added pali inference
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- big_vision/__init__.py +0 -0
- big_vision/__pycache__/__init__.cpython-310.pyc +0 -0
- big_vision/__pycache__/sharding.cpython-310.pyc +0 -0
- big_vision/__pycache__/utils.cpython-310.pyc +0 -0
- big_vision/configs/__init__.py +0 -0
- big_vision/configs/bit_i1k.py +102 -0
- big_vision/configs/bit_i21k.py +85 -0
- big_vision/configs/common.py +188 -0
- big_vision/configs/common_fewshot.py +56 -0
- big_vision/configs/load_and_eval.py +143 -0
- big_vision/configs/mlp_mixer_i1k.py +120 -0
- big_vision/configs/proj/cappa/README.md +37 -0
- big_vision/configs/proj/cappa/cappa_architecture.png +0 -0
- big_vision/configs/proj/cappa/pretrain.py +140 -0
- big_vision/configs/proj/clippo/README.md +85 -0
- big_vision/configs/proj/clippo/clippo_colab.ipynb +0 -0
- big_vision/configs/proj/clippo/train_clippo.py +199 -0
- big_vision/configs/proj/distill/README.md +43 -0
- big_vision/configs/proj/distill/bigsweep_flowers_pet.py +164 -0
- big_vision/configs/proj/distill/bigsweep_food_sun.py +213 -0
- big_vision/configs/proj/distill/bit_i1k.py +167 -0
- big_vision/configs/proj/distill/common.py +27 -0
- big_vision/configs/proj/flexivit/README.md +64 -0
- big_vision/configs/proj/flexivit/i1k_deit3_distill.py +187 -0
- big_vision/configs/proj/flexivit/i21k_distill.py +216 -0
- big_vision/configs/proj/flexivit/i21k_sup.py +144 -0
- big_vision/configs/proj/flexivit/timing.py +53 -0
- big_vision/configs/proj/givt/README.md +111 -0
- big_vision/configs/proj/givt/givt_coco_panoptic.py +186 -0
- big_vision/configs/proj/givt/givt_demo_colab.ipynb +309 -0
- big_vision/configs/proj/givt/givt_imagenet2012.py +222 -0
- big_vision/configs/proj/givt/givt_nyu_depth.py +198 -0
- big_vision/configs/proj/givt/givt_overview.png +0 -0
- big_vision/configs/proj/givt/vae_coco_panoptic.py +136 -0
- big_vision/configs/proj/givt/vae_nyu_depth.py +158 -0
- big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py +134 -0
- big_vision/configs/proj/image_text/README.md +65 -0
- big_vision/configs/proj/image_text/SigLIP_demo.ipynb +0 -0
- big_vision/configs/proj/image_text/common.py +127 -0
- big_vision/configs/proj/image_text/lit.ipynb +0 -0
- big_vision/configs/proj/image_text/siglip_lit_coco.py +115 -0
- big_vision/configs/proj/paligemma/README.md +270 -0
- big_vision/configs/proj/paligemma/finetune_paligemma.ipynb +0 -0
- big_vision/configs/proj/paligemma/paligemma.png +0 -0
- big_vision/configs/proj/paligemma/transfers/activitynet_cap.py +209 -0
- big_vision/configs/proj/paligemma/transfers/activitynet_qa.py +213 -0
- big_vision/configs/proj/paligemma/transfers/ai2d.py +170 -0
- big_vision/configs/proj/paligemma/transfers/aokvqa_da.py +161 -0
- big_vision/configs/proj/paligemma/transfers/aokvqa_mc.py +169 -0
- big_vision/configs/proj/paligemma/transfers/chartqa.py +180 -0
big_vision/__init__.py
ADDED
File without changes
|
big_vision/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (145 Bytes). View file
|
|
big_vision/__pycache__/sharding.cpython-310.pyc
ADDED
Binary file (6.31 kB). View file
|
|
big_vision/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (51 kB). View file
|
|
big_vision/configs/__init__.py
ADDED
File without changes
|
big_vision/configs/bit_i1k.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Pre-training BiT on ILSVRC-2012 as in https://arxiv.org/abs/1912.11370
|
17 |
+
|
18 |
+
Run training of a BiT-ResNet-50x1 variant, which takes ~32min on v3-128:
|
19 |
+
|
20 |
+
big_vision.train \
|
21 |
+
--config big_vision/configs/bit_i1k.py \
|
22 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
23 |
+
--config.model.depth 50 --config.model.width 1
|
24 |
+
"""
|
25 |
+
|
26 |
+
# from big_vision.configs.common_fewshot import get_fewshot_lsr
|
27 |
+
import ml_collections as mlc
|
28 |
+
|
29 |
+
|
30 |
+
def get_config(runlocal=False):
|
31 |
+
"""Config for training on ImageNet-1k."""
|
32 |
+
config = mlc.ConfigDict()
|
33 |
+
|
34 |
+
config.seed = 0
|
35 |
+
config.total_epochs = 90
|
36 |
+
config.num_classes = 1000
|
37 |
+
config.loss = 'softmax_xent'
|
38 |
+
|
39 |
+
config.input = dict()
|
40 |
+
config.input.data = dict(
|
41 |
+
name='imagenet2012',
|
42 |
+
split='train[:99%]',
|
43 |
+
)
|
44 |
+
config.input.batch_size = 4096
|
45 |
+
config.input.cache_raw = True # Needs up to 120GB of RAM!
|
46 |
+
config.input.shuffle_buffer_size = 250_000 # Per host.
|
47 |
+
|
48 |
+
pp_common = '|onehot(1000, key="{lbl}", key_result="labels")'
|
49 |
+
pp_common += '|value_range(-1, 1)|keep("image", "labels")'
|
50 |
+
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label')
|
51 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
|
52 |
+
|
53 |
+
config.log_training_steps = 50
|
54 |
+
config.ckpt_steps = 1000
|
55 |
+
|
56 |
+
# Model section
|
57 |
+
config.model_name = 'bit'
|
58 |
+
config.model = dict(
|
59 |
+
depth=50, # You can also pass e.g. [3, 5, 10, 2]
|
60 |
+
width=1.0,
|
61 |
+
)
|
62 |
+
|
63 |
+
# Optimizer section
|
64 |
+
config.optax_name = 'big_vision.momentum_hp'
|
65 |
+
config.grad_clip_norm = 1.0
|
66 |
+
|
67 |
+
# linear scaling rule. Don't forget to sweep if sweeping batch_size.
|
68 |
+
config.wd = (1e-4 / 256) * config.input.batch_size
|
69 |
+
config.lr = (0.1 / 256) * config.input.batch_size
|
70 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=1000)
|
71 |
+
|
72 |
+
# Eval section
|
73 |
+
def get_eval(split, dataset='imagenet2012'):
|
74 |
+
return dict(
|
75 |
+
type='classification',
|
76 |
+
data=dict(name=dataset, split=split),
|
77 |
+
pp_fn=pp_eval.format(lbl='label'),
|
78 |
+
loss_name=config.loss,
|
79 |
+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
|
80 |
+
cache='final_data',
|
81 |
+
)
|
82 |
+
config.evals = {}
|
83 |
+
config.evals.train = get_eval('train[:2%]')
|
84 |
+
config.evals.minival = get_eval('train[99%:]')
|
85 |
+
config.evals.val = get_eval('validation')
|
86 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
87 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
88 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
89 |
+
|
90 |
+
# config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal)
|
91 |
+
# config.evals.fewshot.log_steps = 1000
|
92 |
+
|
93 |
+
if runlocal:
|
94 |
+
config.input.batch_size = 32
|
95 |
+
config.input.cache_raw = False
|
96 |
+
config.input.shuffle_buffer_size = 100
|
97 |
+
|
98 |
+
local_eval = config.evals.val
|
99 |
+
config.evals = {'val': local_eval}
|
100 |
+
config.evals.val.cache = 'none'
|
101 |
+
|
102 |
+
return config
|
big_vision/configs/bit_i21k.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""A config for pre-training BiT on ImageNet-21k.
|
17 |
+
|
18 |
+
This config relies on the Imagenet-21k tfds dataset, which is not yet
|
19 |
+
available publicly in TFDS. We intend to add the dataset to public TFDS soon,
|
20 |
+
and this config will then be runnable.
|
21 |
+
"""
|
22 |
+
|
23 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
24 |
+
import ml_collections as mlc
|
25 |
+
|
26 |
+
|
27 |
+
def get_config():
|
28 |
+
"""Config for training on imagenet-21k."""
|
29 |
+
config = mlc.ConfigDict()
|
30 |
+
|
31 |
+
config.seed = 0
|
32 |
+
config.total_epochs = 90
|
33 |
+
config.num_classes = 21843
|
34 |
+
config.init_head_bias = -10.0
|
35 |
+
config.loss = 'sigmoid_xent'
|
36 |
+
|
37 |
+
config.input = dict()
|
38 |
+
config.input.data = dict(
|
39 |
+
name='imagenet21k',
|
40 |
+
split='full[51200:]',
|
41 |
+
)
|
42 |
+
config.input.batch_size = 4096
|
43 |
+
config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
|
44 |
+
|
45 |
+
pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
|
46 |
+
pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
|
47 |
+
pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
|
48 |
+
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
|
49 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)'
|
50 |
+
|
51 |
+
config.log_training_steps = 50
|
52 |
+
config.ckpt_steps = 1000
|
53 |
+
|
54 |
+
# Model section
|
55 |
+
config.model_name = 'bit_paper'
|
56 |
+
config.model = dict(depth=50, width=1.0)
|
57 |
+
|
58 |
+
# Optimizer section
|
59 |
+
config.optax_name = 'big_vision.momentum_hp'
|
60 |
+
config.grad_clip_norm = 1.0
|
61 |
+
|
62 |
+
# linear scaling rule. Don't forget to sweep if sweeping batch_size.
|
63 |
+
config.lr = (0.03 / 256) * config.input.batch_size
|
64 |
+
config.wd = (3e-5 / 256) * config.input.batch_size
|
65 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=5000)
|
66 |
+
|
67 |
+
# Evaluations on i21k itself.
|
68 |
+
def eval_i21k(split):
|
69 |
+
return dict(
|
70 |
+
type='classification',
|
71 |
+
data={**config.input.data, 'split': split},
|
72 |
+
pp_fn=pp_eval + pp_common_i21k,
|
73 |
+
loss_name=config.loss,
|
74 |
+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
|
75 |
+
)
|
76 |
+
config.evals = {}
|
77 |
+
config.evals.test = eval_i21k('full[:25_600]')
|
78 |
+
config.evals.val = eval_i21k('full[25_600:51_200]')
|
79 |
+
config.evals.train = eval_i21k('full[51_200:76_800]')
|
80 |
+
|
81 |
+
# Few-shot evaluators
|
82 |
+
config.evals.fewshot = get_fewshot_lsr()
|
83 |
+
config.evals.fewshot.log_steps = 25_000
|
84 |
+
|
85 |
+
return config
|
big_vision/configs/common.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""A few things commonly used across A LOT of config files."""
|
16 |
+
|
17 |
+
import string
|
18 |
+
|
19 |
+
import ml_collections as mlc
|
20 |
+
|
21 |
+
|
22 |
+
def input_for_quicktest(config_input, quicktest):
|
23 |
+
if quicktest:
|
24 |
+
config_input.batch_size = 8
|
25 |
+
config_input.shuffle_buffer_size = 10
|
26 |
+
config_input.cache_raw = False
|
27 |
+
|
28 |
+
|
29 |
+
def parse_arg(arg, lazy=False, **spec):
|
30 |
+
"""Makes ConfigDict's get_config single-string argument more usable.
|
31 |
+
|
32 |
+
Example use in the config file:
|
33 |
+
|
34 |
+
import big_vision.configs.common as bvcc
|
35 |
+
def get_config(arg):
|
36 |
+
arg = bvcc.parse_arg(arg,
|
37 |
+
res=(224, int),
|
38 |
+
runlocal=False,
|
39 |
+
schedule='short',
|
40 |
+
)
|
41 |
+
|
42 |
+
# ...
|
43 |
+
|
44 |
+
config.shuffle_buffer = 250_000 if not arg.runlocal else 50
|
45 |
+
|
46 |
+
Ways that values can be passed when launching:
|
47 |
+
|
48 |
+
--config amazing.py:runlocal,schedule=long,res=128
|
49 |
+
--config amazing.py:res=128
|
50 |
+
--config amazing.py:runlocal # A boolean needs no value for "true".
|
51 |
+
--config amazing.py:runlocal=False # Explicit false boolean.
|
52 |
+
--config amazing.py:128 # The first spec entry may be passed unnamed alone.
|
53 |
+
|
54 |
+
Uses strict bool conversion (converting 'True', 'true' to True, and 'False',
|
55 |
+
'false', '' to False).
|
56 |
+
|
57 |
+
Args:
|
58 |
+
arg: the string argument that's passed to get_config.
|
59 |
+
lazy: allow lazy parsing of arguments, which are not in spec. For these,
|
60 |
+
the type is auto-extracted in dependence of most complex possible type.
|
61 |
+
**spec: the name and default values of the expected options.
|
62 |
+
If the value is a tuple, the value's first element is the default value,
|
63 |
+
and the second element is a function called to convert the string.
|
64 |
+
Otherwise the type is automatically extracted from the default value.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
ConfigDict object with extracted type-converted values.
|
68 |
+
"""
|
69 |
+
# Normalize arg and spec layout.
|
70 |
+
arg = arg or '' # Normalize None to empty string
|
71 |
+
spec = {k: get_type_with_default(v) for k, v in spec.items()}
|
72 |
+
|
73 |
+
result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only.
|
74 |
+
|
75 |
+
# Expand convenience-cases for a single parameter without = sign.
|
76 |
+
if arg and ',' not in arg and '=' not in arg:
|
77 |
+
# (think :runlocal) If it's the name of sth in the spec (or there is no
|
78 |
+
# spec), it's that in bool.
|
79 |
+
if arg in spec or not spec:
|
80 |
+
arg = f'{arg}=True'
|
81 |
+
# Otherwise, it is the value for the first entry in the spec.
|
82 |
+
else:
|
83 |
+
arg = f'{list(spec.keys())[0]}={arg}'
|
84 |
+
# Yes, we rely on Py3.7 insertion order!
|
85 |
+
|
86 |
+
# Now, expand the `arg` string into a dict of keys and values:
|
87 |
+
raw_kv = {raw_arg.split('=')[0]:
|
88 |
+
raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
|
89 |
+
for raw_arg in arg.split(',') if raw_arg}
|
90 |
+
|
91 |
+
# And go through the spec, using provided or default value for each:
|
92 |
+
for name, (default, type_fn) in spec.items():
|
93 |
+
val = raw_kv.pop(name, None)
|
94 |
+
result[name] = type_fn(val) if val is not None else default
|
95 |
+
|
96 |
+
if raw_kv:
|
97 |
+
if lazy: # Process args which are not in spec.
|
98 |
+
for k, v in raw_kv.items():
|
99 |
+
result[k] = autotype(v)
|
100 |
+
else:
|
101 |
+
raise ValueError(f'Unhandled config args remain: {raw_kv}')
|
102 |
+
|
103 |
+
return result
|
104 |
+
|
105 |
+
|
106 |
+
def get_type_with_default(v):
|
107 |
+
"""Returns (v, string_to_v_type) with lenient bool parsing."""
|
108 |
+
# For bool, do safe string conversion.
|
109 |
+
if isinstance(v, bool):
|
110 |
+
def strict_bool(x):
|
111 |
+
assert x.lower() in {'true', 'false', ''}
|
112 |
+
return x.lower() == 'true'
|
113 |
+
return (v, strict_bool)
|
114 |
+
# If already a (default, type) tuple, use that.
|
115 |
+
if isinstance(v, (tuple, list)):
|
116 |
+
assert len(v) == 2 and isinstance(v[1], type), (
|
117 |
+
'List or tuple types are currently not supported because we use `,` as'
|
118 |
+
' dumb delimiter. Contributions (probably using ast) welcome. You can'
|
119 |
+
' unblock by using a string with eval(s.replace(";", ",")) or similar')
|
120 |
+
return (v[0], v[1])
|
121 |
+
# Otherwise, derive the type from the default value.
|
122 |
+
return (v, type(v))
|
123 |
+
|
124 |
+
|
125 |
+
def autotype(x):
|
126 |
+
"""Auto-converts string to bool/int/float if possible."""
|
127 |
+
assert isinstance(x, str)
|
128 |
+
if x.lower() in {'true', 'false'}:
|
129 |
+
return x.lower() == 'true' # Returns as bool.
|
130 |
+
try:
|
131 |
+
return int(x) # Returns as int.
|
132 |
+
except ValueError:
|
133 |
+
try:
|
134 |
+
return float(x) # Returns as float.
|
135 |
+
except ValueError:
|
136 |
+
return x # Returns as str.
|
137 |
+
|
138 |
+
|
139 |
+
def pack_arg(**kw):
|
140 |
+
"""Packs key-word args as a string to be parsed by `parse_arg()`."""
|
141 |
+
for v in kw.values():
|
142 |
+
assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
|
143 |
+
return ','.join([f'{k}={v}' for k, v in kw.items()])
|
144 |
+
|
145 |
+
|
146 |
+
def arg(**kw):
|
147 |
+
"""Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
|
148 |
+
return {'config_arg': pack_arg(**kw), **kw}
|
149 |
+
|
150 |
+
|
151 |
+
def _get_field_ref(config_dict, field_name):
|
152 |
+
path = field_name.split('.')
|
153 |
+
for field in path[:-1]:
|
154 |
+
config_dict = getattr(config_dict, field)
|
155 |
+
return config_dict.get_ref(path[-1])
|
156 |
+
|
157 |
+
|
158 |
+
def format_str(format_string, config):
|
159 |
+
"""Format string with reference fields from config.
|
160 |
+
|
161 |
+
This makes it easy to build preprocess strings that contain references to
|
162 |
+
fields tha are edited after. E.g.:
|
163 |
+
|
164 |
+
```
|
165 |
+
config = mlc.ConficDict()
|
166 |
+
config.res = (256, 256)
|
167 |
+
config.pp = bvcc.format_str('resize({res})', config)
|
168 |
+
...
|
169 |
+
# if config.res is modified (e.g. via sweeps) it will propagate to pp field:
|
170 |
+
config.res = (512, 512)
|
171 |
+
assert config.pp == 'resize((512, 512))'
|
172 |
+
```
|
173 |
+
|
174 |
+
Args:
|
175 |
+
format_string: string to format with references.
|
176 |
+
config: ConfigDict to get references to format the string.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
A reference field which renders a string using references to config fields.
|
180 |
+
"""
|
181 |
+
output = ''
|
182 |
+
parts = string.Formatter().parse(format_string)
|
183 |
+
for (literal_text, field_name, format_spec, conversion) in parts:
|
184 |
+
assert not format_spec and not conversion
|
185 |
+
output += literal_text
|
186 |
+
if field_name:
|
187 |
+
output += _get_field_ref(config, field_name).to_str()
|
188 |
+
return output
|
big_vision/configs/common_fewshot.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Most common few-shot eval configuration."""
|
16 |
+
|
17 |
+
import ml_collections as mlc
|
18 |
+
|
19 |
+
|
20 |
+
def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
|
21 |
+
runlocal=False, **kw):
|
22 |
+
"""Returns a standard-ish fewshot eval configuration."""
|
23 |
+
kw.setdefault('representation_layer', 'pre_logits')
|
24 |
+
kw.setdefault('shots', (1, 5, 10, 25))
|
25 |
+
kw.setdefault('l2_reg', 2.0 ** 10)
|
26 |
+
kw.setdefault('num_seeds', 3)
|
27 |
+
kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/
|
28 |
+
|
29 |
+
# Backward-compatible default:
|
30 |
+
if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long
|
31 |
+
kw['log_steps'] = 25_000
|
32 |
+
|
33 |
+
config = mlc.ConfigDict(kw)
|
34 |
+
config.type = 'fewshot_lsr'
|
35 |
+
config.datasets = {
|
36 |
+
'caltech': ('caltech101', 'train', 'test'), # copybara:srtip
|
37 |
+
'cars': ('cars196:2.1.0', 'train', 'test'),
|
38 |
+
'cifar100': ('cifar100', 'train', 'test'),
|
39 |
+
'dtd': ('dtd', 'train', 'test'),
|
40 |
+
# The first 65000 ImageNet samples have at least 30 shots per any class.
|
41 |
+
# Commented out by default because needs manual download.
|
42 |
+
# 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'),
|
43 |
+
'pets': ('oxford_iiit_pet', 'train', 'test'),
|
44 |
+
'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
|
45 |
+
} if not runlocal else {
|
46 |
+
'pets': ('oxford_iiit_pet', 'train', 'test'),
|
47 |
+
}
|
48 |
+
config.pp_train = (f'decode|resize({resize_resolution})|'
|
49 |
+
f'central_crop({target_resolution})|'
|
50 |
+
f'value_range(-1,1)|keep("image", "label")')
|
51 |
+
config.pp_eval = (f'decode|resize({resize_resolution})|'
|
52 |
+
f'central_crop({target_resolution})|'
|
53 |
+
f'value_range(-1,1)|keep("image", "label")')
|
54 |
+
config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)]
|
55 |
+
|
56 |
+
return config
|
big_vision/configs/load_and_eval.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pytype: disable=not-writable,attribute-error
|
16 |
+
# pylint: disable=line-too-long,missing-function-docstring
|
17 |
+
r"""A config to load and eval key model using the core train.py.
|
18 |
+
|
19 |
+
The runtime varies widely depending on the model, but each one should reproduce
|
20 |
+
the corresponding paper's numbers.
|
21 |
+
This configuration makes use of the "arg" to get_config to select which model
|
22 |
+
to run, so a few examples are given below:
|
23 |
+
|
24 |
+
Run and evaluate a BiT-M ResNet-50x1 model that was transferred to i1k:
|
25 |
+
|
26 |
+
big_vision.train \
|
27 |
+
--config big_vision/configs/load_and_eval.py:name=bit_paper,batch_size=8 \
|
28 |
+
--config.model_init M-imagenet2012 --config.model.width 1 --config.model.depth 50
|
29 |
+
|
30 |
+
Run and evaluate the recommended ViT-B/32 from "how to train your vit" paper:
|
31 |
+
|
32 |
+
big_vision.train \
|
33 |
+
--config big_vision/configs/load_and_eval.py:name=vit_i21k,batch_size=8 \
|
34 |
+
--config.model.variant B/32 --config.model_init howto-i21k-B/32
|
35 |
+
"""
|
36 |
+
|
37 |
+
import big_vision.configs.common as bvcc
|
38 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
39 |
+
|
40 |
+
|
41 |
+
def eval_only(config, batch_size, spec_for_init):
|
42 |
+
"""Set a few configs that turn trainer into (almost) eval-only."""
|
43 |
+
config.total_steps = 0
|
44 |
+
config.input = {}
|
45 |
+
config.input.batch_size = batch_size
|
46 |
+
config.input.data = dict(name='bv:dummy', spec=spec_for_init)
|
47 |
+
config.optax_name = 'identity'
|
48 |
+
config.lr = 0.0
|
49 |
+
|
50 |
+
config.mesh = [('data', -1)]
|
51 |
+
config.sharding_strategy = [('params/.*', 'fsdp(axis="data")')]
|
52 |
+
config.sharding_rules = [('act_batch', ('data',))]
|
53 |
+
|
54 |
+
return config
|
55 |
+
|
56 |
+
|
57 |
+
def get_config(arg=''):
|
58 |
+
config = bvcc.parse_arg(arg, name='bit_paper', batch_size=4)
|
59 |
+
|
60 |
+
# Make the config eval-only by setting some dummies.
|
61 |
+
eval_only(config, config.batch_size, spec_for_init=dict(
|
62 |
+
image=dict(shape=(224, 224, 3), dtype='float32'),
|
63 |
+
))
|
64 |
+
|
65 |
+
config.evals = dict(fewshot=get_fewshot_lsr())
|
66 |
+
|
67 |
+
# Just calls the function with the name given as `config`.
|
68 |
+
# Could also be a giant if-block if you're into that kind of thing.
|
69 |
+
globals()[config.name](config)
|
70 |
+
return config
|
71 |
+
|
72 |
+
|
73 |
+
def bit_paper(config):
|
74 |
+
config.num_classes = 1000
|
75 |
+
|
76 |
+
config.model_name = 'bit_paper'
|
77 |
+
config.model_init = 'M-imagenet2012' # M = i21k, -imagenet2012 = fine-tuned
|
78 |
+
config.model = dict(width=1, depth=50)
|
79 |
+
|
80 |
+
def get_eval(split, lbl, dataset='imagenet2012_real'):
|
81 |
+
return dict(
|
82 |
+
type='classification',
|
83 |
+
data=dict(name=dataset, split=split),
|
84 |
+
loss_name='softmax_xent',
|
85 |
+
cache='none', # Only run once, on low-mem machine.
|
86 |
+
pp_fn=(
|
87 |
+
'decode|resize(384)|value_range(-1, 1)'
|
88 |
+
f'|onehot(1000, key="{lbl}", key_result="labels")'
|
89 |
+
'|keep("image", "labels")'
|
90 |
+
),
|
91 |
+
)
|
92 |
+
config.evals.test = get_eval('validation', 'original_label')
|
93 |
+
config.evals.real = get_eval('validation', 'real_label')
|
94 |
+
config.evals.v2 = get_eval('test', 'label', 'imagenet_v2')
|
95 |
+
|
96 |
+
|
97 |
+
def vit_i1k(config):
|
98 |
+
config.num_classes = 1000
|
99 |
+
|
100 |
+
config.model_name = 'vit'
|
101 |
+
config.model_init = '' # Will be set in sweep.
|
102 |
+
config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
|
103 |
+
rep_size=True)
|
104 |
+
|
105 |
+
config.evals.val = dict(
|
106 |
+
type='classification',
|
107 |
+
data=dict(name='imagenet2012', split='validation'),
|
108 |
+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
|
109 |
+
loss_name='softmax_xent',
|
110 |
+
cache='none', # Only run once, on low-mem machine.
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
def mlp_mixer_i1k(config):
|
115 |
+
config.num_classes = 1000
|
116 |
+
|
117 |
+
config.model_name = 'mlp_mixer'
|
118 |
+
config.model_init = '' # Will be set in sweep.
|
119 |
+
config.model = dict(variant='L/16')
|
120 |
+
|
121 |
+
config.evals.val = dict(
|
122 |
+
type='classification',
|
123 |
+
data=dict(name='imagenet2012', split='validation'),
|
124 |
+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
|
125 |
+
loss_name='softmax_xent',
|
126 |
+
cache='none', # Only run once, on low-mem machine.
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
def vit_i21k(config):
|
131 |
+
config.num_classes = 21843
|
132 |
+
|
133 |
+
config.model_name = 'vit'
|
134 |
+
config.model_init = '' # Will be set in sweep.
|
135 |
+
config.model = dict(variant='B/32', pool_type='tok')
|
136 |
+
|
137 |
+
config.evals.val = dict(
|
138 |
+
type='classification',
|
139 |
+
data=dict(name='imagenet21k', split='full[:51200]'),
|
140 |
+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(21843)|keep("image", "labels")',
|
141 |
+
loss_name='sigmoid_xent',
|
142 |
+
cache='none', # Only run once, on low-mem machine.
|
143 |
+
)
|
big_vision/configs/mlp_mixer_i1k.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""A config for training MLP-Mixer-B/16 model on ILSVRC-2012 ("ImageNet-1k").
|
17 |
+
|
18 |
+
Achieves 76.3% top-1 accuracy on the test split in 2h11m on TPU v3-128
|
19 |
+
with 300 epochs. A shorter 60 epochs run is expected to get to 70.5% in 27m.
|
20 |
+
|
21 |
+
big_vision.train \
|
22 |
+
--config big_vision/configs/mlp_mixer_i1k.py \
|
23 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
24 |
+
"""
|
25 |
+
|
26 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
27 |
+
import ml_collections as mlc
|
28 |
+
|
29 |
+
|
30 |
+
def get_config(mode=None):
|
31 |
+
"""Config for training Mixer on i1k."""
|
32 |
+
config = mlc.ConfigDict()
|
33 |
+
|
34 |
+
config.seed = 0
|
35 |
+
config.total_epochs = 300
|
36 |
+
config.num_classes = 1000
|
37 |
+
config.loss = 'sigmoid_xent'
|
38 |
+
config.init_head_bias = -6.9
|
39 |
+
|
40 |
+
config.input = dict()
|
41 |
+
config.input.data = dict(
|
42 |
+
name='imagenet2012',
|
43 |
+
split='train[:99%]',
|
44 |
+
)
|
45 |
+
config.input.batch_size = 4096
|
46 |
+
config.input.cache_raw = True # Needs up to 120GB of RAM!
|
47 |
+
config.input.shuffle_buffer_size = 250_000
|
48 |
+
|
49 |
+
config.input.pp = (
|
50 |
+
'decode_jpeg_and_inception_crop(224)'
|
51 |
+
'|flip_lr'
|
52 |
+
'|randaug(2,15)'
|
53 |
+
'|value_range(-1, 1)'
|
54 |
+
'|onehot(1000, key="label", key_result="labels")'
|
55 |
+
'|keep("image", "labels")'
|
56 |
+
)
|
57 |
+
pp_eval = (
|
58 |
+
'decode'
|
59 |
+
'|resize_small(256)|central_crop(224)'
|
60 |
+
'|value_range(-1, 1)'
|
61 |
+
'|onehot(1000, key="{lbl}", key_result="labels")'
|
62 |
+
'|keep("image", "labels")'
|
63 |
+
)
|
64 |
+
|
65 |
+
# To continue using the near-defunct randaug op.
|
66 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
67 |
+
|
68 |
+
config.log_training_steps = 50
|
69 |
+
config.ckpt_steps = 1000
|
70 |
+
|
71 |
+
config.prefetch_to_device = 2
|
72 |
+
|
73 |
+
# Model section
|
74 |
+
config.model_name = 'mlp_mixer'
|
75 |
+
config.model = dict()
|
76 |
+
config.model.variant = 'B/16'
|
77 |
+
config.model.stoch_depth = 0.1
|
78 |
+
|
79 |
+
config.mixup = dict(fold_in=None, p=0.5)
|
80 |
+
|
81 |
+
# Optimizer section
|
82 |
+
config.optax_name = 'scale_by_adam'
|
83 |
+
config.grad_clip_norm = 1.
|
84 |
+
|
85 |
+
config.lr = 0.001
|
86 |
+
config.wd = 1e-4
|
87 |
+
config.schedule = dict(
|
88 |
+
decay_type='linear',
|
89 |
+
warmup_steps=10_000,
|
90 |
+
linear_end=1e-5,
|
91 |
+
)
|
92 |
+
|
93 |
+
# Eval section
|
94 |
+
def get_eval(split, dataset='imagenet2012'):
|
95 |
+
return dict(
|
96 |
+
type='classification',
|
97 |
+
data=dict(name=dataset, split=split),
|
98 |
+
pp_fn=pp_eval.format(lbl='label'),
|
99 |
+
loss_name=config.loss,
|
100 |
+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
|
101 |
+
cache_final=mode != 'gpu8',
|
102 |
+
)
|
103 |
+
config.evals = {}
|
104 |
+
config.evals.train = get_eval('train[:2%]')
|
105 |
+
config.evals.minival = get_eval('train[99%:]')
|
106 |
+
config.evals.val = get_eval('validation')
|
107 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
108 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
109 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
110 |
+
|
111 |
+
config.fewshot = get_fewshot_lsr()
|
112 |
+
|
113 |
+
if mode == 'gpu8':
|
114 |
+
config.total_epochs = 60
|
115 |
+
config.input.batch_size = 512
|
116 |
+
config.input.cache_raw = False
|
117 |
+
if mode == 'regression_test':
|
118 |
+
config.total_epochs = 60
|
119 |
+
|
120 |
+
return config
|
big_vision/configs/proj/cappa/README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Image Captioners Are Scalable Vision Learners Too
|
2 |
+
|
3 |
+
*by Michael Tschannen, Manoj Kumar, Andreas Steiner, Xiaohua Zhai, Neil Houlsby, Lucas Beyer* [[arxiv]](https://arxiv.org/abs/2306.07915)
|
4 |
+
|
5 |
+

|
6 |
+
|
7 |
+
This directory contains a config for training a CapPa model from scratch.
|
8 |
+
Note that most models in the paper were trained on a proprietary dataset
|
9 |
+
(WebLI), but similar results can be obtained by training on [LAION](https://laion.ai/).
|
10 |
+
|
11 |
+
By default, this config trains on COCO captions as this data set is readily
|
12 |
+
available in [TFDS](https://www.tensorflow.org/datasets) without manual steps.
|
13 |
+
This is not meant to produce a meaningful model, but
|
14 |
+
provides a way for the user to run the config out of the box. Please update the
|
15 |
+
config with with a TFDS-wrapped variant of your favorite image/text data set to
|
16 |
+
train capable models.
|
17 |
+
|
18 |
+
After setting up `big_vision` as described in the [main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup), training can be launched as follows
|
19 |
+
|
20 |
+
```
|
21 |
+
python -m big_vision.trainers.proj.cappa.generative \
|
22 |
+
--config big_vision/configs/proj/cappa/pretrain.py \
|
23 |
+
--workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'`
|
24 |
+
```
|
25 |
+
|
26 |
+
To run the Cap baseline (autoregressive captioning without parallel prediction),
|
27 |
+
set `config.model.masked_pred_prob = 0.0`.
|
28 |
+
|
29 |
+
### Citation
|
30 |
+
```
|
31 |
+
@inproceedings{tschannen2023image,
|
32 |
+
title={Image Captioners Are Scalable Vision Learners Too},
|
33 |
+
author={Tschannen, Michael and Kumar, Manoj and Steiner, Andreas and Zhai, Xiaohua and Houlsby, Neil and Beyer, Lucas},
|
34 |
+
booktitle={Neural Information Processing Systems (NeurIPS)},
|
35 |
+
year={2023}
|
36 |
+
}
|
37 |
+
```
|
big_vision/configs/proj/cappa/cappa_architecture.png
ADDED
![]() |
big_vision/configs/proj/cappa/pretrain.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Trains a CapPa model (https://arxiv.org/abs/2306.07915) on coco_captions.
|
17 |
+
|
18 |
+
This config is for reference, we never ran a full training on a large
|
19 |
+
image/text data set on public infrastructure.
|
20 |
+
|
21 |
+
big_vision.trainers.proj.cappa.generative \
|
22 |
+
--config big_vision/configs/proj/cappa/pretrain.py \
|
23 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
from big_vision.configs import common_fewshot
|
28 |
+
import big_vision.configs.common as bvcc
|
29 |
+
import ml_collections
|
30 |
+
|
31 |
+
|
32 |
+
def get_config(arg=None):
|
33 |
+
"""Returns the base config."""
|
34 |
+
config = bvcc.parse_arg(arg,
|
35 |
+
runlocal=False,
|
36 |
+
total_steps=366_500,
|
37 |
+
batch_size=8*1024,
|
38 |
+
warmup_steps=10_000,
|
39 |
+
)
|
40 |
+
|
41 |
+
config.evals = {}
|
42 |
+
config.input = {}
|
43 |
+
config.input.batch_size = config.batch_size if not config.runlocal else 8
|
44 |
+
shuffle_buffer_size = 50_000 if not config.runlocal else 50
|
45 |
+
|
46 |
+
res = 224
|
47 |
+
patch_size = 16
|
48 |
+
max_text_tokens = 64
|
49 |
+
|
50 |
+
pp_image = (f'resize({res})|value_range(-1,1)')
|
51 |
+
|
52 |
+
def tokenizer(inkey, outkey):
|
53 |
+
return (f'tokenize(max_len={max_text_tokens}, model="c4_en", '
|
54 |
+
f'eos="sticky", inkey="{inkey}", outkey="{outkey}")')
|
55 |
+
|
56 |
+
pp_coco = (f'decode|{pp_image}|'
|
57 |
+
'coco_captions("captions")|choice(inkey="captions", outkey="text")|'
|
58 |
+
f'{tokenizer("text", "labels")}|keep("image", "labels")')
|
59 |
+
config.input.pp = pp_coco
|
60 |
+
|
61 |
+
# NOTE: "coco_captions" is way too small a dataset to train on. It's simply
|
62 |
+
# used here to serve as a smoke test that the implementation works correctly.
|
63 |
+
config.input.data = dict(name='coco_captions', split='train') # num_examples=82_783
|
64 |
+
config.input.shuffle_buffer_size = shuffle_buffer_size
|
65 |
+
|
66 |
+
config.evals.val_coco = {
|
67 |
+
'type': 'proj.cappa.perplexity',
|
68 |
+
'pred': 'perplexity',
|
69 |
+
'log_steps': 1000,
|
70 |
+
'data': dict(name='coco_captions', split='val'), # num_examples=5_000
|
71 |
+
'pp_fn': pp_coco,
|
72 |
+
}
|
73 |
+
|
74 |
+
# Few-shot metrics
|
75 |
+
config.evals.fewshot = common_fewshot.get_fewshot_lsr(
|
76 |
+
target_resolution=res, resize_resolution=int(256 / 224 * res))
|
77 |
+
config.evals.fewshot.type = 'fewshot_lsr'
|
78 |
+
config.evals.fewshot.log_steps = 5_000 if not config.runlocal else 5
|
79 |
+
config.evals.fewshot.representation_layer = 'pre_logits'
|
80 |
+
config.evals.fewshot.pred = 'enc_rep'
|
81 |
+
config.evals.fewshot.pp_eval = config.evals.fewshot.pp_train
|
82 |
+
|
83 |
+
# NOTE: Scoring of the entire imagenet validation set is rather slow:
|
84 |
+
# ~100 secs / 1k classes / host.
|
85 |
+
config.evals['imagenet/scoring'] = dict(
|
86 |
+
type='proj.cappa.scoring_classifier',
|
87 |
+
pred='score',
|
88 |
+
log_percent=0.1,
|
89 |
+
data=dict(name='imagenet2012', split='validation'),
|
90 |
+
pp_fn=f'decode|{pp_image}|keep("image", "label")',
|
91 |
+
pp_txt=tokenizer('label', 'labels'),
|
92 |
+
)
|
93 |
+
|
94 |
+
for e in config.evals.values():
|
95 |
+
e.skip_first = True
|
96 |
+
|
97 |
+
config.log_training_steps = 50
|
98 |
+
config.ckpt_steps = 1000
|
99 |
+
config.keep_ckpt_steps = None # 10_000
|
100 |
+
|
101 |
+
# Model section
|
102 |
+
config.model_name = 'proj.cappa.cappa'
|
103 |
+
config.model = ml_collections.ConfigDict()
|
104 |
+
config.model.num_layers = 12
|
105 |
+
config.model.num_heads = 12
|
106 |
+
config.model.mlp_dim = 3072
|
107 |
+
config.model.emb_dim = 768
|
108 |
+
config.model.vocab_size = 32_000
|
109 |
+
config.model.patches = (patch_size, patch_size)
|
110 |
+
config.model.seq_len = max_text_tokens
|
111 |
+
config.model.posemb_type = 'learn'
|
112 |
+
|
113 |
+
# Decoder
|
114 |
+
config.model.decoder_num_layers = 6
|
115 |
+
# 0 values here mean to use the same value as for the encoder
|
116 |
+
config.model.decoder_num_heads = 0
|
117 |
+
config.model.decoder_mlp_dim = 0
|
118 |
+
config.model.decoder_emb_dim = 0
|
119 |
+
config.model.dec_dropout_rate = 0.0
|
120 |
+
config.model.masked_pred_prob = 0.75
|
121 |
+
config.model.masking_ratio = 1.0
|
122 |
+
config.model.decoder_bias = False
|
123 |
+
|
124 |
+
config.optax_name = 'big_vision.scale_by_adafactor'
|
125 |
+
config.optax = dict(beta2_cap=0.999)
|
126 |
+
config.grad_clip_norm = 1.0
|
127 |
+
config.label_smoothing = 0.0
|
128 |
+
|
129 |
+
schedule = dict(decay_type='cosine',
|
130 |
+
warmup_steps=config.warmup_steps
|
131 |
+
if not config.runlocal else 5)
|
132 |
+
|
133 |
+
# Standard schedule
|
134 |
+
config.lr = 0.001
|
135 |
+
config.wd = 0.0001
|
136 |
+
config.schedule = schedule
|
137 |
+
|
138 |
+
config.seed = 0
|
139 |
+
|
140 |
+
return config
|
big_vision/configs/proj/clippo/README.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Image-and-Language Understanding from Pixels Only
|
2 |
+
|
3 |
+
*by Michael Tschannen, Basil Mustafa, Neil Houlsby* [[arxiv]](https://arxiv.org/abs/2212.08045) [[colab]](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/clippo/clippo_colab.ipynb)
|
4 |
+
|
5 |
+
We provide pretrained CLIP with Pixels Only (CLIPPO) models and code to train such models on image/alt-text data sets.
|
6 |
+
|
7 |
+
### Pretrained models
|
8 |
+
|
9 |
+
Six ViT-B/16 models trained on a mix of [`YFCC-100M`](https://arxiv.org/abs/1503.01817) and [`C4`](https://arxiv.org/abs/1910.10683) (some initialized with an [ImageNet21k-pretrained checkpoint](https://github.com/google-research/vision_transformer#vision-transformer)\) are available.
|
10 |
+
These models were trained using the schedules and hyperparameters described in the paper. We use the full `YFCC-100M` data set, sampling one of the available `title/description/tag` annotations at random for each each example. We drop non-descriptive annotations (e.g. descriptions consisting of digits only) following the filtering procedure outlined in the [LiT paper](https://arxiv.org/abs/2303.04671), Appendix E. The preprocessing for the `C4` data is as described in the paper.
|
11 |
+
|
12 |
+
The tables below show details about the checkpoints and their performance on Vision & Language benchmarks, and [`GLUE`](https://arxiv.org/abs/1804.07461). We also provide a [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/clippo/clippo_colab.ipynb) to load the models, compute embeddings, and perform zero-shot classification.
|
13 |
+
|
14 |
+
##### Checkpoint details
|
15 |
+
|
16 |
+
| model | training dataset | #param. | steps | checkpoint |
|
17 |
+
|:-----------------|:-------------------|:----------|:--------|:-----------|
|
18 |
+
| CLIPPO | YFCC-100M | 93M | 250k | `gs://big_vision/clippo/clippo_b16_yfcc100m.npz` |
|
19 |
+
| CLIPPO I21k init | YFCC-100M | 93M | 250k | `gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init.npz` |
|
20 |
+
| CLIPPO I21k init | YFCC-100M + 25%C4 | 93M | 333k | `gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_25c4.npz` |
|
21 |
+
| CLIPPO I21k init | YFCC-100M + 50%C4 | 93M | 500k | `gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_50c4.npz` |
|
22 |
+
| CLIPPO I21k init | YFCC-100M + 75%C4 | 93M | 500k | `gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_75c4.npz` |
|
23 |
+
| CLIPPO | C4 | 93M | 250k | `gs://big_vision/clippo/clippo_b16_100c4.npz` |
|
24 |
+
|
25 |
+
##### Vision \& Language results
|
26 |
+
|
27 |
+
| model | training dataset | ImageNet 10-shot | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I |
|
28 |
+
|:-----------------|:-------------------|-----------:|----------:|--------:|--------:|
|
29 |
+
| CLIPPO | YFCC-100M | 38.2 | 43.4 | 34.7 | 19.7 |
|
30 |
+
| CLIPPO I21k init | YFCC-100M | 44.7 | 47.4 | 36.1 | 21.3 |
|
31 |
+
| CLIPPO I21k init | YFCC-100M + 25%C4 | 43.8 | 44.8 | 33.3 | 19.4 |
|
32 |
+
| CLIPPO I21k init | YFCC-100M + 50%C4 | 41.2 | 42.0 | 31.4 | 17.8 |
|
33 |
+
| CLIPPO I21k init | YFCC-100M + 75%C4 | 34.5 | 33.4 | 26.6 | 14.6 |
|
34 |
+
|
35 |
+
##### GLUE results
|
36 |
+
|
37 |
+
| model | training dataset | MNLI-M/MM | QQP | QNLI | SST-2 | COLA | STS-B | MRPC | RTE | avg |
|
38 |
+
|:-----------------|:-------------------|:------------|------:|-------:|--------:|-------:|--------:|-------:|------:|------:|
|
39 |
+
| CLIPPO | YFCC-100M | 71.3 / 71.5 | 79.1 | 67.9 | 85.7 | 0.0 | 14.0 | 83.4 | 54.9 | 58.6 |
|
40 |
+
| CLIPPO I21k init | YFCC-100M | 70.0 / 70.1 | 83.7 | 81.6 | 86.1 | 0.0 | 18.5 | 83.0 | 53.1 | 60.7 |
|
41 |
+
| CLIPPO I21k init | YFCC-100M + 25%C4 | 75.7 / 75.1 | 85.2 | 83.5 | 89.6 | 0.0 | 82.3 | 82.7 | 52.7 | 69.7 |
|
42 |
+
| CLIPPO I21k init | YFCC-100M + 50%C4 | 77.4 / 77.4 | 86.0 | 83.9 | 91.7 | 34.5 | 84.5 | 85.1 | 56.3 | 75.2 |
|
43 |
+
| CLIPPO I21k init | YFCC-100M + 75%C4 | 79.8 / 79.1 | 86.5 | 84.3 | 92.0 | 44.5 | 85.3 | 88.2 | 58.5 | 77.6 |
|
44 |
+
| CLIPPO | C4 | 79.9 / 80.2 | 86.7 | 85.2 | 93.3 | 50.9 | 84.7 | 86.3 | 58.5 | 78.4 |
|
45 |
+
|
46 |
+
### Training your own models
|
47 |
+
|
48 |
+
To train your own CLIPPO model, please follow the setup instructions in the [`big_vision` main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup). In the following, we provide the CLIPPO-specific commands required in addition to the setup, assume you are using the Google Cloud TPU setup (potentially with adapted TPU configuration, see table below). If you are using GPUs, please set up your machine directly and only execute the `--command` portions of the commands below from the `big_vision` repository root.
|
49 |
+
|
50 |
+
The text rendering preproprocessing function requires manual download of the Unifont .hex files from [Unifoundry](https://unifoundry.com/unifont/) (please follow link for license):
|
51 |
+
|
52 |
+
```bash
|
53 |
+
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \
|
54 |
+
--command "bash big_vision/pp/proj/clippo/download_unifont.sh"
|
55 |
+
```
|
56 |
+
|
57 |
+
Launch the training by running
|
58 |
+
|
59 |
+
```bash
|
60 |
+
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \
|
61 |
+
--command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.contrastive --config big_vision/configs/proj/clippo/train_clippo.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
|
62 |
+
```
|
63 |
+
|
64 |
+
*Important note:* The input pipeline relies on [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets) which does not provide automatic integration with large image/alt-text datasets out of the box. The above config therefore trains by default on MS-COCO Captions which can be automatically downloaded via TFDS, and additionally initializes the CLIPPO ViT backbone with weights pretrained on ImageNet21k. This setup is not meant to produce good accuracy, but to provide the user with a way to sanity-check their setup. If you want to train on a large data set such as [`LAION-400M`](https://arxiv.org/abs/2111.02114) or [`YFCC-100M`](https://arxiv.org/abs/1503.01817), please follow [these instructions](https://www.tensorflow.org/datasets/add_dataset) to wrap your data set using TFDS, and update the dataset in the config accordingly. Also note that the ImageNet1k evaluations require manual download of the data, see [these instructions](https://github.com/google-research/big_vision#preparing-tfds-data). To train with your own data set and with ImageNet1k-based evaluations, use `--config big_vision/configs/proj/clippo/train_clippo.py:test_with_coco=False,i1k_eval=True` in the command above.
|
65 |
+
|
66 |
+
##### Expected results
|
67 |
+
|
68 |
+
| train dataset | batch size | #steps | TPU chips | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I | Config `arg` |
|
69 |
+
| :--- | ---: | ---: | ---: | :---: | :---: | :---: | :--- |
|
70 |
+
| *MS-COCO (sanity check)* | 4000 | 400 | 32 v3 | 4.2 | 12.6 | 8.6 | `i1k_eval=True` |
|
71 |
+
| LAION-400M | 8192 | 100k |128 v2 | 51.5 | 44.8 | 29.3 | `test_with_coco=False,i1k_eval=True` |
|
72 |
+
| LAION-400M | 10240\* | 100k | 128 v3 | 53.6 | 46.7 | 30.3 | `test_with_coco=False,i1k_eval=True` |
|
73 |
+
|
74 |
+
\* The experiments in the paper use a batch size of 10240 which requires a memory-optimized ViT implementation to run on 128 TPU v2 chips or 128 TPU v3 chips (in which case the TPU memory capacity allows to increase the batch size beyond 10240).
|
75 |
+
|
76 |
+
### Citation
|
77 |
+
|
78 |
+
```
|
79 |
+
@inproceedings{tschannen2023image,
|
80 |
+
title={Image-and-Language Understanding from Pixels Only},
|
81 |
+
author={Tschannen, Michael and Mustafa, Basil and Houlsby, Neil},
|
82 |
+
booktitle={Proc. IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
83 |
+
year={2023}
|
84 |
+
}
|
85 |
+
```
|
big_vision/configs/proj/clippo/clippo_colab.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
big_vision/configs/proj/clippo/train_clippo.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045
|
17 |
+
|
18 |
+
IMPORTANT NOTE: This config uses coco_captions by default for demonstration
|
19 |
+
purposes since the TFDS catalog does not provide any large image/alt-text data
|
20 |
+
set; the training will not produce a model with useful accuracy. Please
|
21 |
+
replace the data set below (marked by a comment) with an appropriate image/
|
22 |
+
alt-text data set wrapped in TFDS (for example LAION-400M) and run the config
|
23 |
+
with the suffix `:test_with_coco=False` to train on your data set. Refer to
|
24 |
+
the following guide to build a TFDS wrapper for your favorite image/alt-text
|
25 |
+
data set:
|
26 |
+
https://www.tensorflow.org/datasets/add_dataset
|
27 |
+
|
28 |
+
Also note that evaluation on ImageNet requires manual TFDS setup, see
|
29 |
+
https://github.com/google-research/big_vision#preparing-tfds-data
|
30 |
+
|
31 |
+
|
32 |
+
Example training:
|
33 |
+
|
34 |
+
big_vision.trainers.proj.image_text.contrastive \
|
35 |
+
--config big_vision/configs/proj/clippo/train_clippo.py \
|
36 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'`
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
import big_vision.configs.common as bvcc
|
41 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
42 |
+
from big_vision.configs.proj.image_text import common
|
43 |
+
from ml_collections import ConfigDict
|
44 |
+
|
45 |
+
|
46 |
+
def get_config(arg=None):
|
47 |
+
"""The base configuration."""
|
48 |
+
arg = bvcc.parse_arg(
|
49 |
+
arg, res=224, runlocal=False, variant='B/16',
|
50 |
+
test_with_coco=True, i1k_eval=False)
|
51 |
+
config = ConfigDict()
|
52 |
+
|
53 |
+
config.input = {}
|
54 |
+
if arg.test_with_coco:
|
55 |
+
# Use COCO Captions for sanity-checking
|
56 |
+
config.input.data = dict(name='coco_captions', split='train')
|
57 |
+
val_data = dict(config.input.data)
|
58 |
+
val_data['split'] = 'val'
|
59 |
+
config.input.batch_size = 4000 if not arg.runlocal else 32
|
60 |
+
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 50
|
61 |
+
config.total_steps = 400 if not arg.runlocal else 10
|
62 |
+
else:
|
63 |
+
# Please add your favorite image/alt-text dataset here
|
64 |
+
config.input.data = None
|
65 |
+
val_data = None
|
66 |
+
assert config.input.data is not None and val_data is not None, (
|
67 |
+
config.input.data, val_data)
|
68 |
+
|
69 |
+
# The value in the paper is 10 * 1024, which requires 128 TPUv3 cores or a
|
70 |
+
# memory optimized ViT implementation when running on 128 TPUv2 cores.
|
71 |
+
config.input.batch_size = 8 * 1024 if not arg.runlocal else 32
|
72 |
+
config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50
|
73 |
+
config.total_steps = 100_000 if not arg.runlocal else 10
|
74 |
+
|
75 |
+
def tokenizer(inkey, outkey='labels'):
|
76 |
+
return (f'render_unifont('
|
77 |
+
f'inkey="{inkey}", '
|
78 |
+
f'outkey="{outkey}", '
|
79 |
+
f'image_size={arg.res}, '
|
80 |
+
f'lower=True, '
|
81 |
+
f'font_size=16, '
|
82 |
+
f'text_brightness=0, '
|
83 |
+
f'background_brightness=127)|'
|
84 |
+
f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")')
|
85 |
+
|
86 |
+
pp_image = f'decode|resize({arg.res})|value_range(-1,1)'
|
87 |
+
if arg.test_with_coco:
|
88 |
+
# Train with augmentation when sanity-checking
|
89 |
+
pp_image_aug = (
|
90 |
+
f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)')
|
91 |
+
config.input.pp = pp_eval = (
|
92 |
+
f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|'
|
93 |
+
f'keep("image", "labels")')
|
94 |
+
else:
|
95 |
+
config.input.pp = pp_eval = (
|
96 |
+
f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")')
|
97 |
+
|
98 |
+
config.pp_modules = [
|
99 |
+
'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops']
|
100 |
+
|
101 |
+
config.log_training_steps = 50
|
102 |
+
config.ckpt_steps = 1000
|
103 |
+
config.keep_ckpt_steps = 5000
|
104 |
+
|
105 |
+
config.loss_use_global_batch = True
|
106 |
+
|
107 |
+
# Define the model
|
108 |
+
config.model_name = 'proj.clippo.one_tower'
|
109 |
+
|
110 |
+
config.model = ConfigDict()
|
111 |
+
config.model.image_model = 'vit'
|
112 |
+
config.model.image = ConfigDict({
|
113 |
+
'variant': arg.variant,
|
114 |
+
'pool_type': 'map',
|
115 |
+
'head_zeroinit': False,
|
116 |
+
})
|
117 |
+
|
118 |
+
if arg.test_with_coco:
|
119 |
+
# Initialize with ImageNet21k pretrained checkpoint for sanity-checking
|
120 |
+
assert arg.variant == 'B/16', arg.variant
|
121 |
+
config.model_init = {'image': 'howto-i21k-B/16'}
|
122 |
+
config.model_load = {}
|
123 |
+
config.model_load['img_load_kw'] = {
|
124 |
+
'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']}
|
125 |
+
|
126 |
+
config.model.temperature_init = 10.0
|
127 |
+
config.model.out_dim = 768
|
128 |
+
|
129 |
+
# Define the optimizer
|
130 |
+
config.optax_name = 'big_vision.scale_by_adafactor'
|
131 |
+
config.grad_clip_norm = 1.0
|
132 |
+
|
133 |
+
if arg.test_with_coco:
|
134 |
+
# Short schedule for sanity-checking
|
135 |
+
config.lr = 0.0001
|
136 |
+
config.wd = 0.0003
|
137 |
+
config.schedule = dict(decay_type='rsqrt',
|
138 |
+
timescale=100,
|
139 |
+
warmup_steps=100 if not arg.runlocal else 5,
|
140 |
+
cooldown_steps=100 if not arg.runlocal else 5)
|
141 |
+
else:
|
142 |
+
config.lr = 0.001
|
143 |
+
config.wd = 0.0001
|
144 |
+
config.schedule = dict(decay_type='rsqrt',
|
145 |
+
timescale=10_000,
|
146 |
+
warmup_steps=10_000 if not arg.runlocal else 5,
|
147 |
+
cooldown_steps=10_000 if not arg.runlocal else 5)
|
148 |
+
|
149 |
+
# Eval section (Both few-shot and zero-shot)
|
150 |
+
eval_common = dict(
|
151 |
+
type='proj.image_text.contrastive',
|
152 |
+
use_global_batch=config.loss_use_global_batch,
|
153 |
+
log_steps=1000 if not arg.runlocal else 5,
|
154 |
+
)
|
155 |
+
config.evals = {}
|
156 |
+
sub = '[:4]' if arg.runlocal else ''
|
157 |
+
config.evals.val = {
|
158 |
+
**eval_common,
|
159 |
+
'data': val_data,
|
160 |
+
'pp_fn': pp_eval,
|
161 |
+
}
|
162 |
+
config.evals.coco = {
|
163 |
+
**eval_common,
|
164 |
+
'data': dict(name='coco_captions', split=f'val{sub}'),
|
165 |
+
'pp_fn': (
|
166 |
+
f'{pp_image}|flatten|{tokenizer("captions/text")}|'
|
167 |
+
f'keep("image", "labels")'),
|
168 |
+
}
|
169 |
+
|
170 |
+
if arg.i1k_eval:
|
171 |
+
# Requires manual download, see
|
172 |
+
# https://github.com/google-research/big_vision#preparing-tfds-data
|
173 |
+
config.evals.imagenet = {
|
174 |
+
**eval_common,
|
175 |
+
'data': dict(name='imagenet2012', split=f'validation{sub}'),
|
176 |
+
'pp_fn': (
|
177 |
+
f'{pp_image}|clip_i1k_label_names|'
|
178 |
+
f'{tokenizer("labels")}|keep("image", "labels")'),
|
179 |
+
}
|
180 |
+
config.evals.disclf = dict(
|
181 |
+
type='proj.image_text.discriminative_classifier',
|
182 |
+
pp_txt=tokenizer('texts', 'labels'),
|
183 |
+
prefix='z/0shot/',
|
184 |
+
log_steps=5_000 if not arg.runlocal else 5)
|
185 |
+
|
186 |
+
config.evals.retrieval_coco = common.get_coco(
|
187 |
+
pp_img=f'resize({arg.res})|value_range(-1, 1)',
|
188 |
+
pp_txt=tokenizer('texts'),
|
189 |
+
log_steps=5_000 if not arg.runlocal else 5,
|
190 |
+
)
|
191 |
+
|
192 |
+
# Few-shot metrics
|
193 |
+
config.evals.fewshot = get_fewshot_lsr()
|
194 |
+
config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5
|
195 |
+
config.evals.fewshot.representation_layer = 'img/pre_logits'
|
196 |
+
|
197 |
+
config.seed = 0
|
198 |
+
|
199 |
+
return config
|
big_vision/configs/proj/distill/README.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Knowledge distillation: A good teacher is patient and consistent
|
2 |
+
*by Lucas Beyer, Xiaohua Zhai, Amélie Royer, Larisa Markeeva, Rohan Anil, Alexander Kolesnikov*
|
3 |
+
|
4 |
+
## Introduction
|
5 |
+
We publish all teacher models, and configurations for the main experiments of
|
6 |
+
the paper, as well as training logs and student models.
|
7 |
+
|
8 |
+
Please read the main [big_vision README](/README.md) to learn how to run
|
9 |
+
configs, and remember that each config file contains an example invocation in
|
10 |
+
the top-level comment.
|
11 |
+
|
12 |
+
## Results
|
13 |
+
|
14 |
+
We provide the following [colab to read and plot the logfiles](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing)
|
15 |
+
of a few runs that we reproduced on Cloud.
|
16 |
+
|
17 |
+
### ImageNet-1k
|
18 |
+
|
19 |
+
The file [bit_i1k.py](bit_i1k.py) is the configuration which reproduces our
|
20 |
+
distillation runs on ImageNet-1k reported in Figures 1 and 5(left) and the first
|
21 |
+
row of Table1.
|
22 |
+
|
23 |
+
We release both student and teacher models:
|
24 |
+
|
25 |
+
| Model | Download link | Resolution | ImageNet top-1 acc. (paper) |
|
26 |
+
| :--- | :---: | :---: | :---: |
|
27 |
+
| BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_160.npz) | 160 | 80.5 |
|
28 |
+
| BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_224.npz) | 224 | 82.8 |
|
29 |
+
| BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz) | 224 | 83.0 |
|
30 |
+
| BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz) | 384 | 84.3 |
|
31 |
+
|
32 |
+
### Flowers/Pet/Food/Sun
|
33 |
+
|
34 |
+
The files [bigsweep_flowers_pet.py](bigsweep_flowers_pet.py) and
|
35 |
+
[bigsweep_food_sun.py](bigsweep_food_sun.py) can be used to reproduce the
|
36 |
+
distillation runs on these datasets and shown in Figures 3,4,9-12, and Table4.
|
37 |
+
|
38 |
+
While our open-source release does not currently support doing hyper-parameter
|
39 |
+
sweeps, we still provide an example of the sweeps at the end of the configs
|
40 |
+
for reference.
|
41 |
+
|
42 |
+
### Teacher models
|
43 |
+
Links to all teacher models we used can be found in [common.py](common.py).
|
big_vision/configs/proj/distill/bigsweep_flowers_pet.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Distilling BiT-R152x2 into BiT-R50x1 on Flowers/Pet as in https://arxiv.org/abs/2106.05237
|
17 |
+
|
18 |
+
While many epochs are required, this is a small dataset, and thus overall it
|
19 |
+
is still fast and possible to run on the relatively small v3-8TPUs (or GPUs).
|
20 |
+
|
21 |
+
This configuration contains the recommended settings from Fig3/Tab4 of the
|
22 |
+
paper, which can be selected via the fast/medium/long config argument.
|
23 |
+
(best settings were selected on a 10% minival)
|
24 |
+
|
25 |
+
For Flowers:
|
26 |
+
- The `fast` variant takes ~1h10m on a v2-8 TPU.
|
27 |
+
Example logs at gs://big_vision/distill/bit_flowers_fast_06-18_2008/big_vision_metrics.txt
|
28 |
+
- The `long` variant takes ~25h on a v3-32 TPU.
|
29 |
+
Example logs at gs://big_vision/distill/bit_flowers_long_06-19_0524/big_vision_metrics.txt
|
30 |
+
For Pet:
|
31 |
+
- The `fast` variant takes ~28min on a v2-8 TPU.
|
32 |
+
Example logs at gs://big_vision/distill/bit_pet_fast_06-16_2338/big_vision_metrics.txt
|
33 |
+
- The `long` variant takes ~11h on a v2-8 and ~8h on a v3-32.
|
34 |
+
Example logs at gs://big_vision/distill/bit_pet_long_06-17_0050/big_vision_metrics.txt
|
35 |
+
|
36 |
+
big_vision.trainers.proj.distill.distill \
|
37 |
+
--config big_vision/configs/proj/distill/bigsweep_flowers_pet.py:data=flowers,variant=fast \
|
38 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
39 |
+
"""
|
40 |
+
|
41 |
+
import big_vision.configs.common as bvcc
|
42 |
+
import big_vision.configs.proj.distill.common as cd
|
43 |
+
import ml_collections as mlc
|
44 |
+
|
45 |
+
NCLS = dict(flowers=102, pet=37)
|
46 |
+
|
47 |
+
|
48 |
+
def get_config(arg=None):
|
49 |
+
"""Config for massive hypothesis-test on pet."""
|
50 |
+
arg = bvcc.parse_arg(arg, runlocal=False, data='flowers', variant='medium', crop='inception_crop(128)')
|
51 |
+
config = mlc.ConfigDict()
|
52 |
+
|
53 |
+
config.input = {}
|
54 |
+
config.input.data = dict(
|
55 |
+
name=dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data],
|
56 |
+
split=dict(flowers='train', pet='train[:90%]')[arg.data],
|
57 |
+
)
|
58 |
+
config.input.batch_size = 512
|
59 |
+
config.input.cache_raw = True
|
60 |
+
config.input.shuffle_buffer_size = 50_000
|
61 |
+
config.prefetch_to_device = 4
|
62 |
+
|
63 |
+
config.num_classes = NCLS[arg.data]
|
64 |
+
config.total_epochs = {
|
65 |
+
'flowers': {'fast': 10_000, 'medium': 100_000, 'long': 1_000_000},
|
66 |
+
'pet': {'fast': 1000, 'medium': 3000, 'long': 30_000},
|
67 |
+
}[arg.data][arg.variant]
|
68 |
+
|
69 |
+
config.log_training_steps = 100
|
70 |
+
config.ckpt_steps = 2500
|
71 |
+
|
72 |
+
# Model section
|
73 |
+
config.student_name = 'bit_paper'
|
74 |
+
config.student = dict(depth=50, width=1)
|
75 |
+
|
76 |
+
config.teachers = ['prof_m']
|
77 |
+
config.prof_m_name = 'bit_paper'
|
78 |
+
config.prof_m_init = cd.inits[f'BiT-M R152x2 {arg.data} rc128']
|
79 |
+
config.prof_m = dict(depth=152, width=2)
|
80 |
+
|
81 |
+
# Preprocessing pipeline for student & tacher.
|
82 |
+
pp_common = (
|
83 |
+
'|value_range(-1, 1)'
|
84 |
+
f'|onehot({config.num_classes}, key="label", key_result="labels")'
|
85 |
+
'|keep("image", "labels")'
|
86 |
+
)
|
87 |
+
config.input.pp = f'decode|{arg.crop}|flip_lr' + pp_common
|
88 |
+
ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common
|
89 |
+
|
90 |
+
config.mixup = dict(p=1.0)
|
91 |
+
|
92 |
+
# Distillation settings
|
93 |
+
config.distance = 'kl'
|
94 |
+
config.distance_kw = dict(t={
|
95 |
+
'flowers': {'fast': 10., 'medium': 1., 'long': 1.},
|
96 |
+
'pet': {'fast': 5., 'medium': 10., 'long': 2.},
|
97 |
+
}[arg.data][arg.variant])
|
98 |
+
|
99 |
+
# Optimizer section
|
100 |
+
config.grad_clip_norm = 1.0
|
101 |
+
config.optax_name = 'scale_by_adam'
|
102 |
+
config.optax = dict(mu_dtype='bfloat16')
|
103 |
+
|
104 |
+
config.lr = {
|
105 |
+
'flowers': {'fast': 0.003, 'medium': 0.001, 'long': 0.0003},
|
106 |
+
'pet': {'fast': 0.01, 'medium': 0.003, 'long': 0.003},
|
107 |
+
}[arg.data][arg.variant]
|
108 |
+
config.wd = {
|
109 |
+
'flowers': {'fast': 3e-4, 'medium': 1e-4, 'long': 1e-5},
|
110 |
+
'pet': {'fast': 1e-3, 'medium': 3e-4, 'long': 1e-5},
|
111 |
+
}[arg.data][arg.variant]
|
112 |
+
config.schedule = dict(warmup_steps=1500, decay_type='cosine')
|
113 |
+
config.optim_name = 'adam_hp'
|
114 |
+
|
115 |
+
# Eval section
|
116 |
+
minitrain_split = 'train[:512]' if not arg.runlocal else 'train[:16]'
|
117 |
+
if arg.data == 'flowers':
|
118 |
+
val_split = 'validation' if not arg.runlocal else 'validation[:16]'
|
119 |
+
test_split = 'test' if not arg.runlocal else 'test[:16]'
|
120 |
+
elif arg.data == 'pet':
|
121 |
+
val_split = 'train[90%:]' if not arg.runlocal else 'train[:16]'
|
122 |
+
test_split = 'test' if not arg.runlocal else 'test[:16]'
|
123 |
+
|
124 |
+
def get_eval(split):
|
125 |
+
return dict(
|
126 |
+
type='classification',
|
127 |
+
pred='student_fwd',
|
128 |
+
data=dict(name=config.input.data.name, split=split),
|
129 |
+
pp_fn=ppv,
|
130 |
+
loss_name='softmax_xent',
|
131 |
+
log_steps=500,
|
132 |
+
)
|
133 |
+
config.evals = {}
|
134 |
+
config.evals.student_train = get_eval(minitrain_split)
|
135 |
+
config.evals.student_val = get_eval(val_split)
|
136 |
+
config.evals.student_test = get_eval(test_split)
|
137 |
+
|
138 |
+
# Teacher is fixed, so rare evals.
|
139 |
+
teacher = dict(log_steps=100_000, pred='prof_m_fwd')
|
140 |
+
config.evals.teacher_train = {**config.evals.student_train, **teacher}
|
141 |
+
config.evals.teacher_val = {**config.evals.student_val, **teacher}
|
142 |
+
config.evals.teacher_test = {**config.evals.student_test, **teacher}
|
143 |
+
|
144 |
+
# Could in principle also look at agreement on other datasets!
|
145 |
+
def get_dist(split):
|
146 |
+
return dict(
|
147 |
+
type='proj.distill.distance',
|
148 |
+
pred='student_prof_m_fwd',
|
149 |
+
data=dict(name=config.input.data.name, split=split),
|
150 |
+
pp_fn=ppv + '|keep("image")',
|
151 |
+
log_steps=1000,
|
152 |
+
distances=({'kind': 'kl'}, {'kind': 'euclidean'},
|
153 |
+
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
|
154 |
+
)
|
155 |
+
config.evals.dist_train = get_dist(minitrain_split)
|
156 |
+
config.evals.dist_val = get_dist(val_split)
|
157 |
+
config.evals.dist_test = get_dist(test_split)
|
158 |
+
|
159 |
+
# Make a few things much smaller for quick local debugging testruns.
|
160 |
+
if arg.runlocal:
|
161 |
+
config.input.shuffle_buffer_size = 10
|
162 |
+
config.input.batch_size = 8
|
163 |
+
|
164 |
+
return config
|
big_vision/configs/proj/distill/bigsweep_food_sun.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Distilling BiT-R152x2 into BiT-R50x1 on Food101/Sun397 as in https://arxiv.org/abs/2106.05237
|
17 |
+
|
18 |
+
While many epochs are required, this is a small dataset, and thus overall it
|
19 |
+
is still fast and possible to run on the relatively small v3-8TPUs (or GPUs).
|
20 |
+
|
21 |
+
This configuration contains the recommended settings from Fig3/Tab4 of the
|
22 |
+
paper, which can be selected via the fast/medium/long config argument.
|
23 |
+
(best settings were selected on a 10% minival)
|
24 |
+
|
25 |
+
For Food101:
|
26 |
+
- The `fast` variant takes ~45min on a v2-8 TPU.
|
27 |
+
Example logs at gs://big_vision/distill/bit_food_fast_06-19_0547/big_vision_metrics.txt
|
28 |
+
Example logs at gs://big_vision/distill/bit_sun_fast_06-20_1839/big_vision_metrics.txt
|
29 |
+
- The `long` variant takes ~14h on a v3-8 TPU.
|
30 |
+
Example logs at gs://big_vision/distill/bit_food_long_06-19_0614/big_vision_metrics.txt
|
31 |
+
Example logs at gs://big_vision/distill/bit_sun_long_06-20_1912/big_vision_metrics.txt
|
32 |
+
|
33 |
+
big_vision.trainers.proj.distill.distill \
|
34 |
+
--config big_vision/configs/proj/distill/bigsweep_food_sun.py:data=food,variant=fast \
|
35 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
36 |
+
"""
|
37 |
+
|
38 |
+
import big_vision.configs.common as bvcc
|
39 |
+
import big_vision.configs.proj.distill.common as cd
|
40 |
+
import ml_collections as mlc
|
41 |
+
|
42 |
+
H, L = 160, 128
|
43 |
+
NCLS = dict(food=101, sun=397)
|
44 |
+
|
45 |
+
|
46 |
+
def get_config(arg=None):
|
47 |
+
"""Config for massive hypothesis-test on pet."""
|
48 |
+
arg = bvcc.parse_arg(arg, runlocal=False, data='food', variant='medium', crop='inception_crop(128)')
|
49 |
+
config = mlc.ConfigDict()
|
50 |
+
|
51 |
+
config.input = {}
|
52 |
+
config.input.data = dict(
|
53 |
+
name=dict(food='food101', sun='sun397')[arg.data],
|
54 |
+
split=dict(food='train[:90%]', sun='train')[arg.data],
|
55 |
+
)
|
56 |
+
config.input.batch_size = 512
|
57 |
+
config.input.cache_raw = True
|
58 |
+
config.input.shuffle_buffer_size = 50_000
|
59 |
+
config.prefetch_to_device = 4
|
60 |
+
|
61 |
+
config.num_classes = NCLS[arg.data]
|
62 |
+
config.total_epochs = {'fast': 100, 'medium': 1000, 'long': 3000}[arg.variant]
|
63 |
+
|
64 |
+
config.log_training_steps = 50
|
65 |
+
config.ckpt_steps = 2500
|
66 |
+
|
67 |
+
# Model section
|
68 |
+
config.student_name = 'bit_paper'
|
69 |
+
config.student = dict(depth=50, width=1)
|
70 |
+
|
71 |
+
config.teachers = ['prof_m']
|
72 |
+
config.prof_m_name = 'bit_paper'
|
73 |
+
config.prof_m_init = cd.inits[f'BiT-M R152x2 {arg.data} rc128']
|
74 |
+
config.prof_m = dict(depth=152, width=2)
|
75 |
+
|
76 |
+
# Preprocessing pipeline for student & tacher.
|
77 |
+
pp_common = (
|
78 |
+
'|value_range(-1, 1)'
|
79 |
+
f'|onehot({config.num_classes}, key="label", key_result="labels")'
|
80 |
+
'|keep("image", "labels")'
|
81 |
+
)
|
82 |
+
config.input.pp = f'decode|{arg.crop}|flip_lr' + pp_common
|
83 |
+
ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common
|
84 |
+
|
85 |
+
config.mixup = dict(p=1.0)
|
86 |
+
|
87 |
+
# Distillation settings
|
88 |
+
config.distance = 'kl'
|
89 |
+
config.distance_kw = dict(t={
|
90 |
+
'food': {'fast': 10., 'medium': 10., 'long': 5.},
|
91 |
+
'sun': {'fast': 10., 'medium': 10., 'long': 10.},
|
92 |
+
}[arg.data][arg.variant])
|
93 |
+
|
94 |
+
# Optimizer section
|
95 |
+
config.grad_clip_norm = 1.0
|
96 |
+
config.optax_name = 'scale_by_adam'
|
97 |
+
config.optax = dict(mu_dtype='bfloat16')
|
98 |
+
|
99 |
+
config.lr = {
|
100 |
+
'food': {'fast': 0.01, 'medium': 0.001, 'long': 0.01},
|
101 |
+
'sun': {'fast': 0.01, 'medium': 0.001, 'long': 0.01},
|
102 |
+
}[arg.data][arg.variant]
|
103 |
+
config.wd = {
|
104 |
+
'food': {'fast': 1e-3, 'medium': 3e-4, 'long': 1e-4},
|
105 |
+
'sun': {'fast': 1e-3, 'medium': 1e-4, 'long': 3e-5},
|
106 |
+
}[arg.data][arg.variant]
|
107 |
+
config.schedule = dict(warmup_steps=1500, decay_type='cosine')
|
108 |
+
config.optim_name = 'adam_hp'
|
109 |
+
|
110 |
+
# Eval section
|
111 |
+
minitrain_split = 'train[:1024]' if not arg.runlocal else 'train[:16]'
|
112 |
+
if arg.data == 'food':
|
113 |
+
val_split = 'train[90%:]' if not arg.runlocal else 'train[:16]'
|
114 |
+
test_split = 'validation' if not arg.runlocal else 'test[:16]'
|
115 |
+
elif arg.data == 'sun':
|
116 |
+
val_split = 'validation' if not arg.runlocal else 'validation[:16]'
|
117 |
+
test_split = 'test' if not arg.runlocal else 'test[:16]'
|
118 |
+
|
119 |
+
def get_eval(split):
|
120 |
+
return dict(
|
121 |
+
type='classification',
|
122 |
+
pred='student_fwd',
|
123 |
+
data=dict(name=config.input.data.name, split=split),
|
124 |
+
pp_fn=ppv,
|
125 |
+
loss_name='softmax_xent',
|
126 |
+
log_steps=500,
|
127 |
+
)
|
128 |
+
config.evals = {}
|
129 |
+
config.evals.student_train = get_eval(minitrain_split)
|
130 |
+
config.evals.student_val = get_eval(val_split)
|
131 |
+
config.evals.student_test = get_eval(test_split)
|
132 |
+
|
133 |
+
# Teacher is fixed, so rare evals.
|
134 |
+
teacher = dict(log_steps=100_000, pred='prof_m_fwd')
|
135 |
+
config.evals.teacher_train = {**config.evals.student_train, **teacher}
|
136 |
+
config.evals.teacher_val = {**config.evals.student_val, **teacher}
|
137 |
+
config.evals.teacher_test = {**config.evals.student_test, **teacher}
|
138 |
+
|
139 |
+
# Could in principle also look at agreement on other datasets!
|
140 |
+
def get_dist(split):
|
141 |
+
return dict(
|
142 |
+
type='proj.distill.distance',
|
143 |
+
pred='student_prof_m_fwd',
|
144 |
+
data=dict(name=config.input.data.name, split=split),
|
145 |
+
pp_fn=ppv + '|keep("image")',
|
146 |
+
log_steps=1000,
|
147 |
+
distances=({'kind': 'kl'}, {'kind': 'euclidean'},
|
148 |
+
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
|
149 |
+
)
|
150 |
+
config.evals.dist_train = get_dist(minitrain_split)
|
151 |
+
config.evals.dist_val = get_dist(val_split)
|
152 |
+
config.evals.dist_test = get_dist(test_split)
|
153 |
+
|
154 |
+
# Make a few things much smaller for quick local debugging testruns.
|
155 |
+
if arg.runlocal:
|
156 |
+
config.input.shuffle_buffer_size = 10
|
157 |
+
config.input.batch_size = 8
|
158 |
+
|
159 |
+
return config
|
160 |
+
|
161 |
+
|
162 |
+
def get_hyper(hyper):
|
163 |
+
"""Hyper sweep."""
|
164 |
+
# TODO: update, similar to flowers_pet sweep.
|
165 |
+
# By default, not running the MASSIVE sweep, just the recommended setting
|
166 |
+
# across durations. However, code for sweep is left for reference/convenience.
|
167 |
+
return hyper.zipit([
|
168 |
+
hyper.sweep('config.total_epochs', [100, 1_000]),
|
169 |
+
hyper.sweep('config.mixup.p', [0.0, 1.0]),
|
170 |
+
hyper.sweep('config.weight_decay', [1e-3, 1e-5]),
|
171 |
+
])
|
172 |
+
|
173 |
+
# pylint: disable=unreachable
|
174 |
+
|
175 |
+
def fix(**kw):
|
176 |
+
return hyper.product([hyper.fixed(f'config.{k}', v, length=1)
|
177 |
+
for k, v in kw.items()])
|
178 |
+
|
179 |
+
def setting(p, l, m, crop, pp_end=None, **extra):
|
180 |
+
pp_end = pp_end or (
|
181 |
+
f'|value_range(-1, 1, key="image")'
|
182 |
+
f'|onehot({NCLS}, key="label", key_result="labels")'
|
183 |
+
f'|keep("image", "labels")'
|
184 |
+
)
|
185 |
+
return hyper.product([
|
186 |
+
fix(**{'mixup.p': p}),
|
187 |
+
fix(l=l, m=m, crop=crop),
|
188 |
+
fix(pp_train=f'decode|{crop}|flip_lr|randaug({l},{m})' + pp_end),
|
189 |
+
fix(**extra)
|
190 |
+
])
|
191 |
+
|
192 |
+
# Mixup, Layers and Mag in randaug.
|
193 |
+
plm = [(0.0, 0, 0), (0.1, 0, 0), (0.5, 0, 0), (1.0, 0, 0)]
|
194 |
+
return hyper.product([
|
195 |
+
hyper.sweep('config.total_epochs', [100, 1000, 3000]),
|
196 |
+
hyper.sweep('config.lr.base', [0.001, 0.003, 0.01]),
|
197 |
+
hyper.sweep('config.distance_kw.t', [1.0, 2.0, 5.0, 10.0]),
|
198 |
+
hyper.sweep('config.weight_decay', [1e-5, 3e-5, 1e-4, 3e-4, 1e-3]),
|
199 |
+
hyper.chainit(
|
200 |
+
[setting(p=p, l=l, m=m,
|
201 |
+
crop=(f'resize({H})'
|
202 |
+
f'|inception_crop({L}, outkey="student")'
|
203 |
+
f'|central_crop({L}, outkey="teacher")'),
|
204 |
+
pp_end=(
|
205 |
+
f'|value_range(-1, 1, key="student")'
|
206 |
+
f'|value_range(-1, 1, key="teacher")'
|
207 |
+
f'|onehot({NCLS}, key="label", key_result="labels")'
|
208 |
+
f'|keep("student", "teacher", "labels")'))
|
209 |
+
for p, l, m in plm] +
|
210 |
+
[setting(p=p, l=l, m=m, crop=f'inception_crop({L})') for
|
211 |
+
p, l, m in plm],
|
212 |
+
)
|
213 |
+
])
|
big_vision/configs/proj/distill/bit_i1k.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Distilling BiT-R152x2 into BiT-R50x1 on ILSVRC-2012 as in https://arxiv.org/abs/2106.05237
|
17 |
+
|
18 |
+
Note that as per paper title, good results require many epochs and thus
|
19 |
+
a lot of _patience_. For experimentation/exploration, consider
|
20 |
+
using the smaller datasets.
|
21 |
+
|
22 |
+
300ep take about 15h on a v3-32 TPU, an example log is available at:
|
23 |
+
Example logs at gs://big_vision/distill/bit_i1k_300ep_06-16/big_vision_metrics.txt
|
24 |
+
|
25 |
+
big_vision.trainers.proj.distill.distill \
|
26 |
+
--config big_vision/configs/proj/distill/bit_i1k.py \
|
27 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
28 |
+
--config.total_epochs 1200
|
29 |
+
"""
|
30 |
+
|
31 |
+
import big_vision.configs.common as bvcc
|
32 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
33 |
+
import big_vision.configs.proj.distill.common as cd
|
34 |
+
import ml_collections as mlc
|
35 |
+
|
36 |
+
|
37 |
+
def get_config(arg=None):
|
38 |
+
"""Config for distilling on ImageNet."""
|
39 |
+
arg = bvcc.parse_arg(arg, runlocal=False)
|
40 |
+
config = mlc.ConfigDict()
|
41 |
+
|
42 |
+
config.input = {}
|
43 |
+
config.input.data = dict(name='imagenet2012', split='train[:98%]')
|
44 |
+
config.input.batch_size = 4096
|
45 |
+
config.input.shuffle_buffer_size = 250_000
|
46 |
+
|
47 |
+
config.num_classes = 1000
|
48 |
+
config.total_epochs = 1200 # A good middle-ground
|
49 |
+
|
50 |
+
config.log_training_steps = 50
|
51 |
+
config.ckpt_steps = 1000
|
52 |
+
config.keep_ckpt_steps = 20000
|
53 |
+
|
54 |
+
# Model section
|
55 |
+
config.student_name = 'bit_paper'
|
56 |
+
config.student = dict(depth=50, width=1)
|
57 |
+
|
58 |
+
config.teachers = ['prof_m'] # You could even add multiple.
|
59 |
+
|
60 |
+
# TODO: use public checkpoint name.
|
61 |
+
config.prof_m_name = 'bit_paper'
|
62 |
+
config.prof_m_init = cd.inits['BiT-M R152x2 imagenet2012 ic224']
|
63 |
+
config.prof_m = dict(depth=152, width=2)
|
64 |
+
|
65 |
+
pp_common = (
|
66 |
+
'|value_range(-1, 1)'
|
67 |
+
'|onehot(1000, key="{lbl}", key_result="labels")'
|
68 |
+
'|keep("image", "labels")'
|
69 |
+
)
|
70 |
+
config.input.pp = (
|
71 |
+
'decode_jpeg_and_inception_crop(224)|flip_lr' +
|
72 |
+
pp_common.format(lbl='label')
|
73 |
+
)
|
74 |
+
ppv = 'decode|resize_small(256)|central_crop(224)' + pp_common
|
75 |
+
|
76 |
+
config.mixup = dict(p=1.0)
|
77 |
+
|
78 |
+
# Distillation settings
|
79 |
+
config.distance = 'kl'
|
80 |
+
config.distance_kw = dict(t=1.0)
|
81 |
+
|
82 |
+
# Optimizer section
|
83 |
+
config.grad_clip_norm = 1.0
|
84 |
+
config.optax_name = 'scale_by_adam'
|
85 |
+
config.optax = dict(mu_dtype='bfloat16')
|
86 |
+
|
87 |
+
config.lr = 0.03
|
88 |
+
config.wd = 0.0003
|
89 |
+
config.schedule = dict(warmup_steps=5000, decay_type='cosine')
|
90 |
+
|
91 |
+
# Eval section
|
92 |
+
minitrain_split = 'train[:2%]' if not arg.runlocal else 'train[:16]'
|
93 |
+
minival_split = 'train[99%:]' if not arg.runlocal else 'train[:16]'
|
94 |
+
val_split = 'validation' if not arg.runlocal else 'validation[:16]'
|
95 |
+
real_split = 'validation' if not arg.runlocal else 'validation[:16]'
|
96 |
+
v2_split = 'test' if not arg.runlocal else 'test[:16]'
|
97 |
+
|
98 |
+
def get_eval(split, dataset='imagenet2012'):
|
99 |
+
return dict(
|
100 |
+
type='classification',
|
101 |
+
pred='student_fwd',
|
102 |
+
data=dict(name=dataset, split=split),
|
103 |
+
pp_fn=ppv.format(lbl='label'),
|
104 |
+
loss_name='softmax_xent',
|
105 |
+
log_steps=1000,
|
106 |
+
)
|
107 |
+
|
108 |
+
config.evals = {}
|
109 |
+
config.evals.student_train = get_eval(minitrain_split)
|
110 |
+
config.evals.student_minival = get_eval(minival_split)
|
111 |
+
config.evals.student_val = get_eval(val_split)
|
112 |
+
config.evals.student_v2 = get_eval(v2_split, dataset='imagenet_v2')
|
113 |
+
config.evals.student_real = get_eval(real_split, dataset='imagenet2012_real')
|
114 |
+
config.evals.student_real.pp_fn = ppv.format(lbl='real_label')
|
115 |
+
|
116 |
+
config.evals.student_fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
|
117 |
+
config.evals.student_fewshot.pred = 'student_fwd'
|
118 |
+
config.evals.student_fewshot.log_steps = 10_000
|
119 |
+
|
120 |
+
teacher_eval = dict(
|
121 |
+
log_steps=100_000, # Teacher is fixed, so rare evals.
|
122 |
+
pred='prof_m_fwd',
|
123 |
+
)
|
124 |
+
config.evals.teacher_train = {**config.evals.student_train, **teacher_eval}
|
125 |
+
config.evals.teacher_minival = {**config.evals.student_minival, **teacher_eval}
|
126 |
+
config.evals.teacher_val = {**config.evals.student_val, **teacher_eval}
|
127 |
+
config.evals.teacher_v2 = {**config.evals.student_v2, **teacher_eval}
|
128 |
+
config.evals.teacher_real = {**config.evals.student_real, **teacher_eval}
|
129 |
+
config.evals.teacher_fewshot = {**config.evals.student_fewshot, **teacher_eval}
|
130 |
+
config.evals.teacher_fewshot.prefix = 'z_teacher/'
|
131 |
+
|
132 |
+
# Could in principle also look at agreement on other datasets!
|
133 |
+
def get_dist(split, dataset='imagenet2012'):
|
134 |
+
return dict(
|
135 |
+
type='proj.distill.distance',
|
136 |
+
pred='student_prof_m_fwd',
|
137 |
+
data=dict(name=dataset, split=split),
|
138 |
+
pp_fn=ppv.format(lbl='label') + '|keep("image")',
|
139 |
+
log_steps=1000,
|
140 |
+
distances=({'kind': 'kl'}, {'kind': 'euclidean'},
|
141 |
+
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
|
142 |
+
)
|
143 |
+
config.evals.dist_train = get_dist(minitrain_split)
|
144 |
+
config.evals.dist_minival = get_dist(minival_split)
|
145 |
+
config.evals.dist_val = get_dist(val_split)
|
146 |
+
config.evals.dist_v2 = get_dist(v2_split, dataset='imagenet_v2')
|
147 |
+
|
148 |
+
# NOTE: CKA evaluator does not work with batch padding, so the size of the
|
149 |
+
# split must be a multiple of the batch size.
|
150 |
+
def get_cka(split):
|
151 |
+
return dict(
|
152 |
+
type='proj.distill.cka',
|
153 |
+
pred='student_prof_m_fwd',
|
154 |
+
data=dict(name='imagenet2012', split=split),
|
155 |
+
pp_fn=ppv.format(lbl='label') + '|keep("image")',
|
156 |
+
log_steps=1000,
|
157 |
+
)
|
158 |
+
config.evals.cka_train = get_cka('train[:24576]' if not arg.runlocal else 'train[:16]')
|
159 |
+
config.evals.cka_minival = get_cka('train[-24576:]' if not arg.runlocal else 'train[:16]')
|
160 |
+
config.evals.cka_val = get_cka('validation[:49152]' if not arg.runlocal else 'validation[:16]')
|
161 |
+
|
162 |
+
# Make a few things much smaller for quick local debugging testruns.
|
163 |
+
if arg.runlocal:
|
164 |
+
config.input.shuffle_buffer_size = 10
|
165 |
+
config.input.batch_size = 8
|
166 |
+
|
167 |
+
return config
|
big_vision/configs/proj/distill/common.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Most common teachers for distillation."""
|
16 |
+
|
17 |
+
# pylint: disable=line-too-long
|
18 |
+
inits = { # pylint: disable=duplicate-key Internally, we override some paths for convenience.
|
19 |
+
'BiT-M R152x2 imagenet2012 ic224': 'gs://bit_models/distill/R152x2_T_224.npz',
|
20 |
+
'BiT-M R152x2 imagenet2012 rc384': 'gs://bit_models/distill/R152x2_T_384.npz',
|
21 |
+
'BiT-M R152x2 flowers rc128': 'gs://bit_models/distill/R152x2_T_flowers128.npz',
|
22 |
+
'BiT-M R152x2 pet rc128': 'gs://bit_models/distill/R152x2_T_pet128.npz',
|
23 |
+
'BiT-M R152x2 food rc128': 'gs://bit_models/distill/R152x2_T_food128.npz',
|
24 |
+
'BiT-M R152x2 sun rc128': 'gs://bit_models/distill/R152x2_T_sun128.npz',
|
25 |
+
|
26 |
+
}
|
27 |
+
# pylint: enable=line-too-long
|
big_vision/configs/proj/flexivit/README.md
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FlexiViT: One Model for All Patch Sizes
|
2 |
+
*by Lucas Beyer, Pavel Izmailov, Alexander Kolesnikov, Mathilde Caron, Simon Kornblith, Xiaohua Zhai, Matthias Minderer, Michael Tschannen, Ibrahim Alabdulmohsin, Filip Pavetic*
|
3 |
+
|
4 |
+
## Introduction
|
5 |
+
We publish all pre-trained FlexiViT models, and configurations for training
|
6 |
+
those, as well as training logs for one run.
|
7 |
+
|
8 |
+
Please read the main [big_vision README](/README.md) to learn how to run
|
9 |
+
configs, and remember that each config file contains an example invocation in
|
10 |
+
the top-level comment.
|
11 |
+
|
12 |
+
## Pre-trained paper models
|
13 |
+
|
14 |
+
Here are the models that we used as backbones in the paper. See Tables in the
|
15 |
+
appendix of the paper for expected scores at various patch-sizes and on various
|
16 |
+
datasets.
|
17 |
+
|
18 |
+
First, the recommended models we used for all experiments.
|
19 |
+
Remember that the input is 240px, not 224px:
|
20 |
+
|
21 |
+
| Dataset | Model | Download link | Notes |
|
22 |
+
| :--- | :---: | :---: | :---: |
|
23 |
+
| ImageNet-1k | FlexiViT-L | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz) | 1200ep version |
|
24 |
+
| ImageNet-1k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz) | 1200ep version |
|
25 |
+
| ImageNet-1k | FlexiViT-S | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz) | 1200ep version |
|
26 |
+
| ImageNet-21k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz) | 300ep version. 1000ep version below is better but was not used in the paper for fair comparison to baselines. |
|
27 |
+
| ImageNet-21k | ViT-B/16 | [link](https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz) | Apples-to-apples non-flexi baseline used throughout the paper. |
|
28 |
+
| ImageNet-21k | ViT-B/30 | [link](https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz) | Apples-to-apples non-flexi baseline used throughout the paper. |
|
29 |
+
|
30 |
+
These models can be used directly in our codebase by specifying
|
31 |
+
`model_name = "proj.flexi.vit"` and `model_init = "FlexiViT-L i1k"` for example.
|
32 |
+
See the file `models/proj/flexi/vit.py` for more names.
|
33 |
+
|
34 |
+
*Important detail:* When further re-using these models with a flexible patch
|
35 |
+
size, it is recommended to keep the patch-embedding parameter buffer at its
|
36 |
+
original size, and change patch-size on the fly using pi-resize, as opposed to
|
37 |
+
changing the parameter buffer's size at load-time.
|
38 |
+
For re-using the models with a fixed patch size, either way is fine.
|
39 |
+
(The reason is that it is impossible to chain multiple resizes without loss,
|
40 |
+
eg doing 32->8->32 does not result in the original weights.)
|
41 |
+
|
42 |
+
Second, the list of all released models for completeness:
|
43 |
+
|
44 |
+
| Dataset | Model | Download link | Notes |
|
45 |
+
| :--- | :---: | :---: | :---: |
|
46 |
+
| ImageNet-21k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz) | 1000ep version. Should be the best available -B model. |
|
47 |
+
| ImageNet-21k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_90ep.npz) | 90ep version |
|
48 |
+
| ImageNet-1k | FlexiViT-L | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz) | 600ep version |
|
49 |
+
| ImageNet-1k | FlexiViT-L | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz) | 300ep version |
|
50 |
+
| ImageNet-1k | FlexiViT-L | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_90ep.npz) | 90ep version |
|
51 |
+
| ImageNet-1k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz) | 600ep version |
|
52 |
+
| ImageNet-1k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz) | 300ep version |
|
53 |
+
| ImageNet-1k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_90ep.npz) | 90ep version |
|
54 |
+
| ImageNet-1k | FlexiViT-S | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz) | 600ep version |
|
55 |
+
| ImageNet-1k | FlexiViT-S | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz) | 300ep version |
|
56 |
+
| ImageNet-1k | FlexiViT-S | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_90ep.npz) | 90ep version |
|
57 |
+
|
58 |
+
## Results
|
59 |
+
|
60 |
+
We provide full training logs for a run with this public code on Cloud that
|
61 |
+
reproduces the FlexiViT-S 90ep on i1k results:
|
62 |
+
- [metrics](https://storage.googleapis.com/big_vision/flexivit/deit3_i1k_s_90ep_12-15_2254/big_vision_metrics.txt)
|
63 |
+
- [config](https://storage.googleapis.com/big_vision/flexivit/deit3_i1k_s_90ep_12-15_2254/config.json)
|
64 |
+
- or `gs://big_vision/flexivit/deit3_i1k_s_90ep_12-15_2254`.
|
big_vision/configs/proj/flexivit/i1k_deit3_distill.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Distillation of ViT models into FlexiViT on ImageNet1k.
|
17 |
+
|
18 |
+
Run training of the -S variant for 90ep:
|
19 |
+
|
20 |
+
big_vision.trainers.proj.flexi.distill \
|
21 |
+
--config big_vision/configs/proj/flexivit/i1k_deit3_distill.py \
|
22 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
23 |
+
--config.total_epochs 90 --config.variant S
|
24 |
+
|
25 |
+
Logdir for one reproduction run:
|
26 |
+
- gs://big_vision/flexivit/deit3_i1k_s_90ep_12-15_2254
|
27 |
+
|
28 |
+
Timing on Cloud:
|
29 |
+
- S on v3-32: Walltime:10h16m (4h39m eval)
|
30 |
+
|
31 |
+
Note that we did not optimize the input for Cloud,
|
32 |
+
with tuned caching and prefetching, we should be able to get:
|
33 |
+
- S on v3-32: Walltime: ~6h30m (~1h30m eval)
|
34 |
+
- B on v3-32: Walltime: ~16h00m (~2h30m eval)
|
35 |
+
"""
|
36 |
+
|
37 |
+
import big_vision.configs.common as bvcc
|
38 |
+
|
39 |
+
|
40 |
+
def get_config(arg=None):
|
41 |
+
"""Config for distilling ViT on ImageNet1k."""
|
42 |
+
c = bvcc.parse_arg(arg, runlocal=False, res=240)
|
43 |
+
|
44 |
+
c.seed = 0
|
45 |
+
c.total_epochs = 90
|
46 |
+
c.num_classes = 1000
|
47 |
+
c.loss = 'softmax_xent'
|
48 |
+
|
49 |
+
c.input = {}
|
50 |
+
c.input.data = dict(
|
51 |
+
name='imagenet2012',
|
52 |
+
split='train[:99%]',
|
53 |
+
)
|
54 |
+
c.input.batch_size = 1024 if not c.runlocal else 8
|
55 |
+
c.input.cache_raw = False # Needs up to 120GB of RAM!
|
56 |
+
c.input.shuffle_buffer_size = 250_000 if not c.runlocal else 10
|
57 |
+
|
58 |
+
c.log_training_steps = 50
|
59 |
+
c.ckpt_steps = 1000
|
60 |
+
|
61 |
+
# Model section
|
62 |
+
c.variant = 'B'
|
63 |
+
init = bvcc.format_str('deit3_{variant}_384_1k', c)
|
64 |
+
c.student_name = 'proj.flexi.vit'
|
65 |
+
c.student_init = init
|
66 |
+
c.student = dict(variant=c.get_ref('variant'), pool_type='tok', patch_size=(16, 16))
|
67 |
+
|
68 |
+
c.teachers = ['prof'] # You could even add multiple.
|
69 |
+
c.prof_name = 'vit'
|
70 |
+
c.prof_init = init
|
71 |
+
c.prof = dict(variant=c.get_ref('variant'), pool_type='tok', patch_size=(16, 16))
|
72 |
+
|
73 |
+
pp_label = '|onehot(1000, key="{lbl}", key_result="labels")|keep("image", "prof", "labels")'
|
74 |
+
c.input.pp = (
|
75 |
+
f'decode|inception_crop|flip_lr'
|
76 |
+
'|copy("image", "prof")'
|
77 |
+
f'|resize({c.res})|value_range'
|
78 |
+
'|resize(384, key="prof")|value_range(key="prof")'
|
79 |
+
+ pp_label.format(lbl='label')
|
80 |
+
)
|
81 |
+
pp_eval_both = (
|
82 |
+
'decode|copy("image", "prof")|'
|
83 |
+
f'|resize({c.res//7*8})|central_crop({c.res})|value_range'
|
84 |
+
f'|resize({384//7*8}, key="prof")|central_crop(384, key="prof")|value_range(key="prof")|'
|
85 |
+
)
|
86 |
+
pp_eval_student = (
|
87 |
+
f'decode|resize({c.res//7*8})|central_crop({c.res})|value_range(-1, 1)'
|
88 |
+
)
|
89 |
+
pp_eval_prof = (
|
90 |
+
f'decode|resize({384//7*8})|central_crop(384)|value_range(outkey="prof")'
|
91 |
+
)
|
92 |
+
|
93 |
+
c.mixup = dict(p=1.0, n=2)
|
94 |
+
|
95 |
+
# Distillation settings
|
96 |
+
c.distance = 'kl'
|
97 |
+
c.distance_kw = dict(t=1.0)
|
98 |
+
|
99 |
+
# Optimizer section
|
100 |
+
c.grad_clip_norm = 1.0
|
101 |
+
c.optax_name = 'scale_by_adam'
|
102 |
+
c.optax = dict(mu_dtype='bfloat16')
|
103 |
+
|
104 |
+
c.lr = 1e-4
|
105 |
+
c.wd = 1e-5
|
106 |
+
c.schedule = dict(warmup_steps=5000, decay_type='cosine')
|
107 |
+
|
108 |
+
# Define the model parameters which are flexible:
|
109 |
+
c.flexi = dict()
|
110 |
+
c.flexi.seqhw = dict(
|
111 |
+
# The settings to sample from. Corresponding patch-sizes at 240px:
|
112 |
+
# 48, 40, 30, 24, 20, 16, 15, 12, 10, 8
|
113 |
+
v=(5, 6, 8, 10, 12, 15, 16, 20, 24, 30),
|
114 |
+
# The probabilities/weights of them. Default uniform.
|
115 |
+
p=(1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
|
116 |
+
)
|
117 |
+
|
118 |
+
# Eval section
|
119 |
+
def mksplit(split):
|
120 |
+
if c.runlocal:
|
121 |
+
return split.split('[')[0] + '[:16]'
|
122 |
+
return split
|
123 |
+
|
124 |
+
minitrain_split = mksplit('train[:2%]')
|
125 |
+
minival_split = mksplit('train[99%:]')
|
126 |
+
val_split = mksplit('validation')
|
127 |
+
test_split = mksplit('test')
|
128 |
+
c.aggressive_cache = False
|
129 |
+
|
130 |
+
def get_eval(s, split, dataset='imagenet2012'):
|
131 |
+
return dict(
|
132 |
+
type='classification',
|
133 |
+
pred=f'student_seqhw={s}',
|
134 |
+
data=dict(name=dataset, split=split),
|
135 |
+
pp_fn=pp_eval_student + pp_label.format(lbl='label'),
|
136 |
+
loss_name='sigmoid_xent',
|
137 |
+
log_percent=0.05,
|
138 |
+
cache_final=False,
|
139 |
+
)
|
140 |
+
|
141 |
+
c.evals = {}
|
142 |
+
for s in c.flexi.seqhw.v:
|
143 |
+
c.evals[f'student_minitrain_{s:02d}'] = get_eval(s, minitrain_split)
|
144 |
+
c.evals[f'student_minival_{s:02d}'] = get_eval(s, minival_split)
|
145 |
+
c.evals[f'student_val_{s:02d}'] = get_eval(s, val_split)
|
146 |
+
c.evals[f'student_v2_{s:02d}'] = get_eval(s, test_split, 'imagenet_v2')
|
147 |
+
c.evals[f'student_a_{s:02d}'] = get_eval(s, test_split, 'imagenet_a')
|
148 |
+
c.evals[f'student_r_{s:02d}'] = get_eval(s, test_split, 'imagenet_r')
|
149 |
+
c.evals[f'student_real_{s:02d}'] = get_eval(s, val_split, 'imagenet2012_real')
|
150 |
+
c.evals[f'student_real_{s:02d}'].pp_fn = pp_eval_student + pp_label.format(lbl='real_label')
|
151 |
+
|
152 |
+
def get_eval_t(split, dataset='imagenet2012'):
|
153 |
+
return dict(
|
154 |
+
type='classification',
|
155 |
+
pred='prof',
|
156 |
+
data=dict(name=dataset, split=split),
|
157 |
+
pp_fn=pp_eval_prof + pp_label.format(lbl='label'),
|
158 |
+
loss_name='sigmoid_xent',
|
159 |
+
log_percent=0.5, # Teacher is fixed, so eval just for plots.
|
160 |
+
cache_final=False,
|
161 |
+
)
|
162 |
+
c.evals.teacher_minitrain = get_eval_t(minitrain_split)
|
163 |
+
c.evals.teacher_minival = get_eval_t(minival_split)
|
164 |
+
c.evals.teacher_val = get_eval_t(val_split)
|
165 |
+
c.evals.teacher_v2 = get_eval_t(test_split, 'imagenet_v2')
|
166 |
+
c.evals.teacher_a = get_eval_t(test_split, 'imagenet_a')
|
167 |
+
c.evals.teacher_r = get_eval_t(test_split, 'imagenet_r')
|
168 |
+
c.evals.teacher_real = get_eval_t(val_split, 'imagenet2012_real')
|
169 |
+
c.evals.teacher_real.pp_fn = pp_eval_prof + pp_label.format(lbl='real_label')
|
170 |
+
|
171 |
+
# Distance evaluators
|
172 |
+
def get_dist(split, s):
|
173 |
+
return dict(
|
174 |
+
type='proj.distill.distance',
|
175 |
+
pred=f'student_seqhw={s}_prof',
|
176 |
+
data=dict(name='imagenet2012', split=split),
|
177 |
+
pp_fn=pp_eval_both + '|keep("image", "prof")',
|
178 |
+
log_percent=0.05,
|
179 |
+
distances=({'kind': 'kl'}, {'kind': 'logsoftmax_euclidean'},
|
180 |
+
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
|
181 |
+
cache_final=False,
|
182 |
+
)
|
183 |
+
for s in c.flexi.seqhw.v:
|
184 |
+
c.evals[f'dist_minitrain_{s:02d}'] = get_dist(minitrain_split, s)
|
185 |
+
c.evals[f'dist_val_{s:02d}'] = get_dist(val_split, s)
|
186 |
+
|
187 |
+
return c
|
big_vision/configs/proj/flexivit/i21k_distill.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Distill flexible-seqlen ViT on ImageNet-21k from (internal link) B/8.
|
17 |
+
|
18 |
+
This config is for reference, we never ran it on public infrastructure.
|
19 |
+
|
20 |
+
big_vision.trainers.proj.flexi.distill \
|
21 |
+
--config big_vision/configs/proj/flexivit/i21k_distill.py \
|
22 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
23 |
+
--config.total_epochs 90
|
24 |
+
"""
|
25 |
+
|
26 |
+
import big_vision.configs.common as bvcc
|
27 |
+
|
28 |
+
|
29 |
+
def get_config(arg=None):
|
30 |
+
"""Config for training."""
|
31 |
+
# 240px is nice because it's divisible by
|
32 |
+
# [240, 120, 80, 60, 48, 40, 30, 24, 20, 16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1]
|
33 |
+
c = bvcc.parse_arg(arg, runlocal=False, res=240)
|
34 |
+
|
35 |
+
c.seed = 0
|
36 |
+
c.total_epochs = 90
|
37 |
+
c.num_classes = 21843
|
38 |
+
c.init_head_bias = -10.0
|
39 |
+
c.loss = 'sigmoid_xent'
|
40 |
+
|
41 |
+
c.input = dict()
|
42 |
+
c.input.data = dict(
|
43 |
+
name='imagenet21k',
|
44 |
+
split='full[51200:]',
|
45 |
+
)
|
46 |
+
c.input.batch_size = 4096 if not c.runlocal else 8
|
47 |
+
c.input.shuffle_buffer_size = 250_000 if not c.runlocal else 25
|
48 |
+
|
49 |
+
pp_label_i21k = f'|onehot({c.num_classes})|keep("image", "prof", "labels")'
|
50 |
+
pp_label_i1k = '|onehot(1000, key="{lbl}", key_result="labels")|keep("image", "prof", "labels")'
|
51 |
+
c.input.pp = (
|
52 |
+
f'decode|inception_crop|flip_lr|copy("image", "prof")'
|
53 |
+
f'|resize({c.res})|value_range(-1, 1)'
|
54 |
+
f'|resize(224, outkey="prof")|value_range(-1, 1, key="prof")'
|
55 |
+
+ pp_label_i21k
|
56 |
+
)
|
57 |
+
pp_eval_both = (
|
58 |
+
'decode|copy("image", "prof")|'
|
59 |
+
f'|resize_small({c.res//7*8})|central_crop({c.res})|value_range(-1, 1)'
|
60 |
+
f'|resize_small(256, key="prof")|central_crop(224, key="prof")|value_range(-1, 1, key="prof")|'
|
61 |
+
)
|
62 |
+
pp_eval_student = (
|
63 |
+
f'decode|resize({c.res//7*8})|central_crop({c.res})|value_range(-1, 1)'
|
64 |
+
)
|
65 |
+
pp_eval_prof = (
|
66 |
+
'decode|resize(256)|central_crop(224)|value_range(-1, 1, outkey="prof")'
|
67 |
+
)
|
68 |
+
|
69 |
+
# Aggressive pre-fetching because our models here are small, so we not only
|
70 |
+
# can afford it, but we also need it for the smallest models to not be
|
71 |
+
# bottle-necked by the input pipeline. Play around with it for -L models tho.
|
72 |
+
c.input.prefetch = 8
|
73 |
+
c.prefetch_to_device = 4
|
74 |
+
|
75 |
+
c.log_training_steps = 50
|
76 |
+
c.ckpt_steps = 1000
|
77 |
+
|
78 |
+
# Model section
|
79 |
+
init = 'howto-i21k-B/8'
|
80 |
+
c.student_name = 'proj.flexi.vit'
|
81 |
+
c.student_init = init
|
82 |
+
c.student = dict(variant='B', pool_type='tok', patch_size=(8, 8))
|
83 |
+
|
84 |
+
c.teachers = ['prof'] # You could even add multiple.
|
85 |
+
c.prof_name = 'vit'
|
86 |
+
c.prof_init = init
|
87 |
+
c.prof = dict(variant='B/8', pool_type='tok')
|
88 |
+
|
89 |
+
# Define the model parameters which are flexible:
|
90 |
+
c.flexi = dict()
|
91 |
+
c.flexi.seqhw = dict(
|
92 |
+
# The settings to sample from. Corresponding patch-sizes at 240px:
|
93 |
+
# 48, 40, 30, 24, 20, 16, 15, 12, 10, 8
|
94 |
+
v=(5, 6, 8, 10, 12, 15, 16, 20, 24, 30),
|
95 |
+
# The probabilities/weights of them. Default uniform.
|
96 |
+
p=(1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
|
97 |
+
)
|
98 |
+
|
99 |
+
# Distillation settings
|
100 |
+
c.distance = 'kl'
|
101 |
+
c.distance_kw = dict(t=1.0)
|
102 |
+
|
103 |
+
# Optimizer section
|
104 |
+
c.optax_name = 'scale_by_adam'
|
105 |
+
c.optax = dict(mu_dtype='bfloat16')
|
106 |
+
c.grad_clip_norm = 1.0
|
107 |
+
|
108 |
+
c.lr = 1e-4
|
109 |
+
c.wd = 1e-5
|
110 |
+
c.schedule = dict(warmup_steps=5000, decay_type='cosine')
|
111 |
+
|
112 |
+
c.mixup = dict(p=1.0)
|
113 |
+
|
114 |
+
####
|
115 |
+
# Preparing for evals
|
116 |
+
c.evals = {}
|
117 |
+
def mksplit(split):
|
118 |
+
if c.runlocal:
|
119 |
+
return split.split('[')[0] + '[:16]'
|
120 |
+
return split
|
121 |
+
|
122 |
+
####
|
123 |
+
# Student evals
|
124 |
+
|
125 |
+
# Evaluations on i21k itself.
|
126 |
+
def eval_i21k(s, split):
|
127 |
+
return dict(
|
128 |
+
type='classification',
|
129 |
+
pred=f'student_seqhw={s}',
|
130 |
+
data={**c.input.data, 'split': mksplit(split)},
|
131 |
+
pp_fn=pp_eval_student + pp_label_i21k,
|
132 |
+
loss_name=c.loss,
|
133 |
+
log_steps=5000, # Very fast O(seconds) so it's fine to run it often.
|
134 |
+
)
|
135 |
+
|
136 |
+
for s in c.flexi.seqhw.v:
|
137 |
+
c.evals[f'student_test{s:02d}'] = eval_i21k(s, 'full[:25_600]')
|
138 |
+
c.evals[f'student_val{s:02d}'] = eval_i21k(s, 'full[25_600:51_200]')
|
139 |
+
c.evals[f'student_minitrain{s:02d}'] = eval_i21k(s, 'full[51_200:76_800]')
|
140 |
+
|
141 |
+
# Evaluations on ImageNet1k variants by label-mapping.
|
142 |
+
def eval_i1k(s, dataset, split, lblmap):
|
143 |
+
return dict(
|
144 |
+
type='classification_with_labelmap',
|
145 |
+
pred=f'student_seqhw={s}',
|
146 |
+
data=dict(name=dataset, split=mksplit(split)),
|
147 |
+
pp_fn=pp_eval_student + pp_label_i1k.format(lbl='label'),
|
148 |
+
loss_name=c.loss,
|
149 |
+
log_steps=5000, # Very fast O(seconds) so it's fine to run it often.
|
150 |
+
label_mapping=lblmap,
|
151 |
+
)
|
152 |
+
for s in c.flexi.seqhw.v:
|
153 |
+
c.evals[f'student_i1k_val{s:02d}'] = eval_i1k(s, 'imagenet2012', 'validation', 'i1k_i21k')
|
154 |
+
c.evals[f'student_i1k_v2{s:02d}'] = eval_i1k(s, 'imagenet_v2', 'test', 'i1k_i21k')
|
155 |
+
c.evals[f'student_i1k_a{s:02d}'] = eval_i1k(s, 'imagenet_a', 'test', 'i1ka_i21k')
|
156 |
+
c.evals[f'student_i1k_r{s:02d}'] = eval_i1k(s, 'imagenet_r', 'test', 'i1kr_i21k')
|
157 |
+
c.evals[f'student_i1k_real{s:02d}'] = eval_i1k(s, 'imagenet2012_real', 'validation', 'i1k_i21k')
|
158 |
+
c.evals[f'student_i1k_real{s:02d}'].pp_fn = pp_eval_student + pp_label_i1k.format(lbl='real_label')
|
159 |
+
# TODO: add objectnet.
|
160 |
+
|
161 |
+
####
|
162 |
+
# Teacher evals
|
163 |
+
|
164 |
+
# Evaluations on i21k itself.
|
165 |
+
def eval_i21k_t(split):
|
166 |
+
return dict(
|
167 |
+
type='classification',
|
168 |
+
pred='prof',
|
169 |
+
data={**c.input.data, 'split': mksplit(split)},
|
170 |
+
pp_fn=pp_eval_prof + pp_label_i21k,
|
171 |
+
loss_name=c.loss,
|
172 |
+
log_steps=5000, # Very fast O(seconds) so it's fine to run it often.
|
173 |
+
)
|
174 |
+
|
175 |
+
c.evals.teacher_test = eval_i21k_t('full[:25_600]')
|
176 |
+
c.evals.teacher_val = eval_i21k_t('full[25_600:51_200]')
|
177 |
+
c.evals.teacher_minitrain = eval_i21k_t('full[51_200:76_800]')
|
178 |
+
|
179 |
+
# Evaluations on ImageNet1k variants by label-mapping.
|
180 |
+
def eval_i1k_t(dataset, split, lblmap):
|
181 |
+
return dict(
|
182 |
+
type='classification_with_labelmap',
|
183 |
+
pred='prof',
|
184 |
+
data=dict(name=dataset, split=mksplit(split)),
|
185 |
+
pp_fn=pp_eval_prof + pp_label_i1k.format(lbl='label'),
|
186 |
+
loss_name=c.loss,
|
187 |
+
log_percent=0.5, # Teacher is fixed, so eval just for plots.
|
188 |
+
label_mapping=lblmap,
|
189 |
+
)
|
190 |
+
c.evals.teacher_i1k_val = eval_i1k_t('imagenet2012', 'validation', 'i1k_i21k')
|
191 |
+
c.evals.teacher_i1k_v2 = eval_i1k_t('imagenet_v2', 'test', 'i1k_i21k')
|
192 |
+
c.evals.teacher_i1k_a = eval_i1k_t('imagenet_a', 'test', 'i1ka_i21k')
|
193 |
+
c.evals.teacher_i1k_r = eval_i1k_t('imagenet_r', 'test', 'i1kr_i21k')
|
194 |
+
c.evals.teacher_i1k_real = eval_i1k_t('imagenet2012_real', 'validation', 'i1k_i21k')
|
195 |
+
c.evals.teacher_i1k_real.pp_fn = pp_eval_prof + pp_label_i1k.format(lbl='real_label')
|
196 |
+
# TODO: add objectnet.
|
197 |
+
|
198 |
+
####
|
199 |
+
# Combined evals
|
200 |
+
|
201 |
+
def get_dist(split, s):
|
202 |
+
return dict(
|
203 |
+
type='proj.distill.distance',
|
204 |
+
pred=f'student_seqhw={s}_prof',
|
205 |
+
data=dict(name='imagenet2012', split=mksplit(split)),
|
206 |
+
pp_fn=pp_eval_both + '|keep("image", "prof")',
|
207 |
+
log_percent=0.05,
|
208 |
+
distances=({'kind': 'kl'}, {'kind': 'logsoftmax_euclidean'},
|
209 |
+
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
|
210 |
+
)
|
211 |
+
for s in c.flexi.seqhw.v:
|
212 |
+
c.evals[f'dist_minitrain_{s:02d}'] = get_dist('full[51_200:76_800]', s)
|
213 |
+
c.evals[f'dist_val_{s:02d}'] = get_dist('full[25_600:51_200]', s)
|
214 |
+
|
215 |
+
# Few-shot evaluators not added for overkill reasons for now.
|
216 |
+
return c
|
big_vision/configs/proj/flexivit/i21k_sup.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Pre-training flexible-seqlen ViT on ImageNet-21k following (internal link).
|
17 |
+
|
18 |
+
This config is for reference, we never ran it on public infrastructure.
|
19 |
+
|
20 |
+
big_vision.trainers.proj.flexi.train \
|
21 |
+
--config big_vision/configs/proj/flexivit/i21k_sup.py \
|
22 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
23 |
+
--config.total_epochs 90
|
24 |
+
"""
|
25 |
+
|
26 |
+
import big_vision.configs.common as bvcc
|
27 |
+
|
28 |
+
|
29 |
+
def get_config(arg=None):
|
30 |
+
"""Config for training."""
|
31 |
+
# 240px is nice because it's divisible by
|
32 |
+
# [240, 120, 80, 60, 48, 40, 30, 24, 20, 16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1]
|
33 |
+
c = bvcc.parse_arg(arg, runlocal=False, res=240)
|
34 |
+
|
35 |
+
c.seed = 0
|
36 |
+
c.total_epochs = 90
|
37 |
+
c.num_classes = 21843
|
38 |
+
c.init_head_bias = -10.0
|
39 |
+
c.loss = 'sigmoid_xent'
|
40 |
+
|
41 |
+
c.input = dict()
|
42 |
+
c.input.data = dict(
|
43 |
+
name='imagenet21k',
|
44 |
+
split='full[51200:]',
|
45 |
+
)
|
46 |
+
c.input.batch_size = 4096 if not c.runlocal else 8
|
47 |
+
c.input.shuffle_buffer_size = 250_000 if not c.runlocal else 25
|
48 |
+
|
49 |
+
pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
|
50 |
+
pp_common_i21k = pp_common.format(onehot_args=f'{c.num_classes}')
|
51 |
+
pp_common_i1k = pp_common.format(onehot_args='1000, key="{lbl}", key_result="labels"')
|
52 |
+
c.input.pp = f'decode_jpeg_and_inception_crop({c.res})|flip_lr|randaug(2,10)' + pp_common_i21k
|
53 |
+
def pp_eval(res=c.res):
|
54 |
+
return f'decode|resize_small({res//7*8})|central_crop({res})'
|
55 |
+
|
56 |
+
# To continue using the near-defunct randaug op.
|
57 |
+
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
58 |
+
|
59 |
+
# Aggressive pre-fetching because our models here are small, so we not only
|
60 |
+
# can afford it, but we also need it for the smallest models to not be
|
61 |
+
# bottle-necked by the input pipeline. Play around with it for -L models tho.
|
62 |
+
c.input.prefetch = 8
|
63 |
+
c.prefetch_to_device = 4
|
64 |
+
|
65 |
+
c.log_training_steps = 50
|
66 |
+
c.ckpt_steps = 1000
|
67 |
+
|
68 |
+
# Model section
|
69 |
+
c.model_name = 'proj.flexi.vit'
|
70 |
+
c.model = dict(
|
71 |
+
variant='B',
|
72 |
+
pool_type='tok',
|
73 |
+
posemb='learn',
|
74 |
+
# patch_size=(32, 32),
|
75 |
+
patch_size=(8, 8),
|
76 |
+
posemb_size=(7, 7),
|
77 |
+
seqhw=None, # Dynamic!
|
78 |
+
)
|
79 |
+
|
80 |
+
# Define the model parameters which are flexible:
|
81 |
+
c.flexi = dict()
|
82 |
+
c.flexi.seqhw = dict(
|
83 |
+
# The settings to sample from. Corresponding patch-sizes at 240px:
|
84 |
+
# 48, 40, 30, 24, 20, 16, 15, 12, 10, 8
|
85 |
+
v=(5, 6, 8, 10, 12, 15, 16, 20, 24, 30),
|
86 |
+
# The probabilities/weights of them. Default uniform.
|
87 |
+
p=(1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
|
88 |
+
)
|
89 |
+
|
90 |
+
# Optimizer section
|
91 |
+
c.optax_name = 'scale_by_adam'
|
92 |
+
c.optax = dict(mu_dtype='bfloat16')
|
93 |
+
c.grad_clip_norm = 1.0
|
94 |
+
|
95 |
+
c.lr = 0.001
|
96 |
+
c.wd = 0.0001
|
97 |
+
c.schedule = dict(warmup_steps=10_000, decay_type='cosine')
|
98 |
+
|
99 |
+
c.mixup = dict(p=0.2, fold_in=None)
|
100 |
+
|
101 |
+
def mksplit(split):
|
102 |
+
if c.runlocal:
|
103 |
+
return split.split('[')[0] + '[:16]'
|
104 |
+
return split
|
105 |
+
|
106 |
+
# Evaluations on i21k itself.
|
107 |
+
def eval_i21k(s, split):
|
108 |
+
return dict(
|
109 |
+
type='classification',
|
110 |
+
pred=f'predict_seqhw={s}',
|
111 |
+
data={**c.input.data, 'split': mksplit(split)},
|
112 |
+
pp_fn=pp_eval() + pp_common_i21k,
|
113 |
+
loss_name=c.loss,
|
114 |
+
log_steps=5000, # Very fast O(seconds) so it's fine to run it often.
|
115 |
+
)
|
116 |
+
|
117 |
+
c.evals = {}
|
118 |
+
for s in c.flexi.seqhw.v:
|
119 |
+
c.evals[f'test{s:02d}'] = eval_i21k(s, 'full[:25_600]')
|
120 |
+
c.evals[f'val{s:02d}'] = eval_i21k(s, 'full[25_600:51_200]')
|
121 |
+
c.evals[f'train{s:02d}'] = eval_i21k(s, 'full[51_200:76_800]')
|
122 |
+
|
123 |
+
# Evaluations on ImageNet1k variants by label-mapping.
|
124 |
+
def eval_i1k(s, dataset, split, lblmap):
|
125 |
+
return dict(
|
126 |
+
type='classification_with_labelmap',
|
127 |
+
pred=f'predict_seqhw={s}',
|
128 |
+
data=dict(name=dataset, split=mksplit(split)),
|
129 |
+
pp_fn=pp_eval() + pp_common_i1k.format(lbl='label'),
|
130 |
+
loss_name=c.loss,
|
131 |
+
log_steps=5000, # Very fast O(seconds) so it's fine to run it often.
|
132 |
+
label_mapping=lblmap,
|
133 |
+
)
|
134 |
+
for s in c.flexi.seqhw.v:
|
135 |
+
c.evals[f'i1k_val{s:02d}'] = eval_i1k(s, 'imagenet2012', 'validation', 'i1k_i21k')
|
136 |
+
c.evals[f'i1k_v2{s:02d}'] = eval_i1k(s, 'imagenet_v2', 'test', 'i1k_i21k')
|
137 |
+
c.evals[f'i1k_a{s:02d}'] = eval_i1k(s, 'imagenet_a', 'test', 'i1ka_i21k')
|
138 |
+
c.evals[f'i1k_r{s:02d}'] = eval_i1k(s, 'imagenet_r', 'test', 'i1kr_i21k')
|
139 |
+
c.evals[f'i1k_real{s:02d}'] = eval_i1k(s, 'imagenet2012_real', 'validation', 'i1k_i21k')
|
140 |
+
c.evals[f'i1k_real{s:02d}'].pp_fn = pp_eval() + pp_common_i1k.format(lbl='real_label')
|
141 |
+
# TODO: add objectnet.
|
142 |
+
|
143 |
+
# Few-shot evaluators not added for overkill reasons for now.
|
144 |
+
return c
|
big_vision/configs/proj/flexivit/timing.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long,missing-function-docstring
|
16 |
+
r"""A config to run timing for FlexiViT (only inference, no I/O etc.).
|
17 |
+
|
18 |
+
big_vision.tools.eval_only \
|
19 |
+
--config big_vision/configs/proj/flexivit/timing.py \
|
20 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
21 |
+
--config.total_epochs 90
|
22 |
+
"""
|
23 |
+
|
24 |
+
from ml_collections import ConfigDict
|
25 |
+
|
26 |
+
|
27 |
+
def get_config():
|
28 |
+
c = ConfigDict()
|
29 |
+
|
30 |
+
shape = (240, 240, 3)
|
31 |
+
c.batch_size = 8 # swept
|
32 |
+
c.init_shapes = [(1, *shape)]
|
33 |
+
c.representation_layer = 'pre_logits'
|
34 |
+
|
35 |
+
# Creating complete model using all params, the sweep will go over variants.
|
36 |
+
c.model_name = 'xp.flexivit.vit'
|
37 |
+
c.model = dict(
|
38 |
+
variant='B',
|
39 |
+
pool_type='tok',
|
40 |
+
patch_size=(10, 10), # Like deit@384
|
41 |
+
seqhw=(24, 24),
|
42 |
+
)
|
43 |
+
c.num_classes = 0
|
44 |
+
|
45 |
+
c.evals = {}
|
46 |
+
c.evals.timing = dict(
|
47 |
+
type='timing',
|
48 |
+
input_shapes=[shape],
|
49 |
+
timing=True,
|
50 |
+
pred_kw=dict(outputs=('pre_logits',)),
|
51 |
+
)
|
52 |
+
|
53 |
+
return c
|
big_vision/configs/proj/givt/README.md
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GIVT: Generative Infinite-Vocabulary Transformers
|
2 |
+
|
3 |
+
*by Michael Tschannen, Cian Eastwood, Fabian Mentzer* [[arxiv]](https://arxiv.org/abs/2312.02116) [[colab]](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/givt/givt_demo_colab.ipynb)
|
4 |
+
|
5 |
+

|
6 |
+
|
7 |
+
|
8 |
+
### Summary
|
9 |
+
|
10 |
+
We introduce generative infinite-vocabulary transformers (GIVT) which generate vector sequences with real-valued entries, instead of discrete tokens from a finite vocabulary.
|
11 |
+
To this end, we propose two surprisingly simple modifications to decoder-only transformers: 1) at the input, we replace the finite-vocabulary lookup table with a linear projection of the input vectors; and 2) at the output, we replace the logits prediction (usually mapped to a categorical distribution) with the parameters of a multivariate Gaussian mixture model.
|
12 |
+
Inspired by the image-generation paradigm of VQ-GAN and MaskGIT, where transformers are used to model the discrete latent sequences of a VQ-VAE, we use GIVT to model the unquantized real-valued latent sequences of a β-VAE.
|
13 |
+
In class-conditional image generation GIVT outperforms VQ-GAN (and improved variants thereof) as well as MaskGIT, and achieves performance competitive with recent latent diffusion models.
|
14 |
+
Finally, we obtain strong results outside of image generation when applying GIVT to panoptic segmentation and depth estimation with a VAE variant of the UViM framework.
|
15 |
+
|
16 |
+
### Checkpoints
|
17 |
+
|
18 |
+
We provide model checkpoints for a subset of the models from the paper.
|
19 |
+
These are meant as small-scale baselines for researchers interested in exploring GIVT, and are not optimized to provide the best possible visual quality (e.g. scaling the model size can substantially improve visual quality as shown in the paper).
|
20 |
+
See below for instructions to train your own models.
|
21 |
+
|
22 |
+
**ImageNet 2012 VAEs**
|
23 |
+
|
24 |
+
| β | 1e-5 | 2.5e-5 | 5e-5 | 1e-4 | 2e-4 |
|
25 |
+
|:-----------|:------:|:----:|:----:|:----:|:----:|
|
26 |
+
| checkpoint | [link][vae_i1k_0] | [link][vae_i1k_1] | [link][vae_i1k_2] | [link][vae_i1k_3] | [link][vae_i1k_4] |
|
27 |
+
|
28 |
+
[vae_i1k_0]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_1e-5_params
|
29 |
+
[vae_i1k_1]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_2p5e-5_params
|
30 |
+
[vae_i1k_2]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_5e-5_params
|
31 |
+
[vae_i1k_3]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_1e-4_params
|
32 |
+
[vae_i1k_4]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_2e-4_params
|
33 |
+
|
34 |
+
**Class-conditional ImageNet 2012 generative models**
|
35 |
+
|
36 |
+
| model | resolution | β | inference | FID | checkpoint |
|
37 |
+
|:------|:----------:|:------:|:-------------|:---:|:-----------|
|
38 |
+
| GIVT-Causal | 256 x 256 | 5e-5 | t=0.95, DB-CFG=0.4 | 3.35 | [link][givt_i1k_1] |
|
39 |
+
| GIVT-MaskGIT | 256 x 256 | 5e-5 | t_C=35, DB-CFG=0.1 | 4.53 | [link][givt_i1k_2] |
|
40 |
+
| GIVT-MaskGIT | 512 x 512 | 5e-5 | t_C=140 | 4.86 | [link][givt_i1k_3] |
|
41 |
+
|
42 |
+
[givt_i1k_1]: https://storage.googleapis.com/big_vision/givt/givt_imagenet_2012_causal_params.npz
|
43 |
+
[givt_i1k_2]: https://storage.googleapis.com/big_vision/givt/givt_imagenet_2012_maskgit_params.npz
|
44 |
+
[givt_i1k_3]: https://storage.googleapis.com/big_vision/givt/givt_imagenet_2012_maskgit_512_params.npz
|
45 |
+
|
46 |
+
|
47 |
+
**UViM**
|
48 |
+
|
49 |
+
| task | model | dataset | accuracy | checkpoint |
|
50 |
+
|:-----|:------|:--------|---------:|:-----------|
|
51 |
+
| Panoptic segmentation | VAE (stage 1) | [COCO (2017)] | 71.0 (PQ) | [link][vae_coco_panoptic] |
|
52 |
+
| Panoptic segmentation | GIVT (stage 2) | [COCO (2017)] | 40.2 (PQ) | [link][givt_coco_panoptic] |
|
53 |
+
| Depth estimation | VAE (stage 1) | [NYU Depth v2] | 0.195 (RMSE) | [link][vae_nyu_depth] |
|
54 |
+
| Depth estimation | GIVT (stage 2) | [NYU Depth v2] | 0.474 (RMSE) | [link][givt_nyu_depth] |
|
55 |
+
|
56 |
+
[NYU Depth v2]: https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html
|
57 |
+
[COCO (2017)]: https://cocodataset.org/#home
|
58 |
+
[vae_coco_panoptic]: https://storage.googleapis.com/big_vision/givt/vae_coco_panoptic_params.npz
|
59 |
+
[givt_coco_panoptic]: https://storage.googleapis.com/big_vision/givt/givt_coco_panoptic_params.npz
|
60 |
+
[vae_nyu_depth]: https://storage.googleapis.com/big_vision/givt/vae_nyu_depth_params.npz
|
61 |
+
[givt_nyu_depth]: https://storage.googleapis.com/big_vision/givt/givt_nyu_depth_params.npz
|
62 |
+
|
63 |
+
### Training models
|
64 |
+
|
65 |
+
This directory contains configs to train GIVT models as well as VAEs (for the UViM variants).
|
66 |
+
For training the ImageNet 2012 VAE models we used a modified version of the [MaskGIT code](https://github.com/google-research/maskgit).
|
67 |
+
|
68 |
+
The `big_vision` input pipeline relies on [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets)
|
69 |
+
which supports some data sets out-of-the-box, whereas others require manual download of the data
|
70 |
+
(for example ImageNet and COCO (2017), see the `big_vision` [main README](../../../../#cloud-tpu-vm-setup) and the [UViM README](../uvim), respectively, for details).
|
71 |
+
|
72 |
+
After setting up `big_vision` as described in the [main README](../../../../#cloud-tpu-vm-setup), training can be launched locally as follows
|
73 |
+
|
74 |
+
```
|
75 |
+
python -m big_vision.trainers.proj.givt.generative \
|
76 |
+
--config big_vision/configs/proj/givt/givt_imagenet2012.py \
|
77 |
+
--workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'`
|
78 |
+
```
|
79 |
+
|
80 |
+
Add the suffix `:key1=value1,key2=value2,...` to the config path in the launch
|
81 |
+
command to modify the config with predefined arguments (see config for details). For example:
|
82 |
+
`--config big_vision/configs/proj/givt/givt_imagenet_2012.py:model_size=large`.
|
83 |
+
Note that `givt_imagenet2012.py` uses [Imagenette](https://github.com/fastai/imagenette) to ensure that the config is runnable without manual ImageNet download.
|
84 |
+
This is only meant for testing and will overfit immediately. Please download ImageNet to reproduce the paper results.
|
85 |
+
|
86 |
+
VAE trainings for the GIVT variant of UViM can be launched as
|
87 |
+
|
88 |
+
```
|
89 |
+
python -m big_vision.trainers.proj.givt.vae \
|
90 |
+
--config big_vision/configs/proj/givt/vae_nyu_depth.py \
|
91 |
+
--workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'`
|
92 |
+
```
|
93 |
+
|
94 |
+
Please refer to the [main README](../../../../#cloud-tpu-vm-setup)
|
95 |
+
for details on how to launch training on a (multi-host) TPU setup.
|
96 |
+
|
97 |
+
|
98 |
+
### Disclaimer
|
99 |
+
|
100 |
+
This is not an official Google Product.
|
101 |
+
|
102 |
+
|
103 |
+
### Citation
|
104 |
+
```
|
105 |
+
@article{tschannen2023givt,
|
106 |
+
title={GIVT: Generative Infinite-Vocabulary Transformers},
|
107 |
+
author={Tschannen, Michael and Eastwood, Cian and Mentzer, Fabian},
|
108 |
+
journal={arXiv:2312.02116},
|
109 |
+
year={2023}
|
110 |
+
}
|
111 |
+
```
|
big_vision/configs/proj/givt/givt_coco_panoptic.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Train a GIVT encoder-decoder model on COCO panoptic."""
|
17 |
+
|
18 |
+
import itertools
|
19 |
+
import ml_collections
|
20 |
+
|
21 |
+
ConfigDict = ml_collections.ConfigDict
|
22 |
+
|
23 |
+
VTT_MODELS = {
|
24 |
+
'base': dict(num_layers=12, num_decoder_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768),
|
25 |
+
'large': dict(num_layers=24, num_decoder_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024),
|
26 |
+
}
|
27 |
+
|
28 |
+
RES = 512
|
29 |
+
PATCH_SIZE = 16
|
30 |
+
LABEL_RES = 512
|
31 |
+
LABEL_PATCH_SIZE = 16
|
32 |
+
|
33 |
+
|
34 |
+
def get_config(runlocal=False):
|
35 |
+
"""Config for training."""
|
36 |
+
config = ConfigDict()
|
37 |
+
|
38 |
+
config.input = {}
|
39 |
+
config.input.pp = (
|
40 |
+
f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
|
41 |
+
f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|'
|
42 |
+
f'inception_box|crop_box(key="image")|crop_box(key="labels")|'
|
43 |
+
f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|'
|
44 |
+
f'value_range(-1, 1)|make_canonical|'
|
45 |
+
f'copy("image", "cond_image")|copy("labels", "image")|'
|
46 |
+
f'keep("image", "cond_image")'
|
47 |
+
)
|
48 |
+
pp_eval = (
|
49 |
+
f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
|
50 |
+
f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|'
|
51 |
+
f'value_range(-1, 1)|make_canonical|'
|
52 |
+
f'copy("image", "cond_image")|copy("labels", "image")|'
|
53 |
+
f'keep("image", "cond_image")'
|
54 |
+
)
|
55 |
+
pp_predict = (
|
56 |
+
f'decode|resize({RES})|value_range(-1, 1)|copy("image", "cond_image")|'
|
57 |
+
f'keep("cond_image", "image/id")' # image/id used for rng seeds.
|
58 |
+
)
|
59 |
+
|
60 |
+
config.input.data = dict(name='coco/2017_panoptic', split='train[4096:]')
|
61 |
+
config.input.batch_size = 512
|
62 |
+
config.input.shuffle_buffer_size = 50_000
|
63 |
+
|
64 |
+
config.total_epochs = 200
|
65 |
+
|
66 |
+
config.log_training_steps = 50
|
67 |
+
config.ckpt_steps = 1000
|
68 |
+
config.keep_ckpt_steps = None
|
69 |
+
config.prefetch_to_device = 2
|
70 |
+
config.seed = 0
|
71 |
+
|
72 |
+
# Optimizer section
|
73 |
+
config.optax_name = 'big_vision.scale_by_adafactor'
|
74 |
+
config.optax = dict(beta2_cap=0.95)
|
75 |
+
|
76 |
+
config.ar_generation_config = ml_collections.ConfigDict()
|
77 |
+
config.ar_generation_config.temp = 0.85
|
78 |
+
config.ar_generation_config.temp_probs = 1.0
|
79 |
+
config.ar_generation_config.beam_size = 4
|
80 |
+
config.ar_generation_config.fan_size = 8
|
81 |
+
config.ar_generation_config.rand_top_k = False
|
82 |
+
config.ar_generation_config.rand_top_k_temp = 1.0
|
83 |
+
|
84 |
+
config.lr = 0.001
|
85 |
+
config.wd = 0.000001
|
86 |
+
config.lr_mults = [
|
87 |
+
('pos_embedding_encoder.*', 0.1),
|
88 |
+
('EmbedPatches.*', 0.1),
|
89 |
+
('encoder.*', 0.1),
|
90 |
+
('decoder.*', 1.0)
|
91 |
+
]
|
92 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=4_000)
|
93 |
+
|
94 |
+
# Oracle section
|
95 |
+
config.vae = ConfigDict()
|
96 |
+
config.vae.model_name = 'proj.givt.vit'
|
97 |
+
config.vae.model = ConfigDict()
|
98 |
+
config.vae.model.input_size = (RES, RES)
|
99 |
+
config.vae.model.patch_size = (PATCH_SIZE, PATCH_SIZE)
|
100 |
+
config.vae.model.code_len = 256
|
101 |
+
config.vae.model.width = 768
|
102 |
+
config.vae.model.enc_depth = 6
|
103 |
+
config.vae.model.dec_depth = 12
|
104 |
+
config.vae.model.mlp_dim = 3072
|
105 |
+
config.vae.model.num_heads = 12
|
106 |
+
config.vae.model.codeword_dim = 16
|
107 |
+
config.vae.model.code_dropout = 'none'
|
108 |
+
config.vae.model.bottleneck_resize = True
|
109 |
+
# values: (channel index in source image, number of classes)
|
110 |
+
config.vae.model.inout_specs = {
|
111 |
+
'semantics': (0, 133 + 1), # +1 for void label
|
112 |
+
'instances': (1, 100), # COCO: actually 98 train/78 validation.
|
113 |
+
}
|
114 |
+
config.vae.model_init = 'gs://big_vision/givt/vae_coco_panoptic_params.npz'
|
115 |
+
|
116 |
+
# Model section
|
117 |
+
config.model_name = 'proj.givt.givt'
|
118 |
+
# # Base model (for exploration)
|
119 |
+
# config.model_init = {'encoder': 'howto-i21k-B/16'}
|
120 |
+
# config.model = ConfigDict(VTT_MODELS['base'])
|
121 |
+
# Large model
|
122 |
+
config.model_init = {'encoder': 'howto-i21k-L/16'}
|
123 |
+
config.model_load = dict(dont_load=('cls', 'head/bias', 'head/kernel'))
|
124 |
+
config.model = ConfigDict(VTT_MODELS['large'])
|
125 |
+
config.model.patches = (PATCH_SIZE, PATCH_SIZE)
|
126 |
+
config.model.input_size = (RES, RES)
|
127 |
+
config.model.posemb_type = 'learn'
|
128 |
+
config.model.seq_len = config.vae.model.code_len
|
129 |
+
config.model.num_labels = None
|
130 |
+
config.model.num_mixtures = 1
|
131 |
+
config.model.fix_square_plus = True
|
132 |
+
config.model.out_dim = config.vae.model.codeword_dim
|
133 |
+
config.model.scale_tol = 1e-6
|
134 |
+
config.model.dec_dropout_rate = 0.0
|
135 |
+
|
136 |
+
# Evaluation section
|
137 |
+
config.evals = {}
|
138 |
+
config.evals.val = ConfigDict()
|
139 |
+
config.evals.val.type = 'mean'
|
140 |
+
config.evals.val.pred = 'validation'
|
141 |
+
config.evals.val.data = dict(name=config.input.data.name, split='train[:4096]')
|
142 |
+
config.evals.val.pp_fn = pp_eval
|
143 |
+
config.evals.val.log_steps = 1000
|
144 |
+
|
145 |
+
config.eval_only = False
|
146 |
+
|
147 |
+
base = {
|
148 |
+
'type': 'proj.givt.coco_panoptic',
|
149 |
+
'data': {**config.input.data},
|
150 |
+
'pp_fn': pp_predict,
|
151 |
+
'log_steps': 10_000,
|
152 |
+
'pred': 'sample_panoptic',
|
153 |
+
# Filters objects that occupy less than 0.03^2 fraction of all pixels.
|
154 |
+
# 'pred_kw': {'min_fraction': 0.03 ** 2},
|
155 |
+
}
|
156 |
+
config.evals.coco_panoptic_train = dict(base)
|
157 |
+
config.evals.coco_panoptic_train.data.split = 'train[4096:8192]'
|
158 |
+
config.evals.coco_panoptic_holdout = dict(base)
|
159 |
+
config.evals.coco_panoptic_holdout.data.split = 'train[:4096]'
|
160 |
+
config.evals.coco_panoptic = dict(base)
|
161 |
+
config.evals.coco_panoptic.data.split = 'validation'
|
162 |
+
|
163 |
+
config.evals.save_pred = dict(type='proj.givt.save_predictions')
|
164 |
+
config.evals.save_pred.pred = 'sample_panoptic'
|
165 |
+
config.evals.save_pred.pp_fn = pp_eval
|
166 |
+
config.evals.save_pred.log_steps = 100_000
|
167 |
+
config.evals.save_pred.data = dict(config.input.data)
|
168 |
+
config.evals.save_pred.data.split = 'validation[:1024]'
|
169 |
+
config.evals.save_pred.outfile = 'inference.npz'
|
170 |
+
|
171 |
+
if runlocal:
|
172 |
+
config.input.batch_size = 4
|
173 |
+
config.input.shuffle_buffer_size = 10
|
174 |
+
config.evals.val.data.split = 'train[:16]'
|
175 |
+
config.evals.val.log_steps = 20
|
176 |
+
config.model.num_layers = 1
|
177 |
+
config.model.num_decoder_layers = 1
|
178 |
+
del config.model_init
|
179 |
+
config.evals.val.data.split = 'validation[:4]'
|
180 |
+
config.evals.coco_panoptic.data.split = 'validation[:4]'
|
181 |
+
config.evals.save_pred.data.split = 'validation[:4]'
|
182 |
+
for k in config.evals.keys():
|
183 |
+
if k not in ['val', 'coco_panoptic', 'save_pred']:
|
184 |
+
del config.evals[k]
|
185 |
+
|
186 |
+
return config
|
big_vision/configs/proj/givt/givt_demo_colab.ipynb
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"gpuType": "T4"
|
8 |
+
},
|
9 |
+
"kernelspec": {
|
10 |
+
"name": "python3",
|
11 |
+
"display_name": "Python 3"
|
12 |
+
},
|
13 |
+
"language_info": {
|
14 |
+
"name": "python"
|
15 |
+
},
|
16 |
+
"accelerator": "GPU"
|
17 |
+
},
|
18 |
+
"cells": [
|
19 |
+
{
|
20 |
+
"cell_type": "markdown",
|
21 |
+
"source": [
|
22 |
+
"# GIVT Demo colab\n",
|
23 |
+
"\n",
|
24 |
+
"[[paper]](https://arxiv.org/abs/2312.02116) [[GitHub]](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/givt/README.md)\n",
|
25 |
+
"\n",
|
26 |
+
"This colab implements class-conditional image generation using GIVT-Causal and GIVT-MaskGIT for the 1k ImageNet2012 classes.\n",
|
27 |
+
"\n",
|
28 |
+
"The available model checkpoints are meant as small-scale baselines (~300M parameters) for researchers interested in exploring GIVT, and are not optimized to provide the best possible visual quality (e.g. scaling the model size can substantially improve visual quality as shown in the paper).\n",
|
29 |
+
"\n",
|
30 |
+
"The colab was tested with the CPU and T4 GPU runtimes. We recommend the T4 GPU runtime (the CPU rutime is very slow).\n",
|
31 |
+
"\n",
|
32 |
+
"_Disclaimer: This is not an official Google Product._"
|
33 |
+
],
|
34 |
+
"metadata": {
|
35 |
+
"id": "botgo-GZiWI_"
|
36 |
+
}
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "markdown",
|
40 |
+
"source": [
|
41 |
+
"### `big_vision` setup"
|
42 |
+
],
|
43 |
+
"metadata": {
|
44 |
+
"id": "jQxc9UZ-mVrQ"
|
45 |
+
}
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"source": [
|
50 |
+
"#@markdown Clone and set up repository\n",
|
51 |
+
"!git clone --branch=main --depth=1 https://github.com/google-research/big_vision\n",
|
52 |
+
"!cd big_vision && git pull\n",
|
53 |
+
"\n",
|
54 |
+
"# Install dependencies - pin TensorFlow-related packages to ensure compatibility\n",
|
55 |
+
"# which might not be needed in in the future\n",
|
56 |
+
"!echo -e \"keras==3.0.5\\ntensorflow==2.16.1\\ntensorflow-probability==0.24.0\" > big_vision/big_vision/constraints.txt\n",
|
57 |
+
"!pip install -r big_vision/big_vision/requirements.txt -c big_vision/big_vision/constraints.txt\n",
|
58 |
+
"%cd big_vision"
|
59 |
+
],
|
60 |
+
"metadata": {
|
61 |
+
"id": "ZAXiVta3n2jL",
|
62 |
+
"cellView": "form"
|
63 |
+
},
|
64 |
+
"execution_count": null,
|
65 |
+
"outputs": []
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"metadata": {
|
71 |
+
"id": "qYS7JNups4MU",
|
72 |
+
"cellView": "form"
|
73 |
+
},
|
74 |
+
"outputs": [],
|
75 |
+
"source": [
|
76 |
+
"#@markdown Imports\n",
|
77 |
+
"import jax\n",
|
78 |
+
"from functools import partial\n",
|
79 |
+
"import ml_collections\n",
|
80 |
+
"import matplotlib.pyplot as plt\n",
|
81 |
+
"\n",
|
82 |
+
"from big_vision.configs.proj.givt import givt_imagenet2012\n",
|
83 |
+
"from big_vision.datasets.imagenet import class_names as imagenet_class_names\n",
|
84 |
+
"from big_vision.models.proj.givt import givt, cnn, decode, parallel_decode\n",
|
85 |
+
"\n",
|
86 |
+
"jnp = jax.numpy"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "markdown",
|
91 |
+
"source": [
|
92 |
+
"### Select and download model\n",
|
93 |
+
"\n"
|
94 |
+
],
|
95 |
+
"metadata": {
|
96 |
+
"id": "MaCM_PIcd2Rb"
|
97 |
+
}
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"source": [
|
102 |
+
"model = \"GIVT-Causal 256x256\" #@param [\"GIVT-Causal 256x256\", \"GIVT-MaskGIT 256x256\", \"GIVT-MaskGIT 512x512\"]\n",
|
103 |
+
"\n",
|
104 |
+
"givt_ckpt_path, cfg_w, temp, is_ar, res = {\n",
|
105 |
+
" \"GIVT-Causal 256x256\": (\n",
|
106 |
+
" \"gs://big_vision/givt/givt_imagenet_2012_causal_params.npz\", 0.4, 0.95, True, 256),\n",
|
107 |
+
" \"GIVT-MaskGIT 256x256\": (\n",
|
108 |
+
" \"gs://big_vision/givt/givt_imagenet_2012_maskgit_params.npz\", 0.0, 35.0, False, 256),\n",
|
109 |
+
" \"GIVT-MaskGIT 512x512\": (\n",
|
110 |
+
" \"gs://big_vision/givt/givt_imagenet_2012_maskgit_512_params.npz\", 0.0, 140.0, False, 512),\n",
|
111 |
+
"}[model]\n",
|
112 |
+
"\n",
|
113 |
+
"config = givt_imagenet2012.get_config(arg=f\"res={res},style={'ar' if is_ar else 'masked'}\")\n",
|
114 |
+
"\n",
|
115 |
+
"print(\"Loading VAE model...\")\n",
|
116 |
+
"vae_model = cnn.Model(**config.vae.model)\n",
|
117 |
+
"vae_params = cnn.load(None, config.vae.model_init, **config.vae.model_load)\n",
|
118 |
+
"\n",
|
119 |
+
"print(\"Loading GIVT model...\")\n",
|
120 |
+
"givt_model = givt.Model(**config.model)\n",
|
121 |
+
"givt_params = jax.device_put(\n",
|
122 |
+
" givt.load(None, givt_ckpt_path), jax.devices()[0])"
|
123 |
+
],
|
124 |
+
"metadata": {
|
125 |
+
"id": "7l6QIjdyN3dg",
|
126 |
+
"cellView": "form"
|
127 |
+
},
|
128 |
+
"execution_count": null,
|
129 |
+
"outputs": []
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "markdown",
|
133 |
+
"source": [
|
134 |
+
"### VAE encode/decode and sampling loop"
|
135 |
+
],
|
136 |
+
"metadata": {
|
137 |
+
"id": "SUj5k1bxd6wr"
|
138 |
+
}
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "code",
|
142 |
+
"source": [
|
143 |
+
"@jax.jit\n",
|
144 |
+
"def vae_encode(images, rng):\n",
|
145 |
+
" \"\"\"Encode image with VAE encoder.\"\"\"\n",
|
146 |
+
" mu, logvar = vae_model.apply(\n",
|
147 |
+
" {\"params\": vae_params}, images, method=vae_model.encode,\n",
|
148 |
+
" )\n",
|
149 |
+
" return vae_model.apply(\n",
|
150 |
+
" {\"params\": vae_params},\n",
|
151 |
+
" mu,\n",
|
152 |
+
" logvar,\n",
|
153 |
+
" method=vae_model.reparametrize,\n",
|
154 |
+
" rngs={\"dropout\": rng},\n",
|
155 |
+
" )\n",
|
156 |
+
"\n",
|
157 |
+
"@jax.jit\n",
|
158 |
+
"def vae_decode(z):\n",
|
159 |
+
" \"\"\"Reconstruct image with VAE decoder from latent code z.\"\"\"\n",
|
160 |
+
" return vae_model.apply({\"params\": vae_params}, z, method=vae_model.decode)\n",
|
161 |
+
"\n",
|
162 |
+
"### jit-compilation seems to go OOM (RAM) on the free tier GPU colab, but might\n",
|
163 |
+
"### lead to speedups on machines with more resources\n",
|
164 |
+
"# @partial(jax.jit, static_argnums=(2, 3))\n",
|
165 |
+
"def sample(labels, rng, ar_generation_config=None, masked_generation_config=None):\n",
|
166 |
+
" \"\"\"Sample from GIVT-Causal or GIVT-MaskGIT.\"\"\"\n",
|
167 |
+
" print(f\"Sampling, style={givt_model.style}\")\n",
|
168 |
+
" shared_kwargs = dict(\n",
|
169 |
+
" labels=labels,\n",
|
170 |
+
" model=givt_model,\n",
|
171 |
+
" seq_len=config.model.seq_len,\n",
|
172 |
+
" feature_dim=config.model.out_dim,\n",
|
173 |
+
" )\n",
|
174 |
+
"\n",
|
175 |
+
" match givt_model.style:\n",
|
176 |
+
" case \"ar\":\n",
|
177 |
+
" sampled_codes, _ = decode.generate(\n",
|
178 |
+
" params={\"params\": givt_params},\n",
|
179 |
+
" seed=rng,\n",
|
180 |
+
" config=dict(ar_generation_config),\n",
|
181 |
+
" **shared_kwargs,\n",
|
182 |
+
" )\n",
|
183 |
+
" info = sampled_codes\n",
|
184 |
+
" case \"masked\":\n",
|
185 |
+
" masked_out = parallel_decode.decode_masked(\n",
|
186 |
+
" rng=rng,\n",
|
187 |
+
" variables={\"params\": givt_params},\n",
|
188 |
+
" config=masked_generation_config,\n",
|
189 |
+
" **shared_kwargs,\n",
|
190 |
+
" )\n",
|
191 |
+
" sampled_codes = masked_out.current_inputs_q\n",
|
192 |
+
" info = masked_out\n",
|
193 |
+
" case _:\n",
|
194 |
+
" raise NotImplementedError\n",
|
195 |
+
" return sampled_codes, info"
|
196 |
+
],
|
197 |
+
"metadata": {
|
198 |
+
"id": "vSn7Si2FS1zi"
|
199 |
+
},
|
200 |
+
"execution_count": null,
|
201 |
+
"outputs": []
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "markdown",
|
205 |
+
"source": [
|
206 |
+
"### Generate images for class label"
|
207 |
+
],
|
208 |
+
"metadata": {
|
209 |
+
"id": "tOnWaJZVeOIX"
|
210 |
+
}
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"source": [
|
215 |
+
"rng = 0 #@param = 'int'\n",
|
216 |
+
"label = 'goldfish' #@param [\"tench\", \"goldfish\", \"great white shark\", \"tiger shark\", \"hammerhead shark\", \"electric ray\", \"stingray\", \"rooster\", \"hen\", \"ostrich\", \"brambling\", \"goldfinch\", \"house finch\", \"junco\", \"indigo bunting\", \"American robin\", \"bulbul\", \"jay\", \"magpie\", \"chickadee\", \"American dipper\", \"kite (bird of prey)\", \"bald eagle\", \"vulture\", \"great grey owl\", \"fire salamander\", \"smooth newt\", \"newt\", \"spotted salamander\", \"axolotl\", \"American bullfrog\", \"tree frog\", \"tailed frog\", \"loggerhead sea turtle\", \"leatherback sea turtle\", \"mud turtle\", \"terrapin\", \"box turtle\", \"banded gecko\", \"green iguana\", \"Carolina anole\", \"desert grassland whiptail lizard\", \"agama\", \"frilled-necked lizard\", \"alligator lizard\", \"Gila monster\", \"European green lizard\", \"chameleon\", \"Komodo dragon\", \"Nile crocodile\", \"American alligator\", \"triceratops\", \"worm snake\", \"ring-necked snake\", \"eastern hog-nosed snake\", \"smooth green snake\", \"kingsnake\", \"garter snake\", \"water snake\", \"vine snake\", \"night snake\", \"boa constrictor\", \"African rock python\", \"Indian cobra\", \"green mamba\", \"sea snake\", \"Saharan horned viper\", \"eastern diamondback rattlesnake\", \"sidewinder rattlesnake\", \"trilobite\", \"harvestman\", \"scorpion\", \"yellow garden spider\", \"barn spider\", \"European garden spider\", \"southern black widow\", \"tarantula\", \"wolf spider\", \"tick\", \"centipede\", \"black grouse\", \"ptarmigan\", \"ruffed grouse\", \"prairie grouse\", \"peafowl\", \"quail\", \"partridge\", \"african grey parrot\", \"macaw\", \"sulphur-crested cockatoo\", \"lorikeet\", \"coucal\", \"bee eater\", \"hornbill\", \"hummingbird\", \"jacamar\", \"toucan\", \"duck\", \"red-breasted merganser\", \"goose\", \"black swan\", \"tusker\", \"echidna\", \"platypus\", \"wallaby\", \"koala\", \"wombat\", \"jellyfish\", \"sea anemone\", \"brain coral\", \"flatworm\", \"nematode\", \"conch\", \"snail\", \"slug\", \"sea slug\", \"chiton\", \"chambered nautilus\", \"Dungeness crab\", \"rock crab\", \"fiddler crab\", \"red king crab\", \"American lobster\", \"spiny lobster\", \"crayfish\", \"hermit crab\", \"isopod\", \"white stork\", \"black stork\", \"spoonbill\", \"flamingo\", \"little blue heron\", \"great egret\", \"bittern bird\", \"crane bird\", \"limpkin\", \"common gallinule\", \"American coot\", \"bustard\", \"ruddy turnstone\", \"dunlin\", \"common redshank\", \"dowitcher\", \"oystercatcher\", \"pelican\", \"king penguin\", \"albatross\", \"grey whale\", \"killer whale\", \"dugong\", \"sea lion\", \"Chihuahua\", \"Japanese Chin\", \"Maltese\", \"Pekingese\", \"Shih Tzu\", \"King Charles Spaniel\", \"Papillon\", \"toy terrier\", \"Rhodesian Ridgeback\", \"Afghan Hound\", \"Basset Hound\", \"Beagle\", \"Bloodhound\", \"Bluetick Coonhound\", \"Black and Tan Coonhound\", \"Treeing Walker Coonhound\", \"English foxhound\", \"Redbone Coonhound\", \"borzoi\", \"Irish Wolfhound\", \"Italian Greyhound\", \"Whippet\", \"Ibizan Hound\", \"Norwegian Elkhound\", \"Otterhound\", \"Saluki\", \"Scottish Deerhound\", \"Weimaraner\", \"Staffordshire Bull Terrier\", \"American Staffordshire Terrier\", \"Bedlington Terrier\", \"Border Terrier\", \"Kerry Blue Terrier\", \"Irish Terrier\", \"Norfolk Terrier\", \"Norwich Terrier\", \"Yorkshire Terrier\", \"Wire Fox Terrier\", \"Lakeland Terrier\", \"Sealyham Terrier\", \"Airedale Terrier\", \"Cairn Terrier\", \"Australian Terrier\", \"Dandie Dinmont Terrier\", \"Boston Terrier\", \"Miniature Schnauzer\", \"Giant Schnauzer\", \"Standard Schnauzer\", \"Scottish Terrier\", \"Tibetan Terrier\", \"Australian Silky Terrier\", \"Soft-coated Wheaten Terrier\", \"West Highland White Terrier\", \"Lhasa Apso\", \"Flat-Coated Retriever\", \"Curly-coated Retriever\", \"Golden Retriever\", \"Labrador Retriever\", \"Chesapeake Bay Retriever\", \"German Shorthaired Pointer\", \"Vizsla\", \"English Setter\", \"Irish Setter\", \"Gordon Setter\", \"Brittany dog\", \"Clumber Spaniel\", \"English Springer Spaniel\", \"Welsh Springer Spaniel\", \"Cocker Spaniel\", \"Sussex Spaniel\", \"Irish Water Spaniel\", \"Kuvasz\", \"Schipperke\", \"Groenendael dog\", \"Malinois\", \"Briard\", \"Australian Kelpie\", \"Komondor\", \"Old English Sheepdog\", \"Shetland Sheepdog\", \"collie\", \"Border Collie\", \"Bouvier des Flandres dog\", \"Rottweiler\", \"German Shepherd Dog\", \"Dobermann\", \"Miniature Pinscher\", \"Greater Swiss Mountain Dog\", \"Bernese Mountain Dog\", \"Appenzeller Sennenhund\", \"Entlebucher Sennenhund\", \"Boxer\", \"Bullmastiff\", \"Tibetan Mastiff\", \"French Bulldog\", \"Great Dane\", \"St. Bernard\", \"husky\", \"Alaskan Malamute\", \"Siberian Husky\", \"Dalmatian\", \"Affenpinscher\", \"Basenji\", \"pug\", \"Leonberger\", \"Newfoundland dog\", \"Great Pyrenees dog\", \"Samoyed\", \"Pomeranian\", \"Chow Chow\", \"Keeshond\", \"brussels griffon\", \"Pembroke Welsh Corgi\", \"Cardigan Welsh Corgi\", \"Toy Poodle\", \"Miniature Poodle\", \"Standard Poodle\", \"Mexican hairless dog (xoloitzcuintli)\", \"grey wolf\", \"Alaskan tundra wolf\", \"red wolf or maned wolf\", \"coyote\", \"dingo\", \"dhole\", \"African wild dog\", \"hyena\", \"red fox\", \"kit fox\", \"Arctic fox\", \"grey fox\", \"tabby cat\", \"tiger cat\", \"Persian cat\", \"Siamese cat\", \"Egyptian Mau\", \"cougar\", \"lynx\", \"leopard\", \"snow leopard\", \"jaguar\", \"lion\", \"tiger\", \"cheetah\", \"brown bear\", \"American black bear\", \"polar bear\", \"sloth bear\", \"mongoose\", \"meerkat\", \"tiger beetle\", \"ladybug\", \"ground beetle\", \"longhorn beetle\", \"leaf beetle\", \"dung beetle\", \"rhinoceros beetle\", \"weevil\", \"fly\", \"bee\", \"ant\", \"grasshopper\", \"cricket insect\", \"stick insect\", \"cockroach\", \"praying mantis\", \"cicada\", \"leafhopper\", \"lacewing\", \"dragonfly\", \"damselfly\", \"red admiral butterfly\", \"ringlet butterfly\", \"monarch butterfly\", \"small white butterfly\", \"sulphur butterfly\", \"gossamer-winged butterfly\", \"starfish\", \"sea urchin\", \"sea cucumber\", \"cottontail rabbit\", \"hare\", \"Angora rabbit\", \"hamster\", \"porcupine\", \"fox squirrel\", \"marmot\", \"beaver\", \"guinea pig\", \"common sorrel horse\", \"zebra\", \"pig\", \"wild boar\", \"warthog\", \"hippopotamus\", \"ox\", \"water buffalo\", \"bison\", \"ram (adult male sheep)\", \"bighorn sheep\", \"Alpine ibex\", \"hartebeest\", \"impala (antelope)\", \"gazelle\", \"arabian camel\", \"llama\", \"weasel\", \"mink\", \"European polecat\", \"black-footed ferret\", \"otter\", \"skunk\", \"badger\", \"armadillo\", \"three-toed sloth\", \"orangutan\", \"gorilla\", \"chimpanzee\", \"gibbon\", \"siamang\", \"guenon\", \"patas monkey\", \"baboon\", \"macaque\", \"langur\", \"black-and-white colobus\", \"proboscis monkey\", \"marmoset\", \"white-headed capuchin\", \"howler monkey\", \"titi monkey\", \"Geoffroy's spider monkey\", \"common squirrel monkey\", \"ring-tailed lemur\", \"indri\", \"Asian elephant\", \"African bush elephant\", \"red panda\", \"giant panda\", \"snoek fish\", \"eel\", \"silver salmon\", \"rock beauty fish\", \"clownfish\", \"sturgeon\", \"gar fish\", \"lionfish\", \"pufferfish\", \"abacus\", \"abaya\", \"academic gown\", \"accordion\", \"acoustic guitar\", \"aircraft carrier\", \"airliner\", \"airship\", \"altar\", \"ambulance\", \"amphibious vehicle\", \"analog clock\", \"apiary\", \"apron\", \"trash can\", \"assault rifle\", \"backpack\", \"bakery\", \"balance beam\", \"balloon\", \"ballpoint pen\", \"Band-Aid\", \"banjo\", \"baluster / handrail\", \"barbell\", \"barber chair\", \"barbershop\", \"barn\", \"barometer\", \"barrel\", \"wheelbarrow\", \"baseball\", \"basketball\", \"bassinet\", \"bassoon\", \"swimming cap\", \"bath towel\", \"bathtub\", \"station wagon\", \"lighthouse\", \"beaker\", \"military hat (bearskin or shako)\", \"beer bottle\", \"beer glass\", \"bell tower\", \"baby bib\", \"tandem bicycle\", \"bikini\", \"ring binder\", \"binoculars\", \"birdhouse\", \"boathouse\", \"bobsleigh\", \"bolo tie\", \"poke bonnet\", \"bookcase\", \"bookstore\", \"bottle cap\", \"hunting bow\", \"bow tie\", \"brass memorial plaque\", \"bra\", \"breakwater\", \"breastplate\", \"broom\", \"bucket\", \"buckle\", \"bulletproof vest\", \"high-speed train\", \"butcher shop\", \"taxicab\", \"cauldron\", \"candle\", \"cannon\", \"canoe\", \"can opener\", \"cardigan\", \"car mirror\", \"carousel\", \"tool kit\", \"cardboard box / carton\", \"car wheel\", \"automated teller machine\", \"cassette\", \"cassette player\", \"castle\", \"catamaran\", \"CD player\", \"cello\", \"mobile phone\", \"chain\", \"chain-link fence\", \"chain mail\", \"chainsaw\", \"storage chest\", \"chiffonier\", \"bell or wind chime\", \"china cabinet\", \"Christmas stocking\", \"church\", \"movie theater\", \"cleaver\", \"cliff dwelling\", \"cloak\", \"clogs\", \"cocktail shaker\", \"coffee mug\", \"coffeemaker\", \"spiral or coil\", \"combination lock\", \"computer keyboard\", \"candy store\", \"container ship\", \"convertible\", \"corkscrew\", \"cornet\", \"cowboy boot\", \"cowboy hat\", \"cradle\", \"construction crane\", \"crash helmet\", \"crate\", \"infant bed\", \"Crock Pot\", \"croquet ball\", \"crutch\", \"cuirass\", \"dam\", \"desk\", \"desktop computer\", \"rotary dial telephone\", \"diaper\", \"digital clock\", \"digital watch\", \"dining table\", \"dishcloth\", \"dishwasher\", \"disc brake\", \"dock\", \"dog sled\", \"dome\", \"doormat\", \"drilling rig\", \"drum\", \"drumstick\", \"dumbbell\", \"Dutch oven\", \"electric fan\", \"electric guitar\", \"electric locomotive\", \"entertainment center\", \"envelope\", \"espresso machine\", \"face powder\", \"feather boa\", \"filing cabinet\", \"fireboat\", \"fire truck\", \"fire screen\", \"flagpole\", \"flute\", \"folding chair\", \"football helmet\", \"forklift\", \"fountain\", \"fountain pen\", \"four-poster bed\", \"freight car\", \"French horn\", \"frying pan\", \"fur coat\", \"garbage truck\", \"gas mask or respirator\", \"gas pump\", \"goblet\", \"go-kart\", \"golf ball\", \"golf cart\", \"gondola\", \"gong\", \"gown\", \"grand piano\", \"greenhouse\", \"radiator grille\", \"grocery store\", \"guillotine\", \"hair clip\", \"hair spray\", \"half-track\", \"hammer\", \"hamper\", \"hair dryer\", \"hand-held computer\", \"handkerchief\", \"hard disk drive\", \"harmonica\", \"harp\", \"combine harvester\", \"hatchet\", \"holster\", \"home theater\", \"honeycomb\", \"hook\", \"hoop skirt\", \"gymnastic horizontal bar\", \"horse-drawn vehicle\", \"hourglass\", \"iPod\", \"clothes iron\", \"carved pumpkin\", \"jeans\", \"jeep\", \"T-shirt\", \"jigsaw puzzle\", \"rickshaw\", \"joystick\", \"kimono\", \"knee pad\", \"knot\", \"lab coat\", \"ladle\", \"lampshade\", \"laptop computer\", \"lawn mower\", \"lens cap\", \"letter opener\", \"library\", \"lifeboat\", \"lighter\", \"limousine\", \"ocean liner\", \"lipstick\", \"slip-on shoe\", \"lotion\", \"music speaker\", \"loupe magnifying glass\", \"sawmill\", \"magnetic compass\", \"messenger bag\", \"mailbox\", \"tights\", \"one-piece bathing suit\", \"manhole cover\", \"maraca\", \"marimba\", \"mask\", \"matchstick\", \"maypole\", \"maze\", \"measuring cup\", \"medicine cabinet\", \"megalith\", \"microphone\", \"microwave oven\", \"military uniform\", \"milk can\", \"minibus\", \"miniskirt\", \"minivan\", \"missile\", \"mitten\", \"mixing bowl\", \"mobile home\", \"ford model t\", \"modem\", \"monastery\", \"monitor\", \"moped\", \"mortar and pestle\", \"graduation cap\", \"mosque\", \"mosquito net\", \"vespa\", \"mountain bike\", \"tent\", \"computer mouse\", \"mousetrap\", \"moving van\", \"muzzle\", \"metal nail\", \"neck brace\", \"necklace\", \"baby pacifier\", \"notebook computer\", \"obelisk\", \"oboe\", \"ocarina\", \"odometer\", \"oil filter\", \"pipe organ\", \"oscilloscope\", \"overskirt\", \"bullock cart\", \"oxygen mask\", \"product packet / packaging\", \"paddle\", \"paddle wheel\", \"padlock\", \"paintbrush\", \"pajamas\", \"palace\", \"pan flute\", \"paper towel\", \"parachute\", \"parallel bars\", \"park bench\", \"parking meter\", \"railroad car\", \"patio\", \"payphone\", \"pedestal\", \"pencil case\", \"pencil sharpener\", \"perfume\", \"Petri dish\", \"photocopier\", \"plectrum\", \"Pickelhaube\", \"picket fence\", \"pickup truck\", \"pier\", \"piggy bank\", \"pill bottle\", \"pillow\", \"ping-pong ball\", \"pinwheel\", \"pirate ship\", \"drink pitcher\", \"block plane\", \"planetarium\", \"plastic bag\", \"plate rack\", \"farm plow\", \"plunger\", \"Polaroid camera\", \"pole\", \"police van\", \"poncho\", \"pool table\", \"soda bottle\", \"plant pot\", \"potter's wheel\", \"power drill\", \"prayer rug\", \"printer\", \"prison\", \"missile\", \"projector\", \"hockey puck\", \"punching bag\", \"purse\", \"quill\", \"quilt\", \"race car\", \"racket\", \"radiator\", \"radio\", \"radio telescope\", \"rain barrel\", \"recreational vehicle\", \"fishing casting reel\", \"reflex camera\", \"refrigerator\", \"remote control\", \"restaurant\", \"revolver\", \"rifle\", \"rocking chair\", \"rotisserie\", \"eraser\", \"rugby ball\", \"ruler measuring stick\", \"sneaker\", \"safe\", \"safety pin\", \"salt shaker\", \"sandal\", \"sarong\", \"saxophone\", \"scabbard\", \"weighing scale\", \"school bus\", \"schooner\", \"scoreboard\", \"CRT monitor\", \"screw\", \"screwdriver\", \"seat belt\", \"sewing machine\", \"shield\", \"shoe store\", \"shoji screen / room divider\", \"shopping basket\", \"shopping cart\", \"shovel\", \"shower cap\", \"shower curtain\", \"ski\", \"balaclava ski mask\", \"sleeping bag\", \"slide rule\", \"sliding door\", \"slot machine\", \"snorkel\", \"snowmobile\", \"snowplow\", \"soap dispenser\", \"soccer ball\", \"sock\", \"solar thermal collector\", \"sombrero\", \"soup bowl\", \"keyboard space bar\", \"space heater\", \"space shuttle\", \"spatula\", \"motorboat\", \"spider web\", \"spindle\", \"sports car\", \"spotlight\", \"stage\", \"steam locomotive\", \"through arch bridge\", \"steel drum\", \"stethoscope\", \"scarf\", \"stone wall\", \"stopwatch\", \"stove\", \"strainer\", \"tram\", \"stretcher\", \"couch\", \"stupa\", \"submarine\", \"suit\", \"sundial\", \"sunglasses\", \"sunglasses\", \"sunscreen\", \"suspension bridge\", \"mop\", \"sweatshirt\", \"swim trunks / shorts\", \"swing\", \"electrical switch\", \"syringe\", \"table lamp\", \"tank\", \"tape player\", \"teapot\", \"teddy bear\", \"television\", \"tennis ball\", \"thatched roof\", \"front curtain\", \"thimble\", \"threshing machine\", \"throne\", \"tile roof\", \"toaster\", \"tobacco shop\", \"toilet seat\", \"torch\", \"totem pole\", \"tow truck\", \"toy store\", \"tractor\", \"semi-trailer truck\", \"tray\", \"trench coat\", \"tricycle\", \"trimaran\", \"tripod\", \"triumphal arch\", \"trolleybus\", \"trombone\", \"hot tub\", \"turnstile\", \"typewriter keyboard\", \"umbrella\", \"unicycle\", \"upright piano\", \"vacuum cleaner\", \"vase\", \"vaulted or arched ceiling\", \"velvet fabric\", \"vending machine\", \"vestment\", \"viaduct\", \"violin\", \"volleyball\", \"waffle iron\", \"wall clock\", \"wallet\", \"wardrobe\", \"military aircraft\", \"sink\", \"washing machine\", \"water bottle\", \"water jug\", \"water tower\", \"whiskey jug\", \"whistle\", \"hair wig\", \"window screen\", \"window shade\", \"Windsor tie\", \"wine bottle\", \"airplane wing\", \"wok\", \"wooden spoon\", \"wool\", \"split-rail fence\", \"shipwreck\", \"sailboat\", \"yurt\", \"website\", \"comic book\", \"crossword\", \"traffic or street sign\", \"traffic light\", \"dust jacket\", \"menu\", \"plate\", \"guacamole\", \"consomme\", \"hot pot\", \"trifle\", \"ice cream\", \"popsicle\", \"baguette\", \"bagel\", \"pretzel\", \"cheeseburger\", \"hot dog\", \"mashed potatoes\", \"cabbage\", \"broccoli\", \"cauliflower\", \"zucchini\", \"spaghetti squash\", \"acorn squash\", \"butternut squash\", \"cucumber\", \"artichoke\", \"bell pepper\", \"cardoon\", \"mushroom\", \"Granny Smith apple\", \"strawberry\", \"orange\", \"lemon\", \"fig\", \"pineapple\", \"banana\", \"jackfruit\", \"cherimoya (custard apple)\", \"pomegranate\", \"hay\", \"carbonara\", \"chocolate syrup\", \"dough\", \"meatloaf\", \"pizza\", \"pot pie\", \"burrito\", \"red wine\", \"espresso\", \"tea cup\", \"eggnog\", \"mountain\", \"bubble\", \"cliff\", \"coral reef\", \"geyser\", \"lakeshore\", \"promontory\", \"sandbar\", \"beach\", \"valley\", \"volcano\", \"baseball player\", \"bridegroom\", \"scuba diver\", \"rapeseed\", \"daisy\", \"yellow lady's slipper\", \"corn\", \"acorn\", \"rose hip\", \"horse chestnut seed\", \"coral fungus\", \"agaric\", \"gyromitra\", \"stinkhorn mushroom\", \"earth star fungus\", \"hen of the woods mushroom\", \"bolete\", \"corn cob\", \"toilet paper\"]\n",
|
217 |
+
"label_int = dict(\n",
|
218 |
+
" zip(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,\n",
|
219 |
+
" range(len(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES))))[label]"
|
220 |
+
],
|
221 |
+
"metadata": {
|
222 |
+
"cellView": "form",
|
223 |
+
"id": "_CiyXD_6nQbu"
|
224 |
+
},
|
225 |
+
"execution_count": null,
|
226 |
+
"outputs": []
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"source": [
|
231 |
+
"%%capture --no-display\n",
|
232 |
+
"batch_size = 8\n",
|
233 |
+
"\n",
|
234 |
+
"target_labels = jnp.full((batch_size,), label_int, jnp.int32)\n",
|
235 |
+
"\n",
|
236 |
+
"if is_ar:\n",
|
237 |
+
" ar_generation_config = dict(cfg_inference_weight=cfg_w, temp=temp)\n",
|
238 |
+
" masked_generation_config = None\n",
|
239 |
+
"else:\n",
|
240 |
+
" ar_generation_config = {}\n",
|
241 |
+
" masked_generation_config = parallel_decode.MaskedGenerationConfig(\n",
|
242 |
+
" cfg_inference_weight=cfg_w,\n",
|
243 |
+
" choice_temperature = temp,\n",
|
244 |
+
" num_steps = 16,\n",
|
245 |
+
" ordering = \"maskgit\",\n",
|
246 |
+
" schedule = \"cosine\",\n",
|
247 |
+
" )\n",
|
248 |
+
"\n",
|
249 |
+
"# Sample from GIVT and decode\n",
|
250 |
+
"sampled_codes, _ = sample(\n",
|
251 |
+
" target_labels, jax.random.PRNGKey(rng),\n",
|
252 |
+
" tuple(ar_generation_config.items()), masked_generation_config)\n",
|
253 |
+
"\n",
|
254 |
+
"generated_images = vae_decode(sampled_codes)"
|
255 |
+
],
|
256 |
+
"metadata": {
|
257 |
+
"id": "sCcGB0m1oQY1"
|
258 |
+
},
|
259 |
+
"execution_count": null,
|
260 |
+
"outputs": []
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"cell_type": "code",
|
264 |
+
"source": [
|
265 |
+
"#@markdown Visualize images\n",
|
266 |
+
"ncols = 4\n",
|
267 |
+
"nrows = generated_images.shape[0] // ncols\n",
|
268 |
+
"fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))\n",
|
269 |
+
"\n",
|
270 |
+
"for idx, (ax, img) in enumerate(zip(axes.flat, generated_images)):\n",
|
271 |
+
" ax.imshow(img * .5 + .5)\n",
|
272 |
+
" if idx == 0:\n",
|
273 |
+
" ax.set_title(f'Label: {label} ({label_int})', fontsize=10, ha='left', loc='left')\n",
|
274 |
+
" ax.set_axis_off()"
|
275 |
+
],
|
276 |
+
"metadata": {
|
277 |
+
"id": "4FWgfAghuh8P",
|
278 |
+
"cellView": "form"
|
279 |
+
},
|
280 |
+
"execution_count": null,
|
281 |
+
"outputs": []
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"cell_type": "code",
|
285 |
+
"source": [
|
286 |
+
"#@markdown Visualize latent codes\n",
|
287 |
+
"nrows = sampled_codes.shape[0]\n",
|
288 |
+
"ncols = sampled_codes.shape[-1] + 1\n",
|
289 |
+
"fig, axes = plt.subplots(nrows, ncols, figsize=(ncols, nrows))\n",
|
290 |
+
"\n",
|
291 |
+
"for r, (row_ax, code) in enumerate(zip(axes, sampled_codes)):\n",
|
292 |
+
" code_norm = (code - code.min()) / (code.max() - code.min())\n",
|
293 |
+
" for c, ax in enumerate(row_ax):\n",
|
294 |
+
" if c == 0:\n",
|
295 |
+
" cc = generated_images[r] * .5 + .5\n",
|
296 |
+
" else:\n",
|
297 |
+
" cc = code_norm[..., c - 1].reshape(res // 16, res // 16)\n",
|
298 |
+
" ax.imshow(cc)\n",
|
299 |
+
" ax.set_axis_off()"
|
300 |
+
],
|
301 |
+
"metadata": {
|
302 |
+
"id": "zGPPeXONy0Am",
|
303 |
+
"cellView": "form"
|
304 |
+
},
|
305 |
+
"execution_count": null,
|
306 |
+
"outputs": []
|
307 |
+
}
|
308 |
+
]
|
309 |
+
}
|
big_vision/configs/proj/givt/givt_imagenet2012.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Train Generative Infinite Vocabulary Transformer (GIVT) on ImageNet.
|
17 |
+
|
18 |
+
Example launch command (local; see main README for launching on TPU servers):
|
19 |
+
|
20 |
+
python -m big_vision.trainers.proj.givt.generative \
|
21 |
+
--config big_vision/configs/proj/givt/givt_imagenet2012.py \
|
22 |
+
--workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'`
|
23 |
+
|
24 |
+
Add the suffix `:key1=value1,key2=value2,...` to the config path in the launch
|
25 |
+
command to modify the the config with the arguments below. For example:
|
26 |
+
`--config big_vision/configs/proj/givt/givt_imagenet_2012.py:model_size=large`
|
27 |
+
"""
|
28 |
+
|
29 |
+
import big_vision.configs.common as bvcc
|
30 |
+
import ml_collections
|
31 |
+
|
32 |
+
|
33 |
+
RES = 256
|
34 |
+
PATCH_SIZE = 16
|
35 |
+
|
36 |
+
GIVT_MODELS = {
|
37 |
+
'base': dict(num_decoder_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768, dec_dropout_rate=0.1),
|
38 |
+
'default': dict(num_decoder_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024, dec_dropout_rate=0.2),
|
39 |
+
'large': dict(num_decoder_layers=48, num_heads=16, mlp_dim=8192, emb_dim=1536, dec_dropout_rate=0.3)
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def get_config(arg=None):
|
44 |
+
"""A config for training a simple VAE on imagenet2012."""
|
45 |
+
arg = bvcc.parse_arg(arg, res=RES, patch_size=PATCH_SIZE, style='ar', # 'ar' or 'masked'
|
46 |
+
model_size='default', runlocal=False, singlehost=False,
|
47 |
+
adaptor=False)
|
48 |
+
config = ml_collections.ConfigDict()
|
49 |
+
|
50 |
+
config.input = {}
|
51 |
+
### Using Imagenette here to ensure this config is runnable without manual
|
52 |
+
### download of ImageNet. This is only meant for testing and will overfit
|
53 |
+
### immediately. Please download ImageNet to reproduce the paper results.
|
54 |
+
# config.input.data = dict(name='imagenet2012', split='train[4096:]')
|
55 |
+
config.input.data = dict(name='imagenette', split='train')
|
56 |
+
|
57 |
+
config.input.batch_size = 8 * 1024 if not arg.runlocal else 8
|
58 |
+
config.input.shuffle_buffer_size = 25_000 if not arg.runlocal else 10
|
59 |
+
|
60 |
+
config.total_epochs = 500
|
61 |
+
|
62 |
+
config.input.pp = (
|
63 |
+
f'decode_jpeg_and_inception_crop({arg.res},'
|
64 |
+
f'area_min=80, area_max=100, ratio_min=1.0, ratio_max=1.0,'
|
65 |
+
f'method="bicubic", antialias=True)'
|
66 |
+
f'|flip_lr'
|
67 |
+
f'|value_range(-1, 1, key="image")'
|
68 |
+
f'|copy("label", "labels")'
|
69 |
+
f'|keep("image", "labels")')
|
70 |
+
|
71 |
+
pp_eval = (
|
72 |
+
f'decode'
|
73 |
+
f'|resize_small({arg.res}, inkey="image", outkey="image",'
|
74 |
+
f'method="bicubic", antialias=True)'
|
75 |
+
f'|central_crop({arg.res})'
|
76 |
+
f'|value_range(-1, 1, key="image")'
|
77 |
+
f'|copy("label", "labels")'
|
78 |
+
f'|keep("image", "labels")')
|
79 |
+
|
80 |
+
config.log_training_steps = 50
|
81 |
+
config.ckpt_steps = 1000
|
82 |
+
config.keep_ckpt_steps = None
|
83 |
+
|
84 |
+
# Flags for AR model.
|
85 |
+
config.ar_generation_config = ml_collections.ConfigDict()
|
86 |
+
config.ar_generation_config.temp = 0.95
|
87 |
+
config.ar_generation_config.temp_probs = 1.0
|
88 |
+
config.ar_generation_config.beam_size = 1
|
89 |
+
config.ar_generation_config.fan_size = 1
|
90 |
+
config.ar_generation_config.rand_top_k = False
|
91 |
+
config.ar_generation_config.rand_top_k_temp = 1.0
|
92 |
+
config.ar_generation_config.cfg_inference_weight = 0.4
|
93 |
+
|
94 |
+
# Flags for masked model.
|
95 |
+
config.masked_generation_config = ml_collections.ConfigDict()
|
96 |
+
config.masked_generation_config.choice_temperature = 35.0
|
97 |
+
config.masked_generation_config.ordering = 'maskgit'
|
98 |
+
config.masked_generation_config.cfg_inference_weight = 0.0
|
99 |
+
config.masked_generation_config.schedule = 'cosine'
|
100 |
+
|
101 |
+
# Used for eval sweep.
|
102 |
+
config.eval_only = False
|
103 |
+
|
104 |
+
# VAE section
|
105 |
+
config.vae = {}
|
106 |
+
config.vae.model = ml_collections.ConfigDict()
|
107 |
+
config.vae.model.code_len = (arg.res // arg.patch_size) ** 2
|
108 |
+
config.vae.model_name = 'proj.givt.cnn'
|
109 |
+
config.vae.model.codeword_dim = 16
|
110 |
+
config.vae.model.filters = 128
|
111 |
+
config.vae.model.num_res_blocks = 2
|
112 |
+
config.vae.model.channel_multipliers = (1, 1, 2, 2, 4)
|
113 |
+
config.vae.model.conv_downsample = False
|
114 |
+
config.vae.model.activation_fn = 'swish'
|
115 |
+
config.vae.model.norm_type = 'GN'
|
116 |
+
if arg.model_size == 'large':
|
117 |
+
config.vae.model_init = 'gs://big_vision/givt/vae_imagenet_2012_beta_1e-5_params'
|
118 |
+
else:
|
119 |
+
config.vae.model_init = 'gs://big_vision/givt/vae_imagenet_2012_beta_5e-5_params'
|
120 |
+
config.vae.model.malib_ckpt = True
|
121 |
+
config.vae.model_load = {}
|
122 |
+
config.vae.model_load.malib_ckpt = config.vae.model.malib_ckpt
|
123 |
+
config.vae.model_load.use_ema_params = True
|
124 |
+
|
125 |
+
# GIVT section
|
126 |
+
config.model_name = 'proj.givt.givt'
|
127 |
+
config.model_init = ''
|
128 |
+
assert arg.model_size in GIVT_MODELS, f'Unknown model size: {arg.model_size}'
|
129 |
+
config.model = ml_collections.ConfigDict(GIVT_MODELS[arg.model_size])
|
130 |
+
config.model.num_layers = 0
|
131 |
+
config.model.num_labels = 1000 # None
|
132 |
+
config.model.seq_len = config.vae.model.code_len
|
133 |
+
config.model.out_dim = config.vae.model.codeword_dim
|
134 |
+
config.model.num_mixtures = 16
|
135 |
+
config.model.posemb_type = 'learn'
|
136 |
+
config.model.scale_tol = 1e-6
|
137 |
+
config.model.style = arg.style
|
138 |
+
config.model.min_masking_rate_training = 0.3
|
139 |
+
config.model.mask_style = 'concat'
|
140 |
+
config.model.drop_labels_probability = 0.1
|
141 |
+
config.model.fix_square_plus = True
|
142 |
+
config.model.per_channel_mixtures = False
|
143 |
+
config.model_init = ''
|
144 |
+
# Required for model sharding
|
145 |
+
config.model.scan = True
|
146 |
+
config.model.remat_policy = 'nothing_saveable'
|
147 |
+
|
148 |
+
# Adaptor section
|
149 |
+
config.adaptor_name = 'proj.givt.adaptor' if arg.adaptor else ''
|
150 |
+
config.adaptor = {}
|
151 |
+
config.adaptor.model = ml_collections.ConfigDict()
|
152 |
+
config.adaptor.model.num_blocks = 8
|
153 |
+
config.adaptor.model.num_channels_bottleneck = 4 * config.model.out_dim
|
154 |
+
|
155 |
+
config.optax_name = 'scale_by_adam'
|
156 |
+
config.optax = dict(b2=0.95)
|
157 |
+
config.grad_clip_norm = 1.0
|
158 |
+
|
159 |
+
# FSDP training by default
|
160 |
+
config.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
161 |
+
config.sharding_rules = [('act_batch', ('data',))]
|
162 |
+
|
163 |
+
# Standard schedule
|
164 |
+
config.lr = 0.001
|
165 |
+
config.wd = 0.0001
|
166 |
+
config.schedule = dict(decay_type='cosine', warmup_percent=0.1)
|
167 |
+
|
168 |
+
# MaskGIT-specific parameters
|
169 |
+
if arg.style == 'masked':
|
170 |
+
config.model.dec_dropout_rate = 0.4
|
171 |
+
config.wd = 0.0
|
172 |
+
if arg.res == 512:
|
173 |
+
config.masked_generation_config.choice_temperature = 140
|
174 |
+
# GIVT-Causal 512px specific parameters
|
175 |
+
elif arg.res == 512 and arg.model_size == 'large':
|
176 |
+
config.model.dec_dropout_rate = 0.1
|
177 |
+
# Set up space-to-depth/pixel shuffle
|
178 |
+
config.vae.model.code_len //= 2
|
179 |
+
config.vae.model.pixel_shuffle_patch_size = (1, 2)
|
180 |
+
config.model.seq_len //= 2
|
181 |
+
config.model.out_dim = config.vae.model.codeword_dim * 2
|
182 |
+
config.model.num_mixtures = 32
|
183 |
+
config.adaptor.model.num_channels_bottleneck = 8 * config.model.out_dim
|
184 |
+
config.adaptor.model.pixel_shuffle_patch_size = (1, 2)
|
185 |
+
# Update sampling config
|
186 |
+
config.ar_generation_config.temp = 0.9
|
187 |
+
config.ar_generation_config.cfg_inference_weight = 0.9
|
188 |
+
|
189 |
+
### Evaluation section
|
190 |
+
config.evals = {}
|
191 |
+
config.evals.val = ml_collections.ConfigDict()
|
192 |
+
config.evals.val.type = 'mean'
|
193 |
+
config.evals.val.pred = 'validation'
|
194 |
+
config.evals.val.data = {**config.input.data}
|
195 |
+
config.evals.val.data.split = f'train[:{4096 if not arg.runlocal else 8}]'
|
196 |
+
config.evals.val.pp_fn = pp_eval
|
197 |
+
config.evals.val.log_steps = 1_000 if not arg.runlocal else 20
|
198 |
+
|
199 |
+
config.evals.save_pred_sampling = dict(
|
200 |
+
type='proj.givt.save_predictions',
|
201 |
+
pp_fn=pp_eval,
|
202 |
+
log_steps=10_000,
|
203 |
+
pred='sample',
|
204 |
+
batch_size=512,
|
205 |
+
data=dict(name=config.input.data.name, split='validation[:512]'),
|
206 |
+
outfile='inference_sampled.npz',
|
207 |
+
)
|
208 |
+
|
209 |
+
config.seed = 0
|
210 |
+
|
211 |
+
config.ckpt_timeout = 30
|
212 |
+
|
213 |
+
if arg.runlocal:
|
214 |
+
config.input.batch_size = 4
|
215 |
+
config.input.shuffle_buffer_size = 10
|
216 |
+
config.log_training_steps = 5
|
217 |
+
config.model.num_decoder_layers = 2
|
218 |
+
|
219 |
+
config.evals.val.data.split = 'validation[:16]'
|
220 |
+
config.evals.val.log_steps = 20
|
221 |
+
|
222 |
+
return config
|
big_vision/configs/proj/givt/givt_nyu_depth.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Train a GIVT encoder-decoder model for NYU depth prediction."""
|
17 |
+
|
18 |
+
import itertools
|
19 |
+
import big_vision.configs.common as bvcc
|
20 |
+
import ml_collections
|
21 |
+
|
22 |
+
ConfigDict = ml_collections.ConfigDict
|
23 |
+
|
24 |
+
VTT_MODELS = {
|
25 |
+
'base': dict(num_layers=12, num_decoder_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768),
|
26 |
+
'large': dict(num_layers=24, num_decoder_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024),
|
27 |
+
}
|
28 |
+
|
29 |
+
RES = 512
|
30 |
+
PATCH_SIZE = 16
|
31 |
+
LABEL_RES = 512
|
32 |
+
LABEL_PATCH_SIZE = 16
|
33 |
+
QUANTIZATION_BINS = 256
|
34 |
+
MIN_DEPTH = 0.001
|
35 |
+
MAX_DEPTH = 10.0
|
36 |
+
|
37 |
+
|
38 |
+
def get_config(arg='split=sweep'):
|
39 |
+
"""Config for training."""
|
40 |
+
arg = bvcc.parse_arg(arg, split='sweep', runlocal=False, singlehost=False)
|
41 |
+
config = ConfigDict()
|
42 |
+
|
43 |
+
config.input = {}
|
44 |
+
config.input.pp = (
|
45 |
+
f'decode|nyu_depth|'
|
46 |
+
f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|'
|
47 |
+
f'inception_box|crop_box(key="image")|crop_box(key="labels")|'
|
48 |
+
f'resize({RES})|'
|
49 |
+
f'resize({LABEL_RES},key="labels",method="nearest")|'
|
50 |
+
f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|'
|
51 |
+
f'value_range(-1,1)|'
|
52 |
+
f'copy("image", "cond_image")|copy("labels", "image")|'
|
53 |
+
f'keep("image", "cond_image")'
|
54 |
+
)
|
55 |
+
pp_eval = (
|
56 |
+
f'decode|nyu_depth|'
|
57 |
+
f'nyu_eval_crop|'
|
58 |
+
f'resize({RES})|'
|
59 |
+
f'resize({LABEL_RES},key="labels",method="nearest")|'
|
60 |
+
f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|'
|
61 |
+
f'value_range(-1,1)|'
|
62 |
+
f'copy("image", "cond_image")|copy("labels", "image")|'
|
63 |
+
f'keep("image", "cond_image")'
|
64 |
+
)
|
65 |
+
pp_predict = (
|
66 |
+
f'decode|nyu_depth|'
|
67 |
+
f'nyu_eval_crop|copy("labels","ground_truth")|'
|
68 |
+
f'resize({RES})|'
|
69 |
+
f'value_range(-1,1)|'
|
70 |
+
f'copy("image", "cond_image")|'
|
71 |
+
f'strong_hash(inkey="tfds_id", outkey="image/id")|'
|
72 |
+
f'keep("cond_image", "ground_truth", "image/id")'
|
73 |
+
)
|
74 |
+
|
75 |
+
config.input.data = dict(name='nyu_depth_v2', split='train')
|
76 |
+
config.input.batch_size = 512
|
77 |
+
config.input.shuffle_buffer_size = 50_000
|
78 |
+
|
79 |
+
config.total_epochs = 50
|
80 |
+
|
81 |
+
config.log_training_steps = 50
|
82 |
+
config.ckpt_steps = 1000
|
83 |
+
config.keep_ckpt_steps = None
|
84 |
+
config.prefetch_to_device = 2
|
85 |
+
config.seed = 0
|
86 |
+
|
87 |
+
# Optimizer section
|
88 |
+
config.optax_name = 'big_vision.scale_by_adafactor'
|
89 |
+
config.optax = dict(beta2_cap=0.95)
|
90 |
+
|
91 |
+
config.ar_generation_config = ConfigDict()
|
92 |
+
config.ar_generation_config.temp = 0.9
|
93 |
+
config.ar_generation_config.temp_probs = 1.0
|
94 |
+
config.ar_generation_config.beam_size = 2
|
95 |
+
config.ar_generation_config.fan_size = 8
|
96 |
+
config.ar_generation_config.rand_top_k = False
|
97 |
+
config.ar_generation_config.rand_top_k_temp = 1.0
|
98 |
+
|
99 |
+
config.lr = 0.001
|
100 |
+
config.wd = 0.000001
|
101 |
+
config.lr_mults = [
|
102 |
+
('pos_embedding_encoder.*', 0.1),
|
103 |
+
('EmbedPatches.*', 0.1),
|
104 |
+
('encoder.*', 0.1),
|
105 |
+
('decoder.*', 1.0)
|
106 |
+
]
|
107 |
+
config.schedule = dict(decay_type='cosine', warmup_percent=0.1)
|
108 |
+
|
109 |
+
# Oracle section
|
110 |
+
config.min_depth = MIN_DEPTH
|
111 |
+
config.max_depth = MAX_DEPTH
|
112 |
+
config.vae = ConfigDict()
|
113 |
+
config.vae.model_name = 'proj.givt.vit'
|
114 |
+
config.vae.model = ConfigDict()
|
115 |
+
config.vae.model.input_size = (RES, RES)
|
116 |
+
config.vae.model.patch_size = (PATCH_SIZE, PATCH_SIZE)
|
117 |
+
config.vae.model.code_len = 256
|
118 |
+
config.vae.model.width = 768
|
119 |
+
config.vae.model.enc_depth = 6
|
120 |
+
config.vae.model.dec_depth = 12
|
121 |
+
config.vae.model.mlp_dim = 3072
|
122 |
+
config.vae.model.num_heads = 12
|
123 |
+
config.vae.model.codeword_dim = 16
|
124 |
+
config.vae.model.code_dropout = 'none'
|
125 |
+
config.vae.model.bottleneck_resize = True
|
126 |
+
# values: (channel index in source image, number of classes)
|
127 |
+
config.vae.model.inout_specs = {
|
128 |
+
'depth': (0, QUANTIZATION_BINS),
|
129 |
+
}
|
130 |
+
config.vae.model_init = 'gs://big_vision/givt/vae_nyu_depth_params.npz'
|
131 |
+
|
132 |
+
# Model section
|
133 |
+
config.model_name = 'proj.givt.givt'
|
134 |
+
# # Base model (for exploration)
|
135 |
+
# config.model_init = {'encoder': 'howto-i21k-B/16'}
|
136 |
+
# config.model = ConfigDict(VTT_MODELS['base'])
|
137 |
+
# Large model
|
138 |
+
config.model_init = {'encoder': 'howto-i21k-L/16'}
|
139 |
+
config.model_load = dict(dont_load=('cls', 'head/bias', 'head/kernel'))
|
140 |
+
config.model = ConfigDict(VTT_MODELS['large'])
|
141 |
+
config.model.patches = (PATCH_SIZE, PATCH_SIZE)
|
142 |
+
config.model.input_size = (RES, RES)
|
143 |
+
config.model.posemb_type = 'learn'
|
144 |
+
config.model.seq_len = config.vae.model.code_len
|
145 |
+
config.model.num_labels = None
|
146 |
+
config.model.num_mixtures = 1
|
147 |
+
config.model.fix_square_plus = True
|
148 |
+
config.model.out_dim = config.vae.model.codeword_dim
|
149 |
+
config.model.scale_tol = 1e-6
|
150 |
+
config.model.dec_dropout_rate = 0.0
|
151 |
+
|
152 |
+
# Evaluation section
|
153 |
+
config.evals = {}
|
154 |
+
config.evals.val = ConfigDict()
|
155 |
+
config.evals.val.type = 'mean'
|
156 |
+
config.evals.val.pred = 'validation'
|
157 |
+
config.evals.val.data = {**config.input.data}
|
158 |
+
config.evals.val.data.split = 'validation'
|
159 |
+
config.evals.val.pp_fn = pp_eval
|
160 |
+
config.evals.val.log_steps = 250
|
161 |
+
|
162 |
+
base = {
|
163 |
+
'type': 'proj.givt.nyu_depth',
|
164 |
+
'data': {**config.input.data},
|
165 |
+
'pp_fn': pp_predict,
|
166 |
+
'pred': 'sample_depth',
|
167 |
+
'log_steps': 2000,
|
168 |
+
'min_depth': MIN_DEPTH,
|
169 |
+
'max_depth': MAX_DEPTH,
|
170 |
+
}
|
171 |
+
|
172 |
+
config.evals.nyu_depth_val = dict(base)
|
173 |
+
config.evals.nyu_depth_val.data.split = 'validation'
|
174 |
+
|
175 |
+
config.evals.save_pred = dict(base)
|
176 |
+
config.evals.save_pred.type = 'proj.givt.save_predictions'
|
177 |
+
del config.evals.save_pred.min_depth, config.evals.save_pred.max_depth
|
178 |
+
config.evals.save_pred.log_steps = 100_000
|
179 |
+
config.evals.save_pred.data.split = 'validation[:128]'
|
180 |
+
config.evals.save_pred.outfile = 'inference.npz'
|
181 |
+
|
182 |
+
config.eval_only = False
|
183 |
+
config.seed = 0
|
184 |
+
|
185 |
+
if arg.runlocal:
|
186 |
+
config.input.batch_size = 4
|
187 |
+
config.input.shuffle_buffer_size = 10
|
188 |
+
config.evals.val.log_steps = 20
|
189 |
+
config.evals.val.data.split = 'validation[:4]'
|
190 |
+
config.evals.nyu_depth_val.data.split = 'validation[:4]'
|
191 |
+
config.evals.save_pred.data.split = 'validation[:4]'
|
192 |
+
config.model.update(VTT_MODELS['base'])
|
193 |
+
del config.model_init
|
194 |
+
for k in config.evals.keys():
|
195 |
+
if k not in ['val', 'nyu_depth_val', 'save_pred']:
|
196 |
+
del config.evals[k]
|
197 |
+
|
198 |
+
return config
|
big_vision/configs/proj/givt/givt_overview.png
ADDED
![]() |
big_vision/configs/proj/givt/vae_coco_panoptic.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Train VAE for GIVT-based UViM COCO panoptic task.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import big_vision.configs.common as bvcc
|
20 |
+
import ml_collections as mlc
|
21 |
+
|
22 |
+
|
23 |
+
def get_config(arg='res=512,patch_size=16'):
|
24 |
+
"""Config for training label compression on COCO-panoptic."""
|
25 |
+
arg = bvcc.parse_arg(arg, res=512, patch_size=16,
|
26 |
+
runlocal=False, singlehost=False)
|
27 |
+
config = mlc.ConfigDict()
|
28 |
+
|
29 |
+
config.input = {}
|
30 |
+
config.input.data = dict(name='coco/2017_panoptic', split='train[4096:]')
|
31 |
+
|
32 |
+
config.input.batch_size = 1024
|
33 |
+
config.input.shuffle_buffer_size = 25_000
|
34 |
+
|
35 |
+
config.total_epochs = 500
|
36 |
+
|
37 |
+
config.input.pp = (
|
38 |
+
f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
|
39 |
+
f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|'
|
40 |
+
f'inception_box|crop_box(key="image")|crop_box(key="labels")|'
|
41 |
+
f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|'
|
42 |
+
f'value_range(-1, 1)|make_canonical|copy("labels","image")|keep("image")'
|
43 |
+
)
|
44 |
+
pp_eval = (
|
45 |
+
f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
|
46 |
+
f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|'
|
47 |
+
f'value_range(-1, 1)|make_canonical|copy("labels","image")|keep("image", "image/id")'
|
48 |
+
)
|
49 |
+
|
50 |
+
config.log_training_steps = 50
|
51 |
+
config.ckpt_steps = 1000
|
52 |
+
config.keep_ckpt_steps = None
|
53 |
+
|
54 |
+
# Model section
|
55 |
+
config.model_name = 'proj.givt.vit'
|
56 |
+
config.model = mlc.ConfigDict()
|
57 |
+
config.model.input_size = (arg.res, arg.res)
|
58 |
+
config.model.patch_size = (arg.patch_size, arg.patch_size)
|
59 |
+
config.model.code_len = 256
|
60 |
+
config.model.width = 768
|
61 |
+
config.model.enc_depth = 6
|
62 |
+
config.model.dec_depth = 12
|
63 |
+
config.model.mlp_dim = 3072
|
64 |
+
config.model.num_heads = 12
|
65 |
+
config.model.codeword_dim = 32
|
66 |
+
config.model.code_dropout = 'none'
|
67 |
+
config.model.bottleneck_resize = True
|
68 |
+
config.model.scan = True
|
69 |
+
config.model.remat_policy = 'nothing_saveable'
|
70 |
+
|
71 |
+
config.rec_loss_fn = 'xent' # xent, l2
|
72 |
+
# values: (index in source image, number of classes)
|
73 |
+
config.model.inout_specs = {
|
74 |
+
'semantics': (0, 133 + 1), # +1 for void label
|
75 |
+
'instances': (1, 100), # COCO: actually 98 train/78 validation.
|
76 |
+
}
|
77 |
+
|
78 |
+
config.beta = 2.5e-4
|
79 |
+
config.beta_percept = 0.0
|
80 |
+
|
81 |
+
config.optax_name = 'scale_by_adam'
|
82 |
+
config.optax = dict(b2=0.95)
|
83 |
+
config.grad_clip_norm = 1.0
|
84 |
+
|
85 |
+
# FSDP training by default
|
86 |
+
config.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
87 |
+
config.sharding_rules = [('act_batch', ('data',))]
|
88 |
+
|
89 |
+
config.lr = 1e-3
|
90 |
+
config.wd = 1e-4
|
91 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=0.1)
|
92 |
+
config.grad_clip_norm = 1.0
|
93 |
+
|
94 |
+
# Evaluation section
|
95 |
+
config.evals = {}
|
96 |
+
config.evals.val = mlc.ConfigDict()
|
97 |
+
config.evals.val.type = 'mean'
|
98 |
+
config.evals.val.pred = 'validation'
|
99 |
+
config.evals.val.data = {**config.input.data}
|
100 |
+
config.evals.val.data.split = 'train[:4096]'
|
101 |
+
config.evals.val.pp_fn = pp_eval
|
102 |
+
config.evals.val.log_steps = 250
|
103 |
+
|
104 |
+
base = {
|
105 |
+
'type': 'proj.givt.coco_panoptic',
|
106 |
+
'pp_fn': pp_eval,
|
107 |
+
'log_steps': 5_000,
|
108 |
+
'pred': 'predict_panoptic',
|
109 |
+
# Filters objects that occupy less than 0.03^2 fraction of all pixels.
|
110 |
+
# 'pred_kw': {'min_fraction': 0.03 ** 2},
|
111 |
+
}
|
112 |
+
config.evals.coco_panoptic_train = dict(**base, data={'split': 'train[4096:8192]'})
|
113 |
+
config.evals.coco_panoptic_holdout = dict(**base, data={'split': 'train[:4096]'})
|
114 |
+
config.evals.coco_panoptic = dict(**base, data={'split': 'validation'})
|
115 |
+
|
116 |
+
config.evals.save_pred = dict(type='proj.givt.save_predictions')
|
117 |
+
config.evals.save_pred.pp_fn = pp_eval
|
118 |
+
config.evals.save_pred.log_steps = 100_000
|
119 |
+
config.evals.save_pred.pred = 'predict_panoptic'
|
120 |
+
config.evals.save_pred.data = {**config.input.data}
|
121 |
+
config.evals.save_pred.data.split = 'validation[:1024]'
|
122 |
+
config.evals.save_pred.outfile = 'inference.npz'
|
123 |
+
|
124 |
+
config.seed = 0
|
125 |
+
|
126 |
+
if arg.singlehost:
|
127 |
+
config.input.batch_size = 128
|
128 |
+
config.num_epochs = 100
|
129 |
+
elif arg.runlocal:
|
130 |
+
config.input.batch_size = 16
|
131 |
+
config.input.shuffle_buffer_size = 10
|
132 |
+
config.log_training_steps = 5
|
133 |
+
config.model.enc_depth = 1
|
134 |
+
config.model.dec_depth = 1
|
135 |
+
|
136 |
+
return config
|
big_vision/configs/proj/givt/vae_nyu_depth.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Train VAE on NYU depth data for GIVT-based UViM.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import big_vision.configs.common as bvcc
|
20 |
+
import ml_collections as mlc
|
21 |
+
|
22 |
+
|
23 |
+
QUANTIZATION_BINS = 256
|
24 |
+
MIN_DEPTH = 0.001
|
25 |
+
MAX_DEPTH = 10.0
|
26 |
+
|
27 |
+
|
28 |
+
def get_config(arg='res=512,patch_size=16'):
|
29 |
+
"""Config for training label compression on NYU depth."""
|
30 |
+
arg = bvcc.parse_arg(arg, res=512, patch_size=16,
|
31 |
+
runlocal=False, singlehost=False)
|
32 |
+
config = mlc.ConfigDict()
|
33 |
+
|
34 |
+
config.input = {}
|
35 |
+
config.input.data = dict(name='nyu_depth_v2', split='train')
|
36 |
+
|
37 |
+
config.input.batch_size = 1024
|
38 |
+
config.input.shuffle_buffer_size = 25_000
|
39 |
+
|
40 |
+
config.total_epochs = 200
|
41 |
+
|
42 |
+
config.input.pp = (
|
43 |
+
f'decode|nyu_depth|'
|
44 |
+
f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|'
|
45 |
+
f'inception_box|crop_box(key="image")|crop_box(key="labels")|'
|
46 |
+
f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|'
|
47 |
+
f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|'
|
48 |
+
f'value_range(-1, 1)|copy("labels", "image")|keep("image")'
|
49 |
+
)
|
50 |
+
pp_eval = (
|
51 |
+
f'decode|nyu_depth|nyu_eval_crop|'
|
52 |
+
f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|'
|
53 |
+
f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|'
|
54 |
+
f'value_range(-1, 1)|copy("labels", "image")|keep("image")'
|
55 |
+
)
|
56 |
+
pp_pred = (
|
57 |
+
f'decode|nyu_depth|nyu_eval_crop|copy("labels","ground_truth")|'
|
58 |
+
f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|'
|
59 |
+
f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|'
|
60 |
+
f'value_range(-1, 1)|copy("labels", "image")|'
|
61 |
+
f'keep("image", "ground_truth")'
|
62 |
+
)
|
63 |
+
|
64 |
+
config.log_training_steps = 50
|
65 |
+
config.ckpt_steps = 1000
|
66 |
+
config.keep_ckpt_steps = None
|
67 |
+
|
68 |
+
# Model section
|
69 |
+
config.min_depth = MIN_DEPTH
|
70 |
+
config.max_depth = MAX_DEPTH
|
71 |
+
config.model_name = 'proj.givt.vit'
|
72 |
+
config.model = mlc.ConfigDict()
|
73 |
+
config.model.input_size = (arg.res, arg.res)
|
74 |
+
config.model.patch_size = (arg.patch_size, arg.patch_size)
|
75 |
+
config.model.code_len = 256
|
76 |
+
config.model.width = 768
|
77 |
+
config.model.enc_depth = 6
|
78 |
+
config.model.dec_depth = 12
|
79 |
+
config.model.mlp_dim = 3072
|
80 |
+
config.model.num_heads = 12
|
81 |
+
config.model.codeword_dim = 16
|
82 |
+
config.model.code_dropout = 'none'
|
83 |
+
config.model.bottleneck_resize = True
|
84 |
+
config.model.scan = True
|
85 |
+
config.model.remat_policy = 'nothing_saveable'
|
86 |
+
config.model_init = ''
|
87 |
+
|
88 |
+
config.rec_loss_fn = 'xent' # xent, l2
|
89 |
+
config.mask_zero_target = True
|
90 |
+
# values: (index in source image, number of classes)
|
91 |
+
config.model.inout_specs = {
|
92 |
+
'depth': (0, QUANTIZATION_BINS),
|
93 |
+
}
|
94 |
+
|
95 |
+
config.beta = 2e-4
|
96 |
+
config.beta_percept = 0.0
|
97 |
+
|
98 |
+
# Optimizer section
|
99 |
+
config.optax_name = 'scale_by_adam'
|
100 |
+
config.optax = dict(b2=0.95)
|
101 |
+
|
102 |
+
# FSDP training by default
|
103 |
+
config.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
104 |
+
config.sharding_rules = [('act_batch', ('data',))]
|
105 |
+
|
106 |
+
config.lr = 1e-3
|
107 |
+
config.wd = 1e-4
|
108 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=0.1)
|
109 |
+
config.grad_clip_norm = 1.0
|
110 |
+
|
111 |
+
# Evaluation section
|
112 |
+
config.evals = {}
|
113 |
+
config.evals.val = mlc.ConfigDict()
|
114 |
+
config.evals.val.type = 'mean'
|
115 |
+
config.evals.val.pred = 'validation'
|
116 |
+
config.evals.val.data = {**config.input.data}
|
117 |
+
config.evals.val.data.split = 'validation'
|
118 |
+
config.evals.val.pp_fn = pp_eval
|
119 |
+
config.evals.val.log_steps = 250
|
120 |
+
|
121 |
+
base = {
|
122 |
+
'type': 'proj.givt.nyu_depth',
|
123 |
+
'data': {**config.input.data},
|
124 |
+
'pp_fn': pp_pred,
|
125 |
+
'pred': 'predict_depth',
|
126 |
+
'log_steps': 2000,
|
127 |
+
'min_depth': MIN_DEPTH,
|
128 |
+
'max_depth': MAX_DEPTH,
|
129 |
+
}
|
130 |
+
config.evals.nyu_depth_val = {**base}
|
131 |
+
config.evals.nyu_depth_val.data.split = 'validation'
|
132 |
+
|
133 |
+
# ### Uses a lot of memory
|
134 |
+
# config.evals.save_pred = dict(type='proj.givt.save_predictions')
|
135 |
+
# config.evals.save_pred.pp_fn = pp_eval
|
136 |
+
# config.evals.save_pred.log_steps = 100_000
|
137 |
+
# config.evals.save_pred.data = {**config.input.data}
|
138 |
+
# config.evals.save_pred.data.split = 'validation[:64]'
|
139 |
+
# config.evals.save_pred.batch_size = 64
|
140 |
+
# config.evals.save_pred.outfile = 'inference.npz'
|
141 |
+
|
142 |
+
config.eval_only = False
|
143 |
+
config.seed = 0
|
144 |
+
|
145 |
+
if arg.singlehost:
|
146 |
+
config.input.batch_size = 128
|
147 |
+
config.num_epochs = 50
|
148 |
+
elif arg.runlocal:
|
149 |
+
config.input.batch_size = 16
|
150 |
+
config.input.shuffle_buffer_size = 10
|
151 |
+
config.log_training_steps = 5
|
152 |
+
config.model.enc_depth = 1
|
153 |
+
config.model.dec_depth = 1
|
154 |
+
config.evals.val.data.split = 'validation[:16]'
|
155 |
+
config.evals.val.log_steps = 20
|
156 |
+
config.evals.nyu_depth_val.data.split = 'validation[:16]'
|
157 |
+
|
158 |
+
return config
|
big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Pre-training ViT on ILSVRC-2012 with GSAM in https://arxiv.org/abs/2203.08065
|
17 |
+
|
18 |
+
Run training of a B/32 model:
|
19 |
+
|
20 |
+
big_vision.trainers.proj.gsam.train \
|
21 |
+
--config big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py \
|
22 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
import big_vision.configs.common as bvcc
|
27 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
28 |
+
import ml_collections as mlc
|
29 |
+
|
30 |
+
def get_config(arg=None):
|
31 |
+
"""Config for training."""
|
32 |
+
arg = bvcc.parse_arg(arg, variant='B/32', runlocal=False)
|
33 |
+
config = mlc.ConfigDict()
|
34 |
+
|
35 |
+
config.dataset = 'imagenet2012'
|
36 |
+
config.train_split = 'train[:99%]'
|
37 |
+
config.cache_raw = not arg.runlocal # Needs up to 120GB of RAM!
|
38 |
+
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
|
39 |
+
config.num_classes = 1000
|
40 |
+
config.loss = 'sigmoid_xent'
|
41 |
+
config.batch_size = 4096
|
42 |
+
config.num_epochs = 300
|
43 |
+
|
44 |
+
pp_common = (
|
45 |
+
'|value_range(-1, 1)'
|
46 |
+
'|onehot(1000, key="{lbl}", key_result="labels")'
|
47 |
+
'|keep("image", "labels")'
|
48 |
+
)
|
49 |
+
config.pp_train = (
|
50 |
+
'decode_jpeg_and_inception_crop(224)|flip_lr|' +
|
51 |
+
pp_common.format(lbl='label')
|
52 |
+
)
|
53 |
+
pp = 'decode|resize_small(256)|central_crop(224)' + pp_common
|
54 |
+
|
55 |
+
# Aggressive pre-fetching because our models here are small, so we not only
|
56 |
+
# can afford it, but we also need it for the smallest models to not be
|
57 |
+
# bottle-necked by the input pipeline. Play around with it for -L models tho.
|
58 |
+
config.prefetch_to_host = 8
|
59 |
+
config.prefetch_to_device = 4
|
60 |
+
|
61 |
+
config.log_training_steps = 50
|
62 |
+
config.checkpoint_steps = 1000
|
63 |
+
|
64 |
+
# Model section
|
65 |
+
config.model_name = 'vit'
|
66 |
+
config.model = dict(
|
67 |
+
variant=arg.variant,
|
68 |
+
rep_size=False,
|
69 |
+
pool_type='gap',
|
70 |
+
)
|
71 |
+
config.init_head_bias = -10.0
|
72 |
+
|
73 |
+
# Optimizer section
|
74 |
+
config.grad_clip_norm = 1.0
|
75 |
+
config.optax_name = 'scale_by_adam'
|
76 |
+
config.optax = dict(mu_dtype='float32')
|
77 |
+
# The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560
|
78 |
+
# almost always behaves exactly like adam, but at a fraction of the memory
|
79 |
+
# cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a
|
80 |
+
# good idea to try it when you are memory-bound!
|
81 |
+
# config.optax_name = 'big_vision.scale_by_adafactor'
|
82 |
+
# A good flag to play with when hitting instabilities, is the following:
|
83 |
+
# config.optax = dict(beta2_cap=0.95)
|
84 |
+
|
85 |
+
config.lr = 0.003
|
86 |
+
config.wd = 0.001 # default is 0.0001; paper used 0.3, effective wd=0.3*lr
|
87 |
+
config.schedule = dict(
|
88 |
+
warmup_steps=10_000,
|
89 |
+
decay_type='linear',
|
90 |
+
linear_end=0.01,
|
91 |
+
)
|
92 |
+
|
93 |
+
# GSAM settings.
|
94 |
+
# Note: when rho_max=rho_min and alpha=0, GSAM reduces to SAM.
|
95 |
+
config.gsam = dict(
|
96 |
+
rho_max=0.6,
|
97 |
+
rho_min=0.1,
|
98 |
+
alpha=0.6,
|
99 |
+
lr_max=config.get_ref('lr'),
|
100 |
+
lr_min=config.schedule.get_ref('linear_end') * config.get_ref('lr'),
|
101 |
+
)
|
102 |
+
|
103 |
+
# Eval section
|
104 |
+
eval_common = dict(
|
105 |
+
type='classification',
|
106 |
+
dataset='imagenet2012',
|
107 |
+
pp_fn=pp.format(lbl='label'),
|
108 |
+
loss_name=config.loss,
|
109 |
+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
|
110 |
+
)
|
111 |
+
config.evals = {}
|
112 |
+
config.evals.train = {**eval_common, 'split': 'train[:2%]'}
|
113 |
+
config.evals.minival = {**eval_common, 'split': 'train[99%:]'}
|
114 |
+
config.evals.val = {**eval_common, 'split': 'validation'}
|
115 |
+
config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'}
|
116 |
+
|
117 |
+
config.evals.real = {**eval_common}
|
118 |
+
config.evals.real.dataset = 'imagenet2012_real'
|
119 |
+
config.evals.real.split = 'validation'
|
120 |
+
config.evals.real.pp_fn = pp.format(lbl='real_label')
|
121 |
+
|
122 |
+
config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
|
123 |
+
config.fewshot.log_steps = 10_000
|
124 |
+
|
125 |
+
# Make a few things much smaller for quick local debugging testruns.
|
126 |
+
if arg.runlocal:
|
127 |
+
config.shuffle_buffer_size = 10
|
128 |
+
config.batch_size = 8
|
129 |
+
config.minival.split = 'train[:16]'
|
130 |
+
config.val.split = 'validation[:16]'
|
131 |
+
config.real.split = 'validation[:16]'
|
132 |
+
config.v2.split = 'test[:16]'
|
133 |
+
|
134 |
+
return config
|
big_vision/configs/proj/image_text/README.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Image/text models
|
2 |
+
|
3 |
+
## LiT: Zero-Shot Transfer with Locked-image text Tuning
|
4 |
+
|
5 |
+
*by Xiaohua Zhai, Xiao Wang, Basil Mustafa, Andreas Steiner, Daniel Keysers, Alexander Kolesnikov, Lucas Beyer*
|
6 |
+
|
7 |
+
https://arxiv.org/abs/2111.07991
|
8 |
+
|
9 |
+
```
|
10 |
+
@article{zhai2022lit,
|
11 |
+
title={LiT: Zero-Shot Transfer with Locked-image Text Tuning},
|
12 |
+
author={Zhai, Xiaohua and Wang, Xiao and Mustafa, Basil and Steiner, Andreas and Keysers, Daniel and Kolesnikov, Alexander and Beyer, Lucas},
|
13 |
+
journal={CVPR},
|
14 |
+
year={2022}
|
15 |
+
}
|
16 |
+
```
|
17 |
+
|
18 |
+
Model card:
|
19 |
+
https://github.com/google-research/vision_transformer/blob/main/model_cards/lit.md
|
20 |
+
|
21 |
+
Colabs:
|
22 |
+
|
23 |
+
- https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb
|
24 |
+
- https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb
|
25 |
+
|
26 |
+
### Results
|
27 |
+
|
28 |
+
| Model | Download link | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I | Config `arg` |
|
29 |
+
| :--- | :---: | :---: | :---: | :---: | :--- |
|
30 |
+
| mixed_L16L | [link](https://storage.googleapis.com/vit_models/lit/LiT-L16L.npz) | 75.7 | 48.5 | 31.2 | `txt=bert_large,img=L/16` |
|
31 |
+
| mixed_B16B | [link](https://storage.googleapis.com/vit_models/lit/LiT-B16B.npz) | 72.1 | 49.4 | 31.1 | `txt=bert_base,img=B/16,img_head` |
|
32 |
+
| mixed_B16B_2 | [link](https://storage.googleapis.com/vit_models/lit/LiT-B16B.npz) | 73.9 | 51.5 | 31.8 | `txt=bert_base,img=B/16` |
|
33 |
+
| coco_B16B | [link](https://storage.googleapis.com/vit_models/lit/big_vision/coco_B16B/checkpoint.npz) | 20.7 | 47.2 | 32.1 | `txt=bert_base,img=B/16` |
|
34 |
+
|
35 |
+
The first three rows are the best available models trained on open source data,
|
36 |
+
originally published in the [`google-research/vision_transformer`] repository.
|
37 |
+
These models were re-evaluated with this codebase using the following commands:
|
38 |
+
|
39 |
+
```bash
|
40 |
+
big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img=B/16,img_head,init=gs://vit_models/lit/LiT-B16B.npz
|
41 |
+
|
42 |
+
big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img=B/16_2,init=gs://vit_models/lit/LiT-B16B_2.npz
|
43 |
+
|
44 |
+
big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_large,img=L/16,init=gs://vit_models/lit/LiT-L16L.npz
|
45 |
+
```
|
46 |
+
|
47 |
+
Unfortunately, the public multi-modal datasets [`CC12M`] and [`YFCC100M`] are
|
48 |
+
not yet available in [`tfds`], so these models cannot be reproduced with the
|
49 |
+
codebase. For this reason we provide the much weaker model `coco_B16B` in the
|
50 |
+
third row, which was trained on the small `tfds` dataset [`coco_captions`], and
|
51 |
+
can be used to verify correctness of the codebase
|
52 |
+
([workdir](https://console.cloud.google.com/storage/browser/vit_models/lit/big_vision/coco_B16B/)).
|
53 |
+
|
54 |
+
[`google-research/vision_transformer`]: https://github.com/google-research/vision_transformer
|
55 |
+
[`CC12M`]: https://arxiv.org/abs/2102.08981
|
56 |
+
[`YFCC100M`]: https://arxiv.org/abs/1503.01817
|
57 |
+
[`tfds`]: https://www.tensorflow.org/datasets/api_docs/python/tfds
|
58 |
+
[`coco_captions`]: https://www.tensorflow.org/datasets/catalog/coco_captions
|
59 |
+
|
60 |
+
|
61 |
+
### Changelog
|
62 |
+
|
63 |
+
- 2022-08-18: Added LiT-B16B_2 model that was trained for 60k steps
|
64 |
+
(LiT_B16B: 30k) without linear head on the image side (LiT_B16B: 768) and has
|
65 |
+
better performance.
|
big_vision/configs/proj/image_text/SigLIP_demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
big_vision/configs/proj/image_text/common.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Snippets and constants used a lot in image-text configs."""
|
16 |
+
|
17 |
+
import ml_collections
|
18 |
+
|
19 |
+
|
20 |
+
# pylint: disable=line-too-long
|
21 |
+
inits = {
|
22 |
+
# Downloaded & extracted from original repo:
|
23 |
+
# https://github.com/google-research/bert
|
24 |
+
'bert_base': ('base', 'gs://vit_models/lit/bert/uncased_L-12_H-768_A-12'),
|
25 |
+
'bert_large': ('large', 'gs://vit_models/lit/bert/uncased_L-uncased_L-24_H-1024_A-16'),
|
26 |
+
# Recommended "How to train your ViT..." checkpoints from
|
27 |
+
# https://github.com/google-research/vision_transformer#available-vit-models
|
28 |
+
'B/32': ('B/32', 'gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz'),
|
29 |
+
'B/16': ('B/16', 'gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz'),
|
30 |
+
'L/16': ('L/16', 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz'),
|
31 |
+
}
|
32 |
+
# pylint: enable=line-too-long
|
33 |
+
|
34 |
+
|
35 |
+
def _square875(sz):
|
36 |
+
return f'resize({int(sz/0.875)})|central_crop({sz})|value_range(-1,1)'
|
37 |
+
|
38 |
+
|
39 |
+
def _aspect75(sz):
|
40 |
+
return f'resize_small({int(sz/0.75)})|central_crop({sz})|value_range(-1,1)'
|
41 |
+
|
42 |
+
|
43 |
+
def _drop_no_real_label(f):
|
44 |
+
return len(f['real_label']) > 0
|
45 |
+
|
46 |
+
|
47 |
+
def _drop_no_imagenet(f):
|
48 |
+
return len(f['labels_imagenet']) > 0
|
49 |
+
|
50 |
+
|
51 |
+
DISCLF_DATASET_OVERRIDES = {
|
52 |
+
'imagenet2012': {'class_names': 'clip', 'split': 'validation'},
|
53 |
+
'imagenet2012_minival': {
|
54 |
+
'dataset_name': 'imagenet2012',
|
55 |
+
'class_names': 'clip',
|
56 |
+
'split': 'train[99%:]',
|
57 |
+
},
|
58 |
+
'imagenet2012_real': {
|
59 |
+
'split': 'validation',
|
60 |
+
'class_names': 'clip',
|
61 |
+
'class_names_dataset_name': 'imagenet2012',
|
62 |
+
'pp_img': lambda sz: (
|
63 |
+
_square875(sz) + '|pad_to_shape(inkey="real_label", outkey="label", shape=[10], pad_value=-1)|keep("label", "image")'), # pylint: disable=line-too-long
|
64 |
+
'pre_filter_fn': _drop_no_real_label,
|
65 |
+
},
|
66 |
+
'imagenet_v2': {'class_names': 'clip'},
|
67 |
+
'imagenet_a': {
|
68 |
+
'class_names': 'clip',
|
69 |
+
'pp_img': lambda sz: _aspect75(sz) + '|map("i1k_i1ka")',
|
70 |
+
},
|
71 |
+
'imagenet_r': {
|
72 |
+
'class_names': 'clip',
|
73 |
+
'pp_img': lambda sz: _square875(sz) + '|map("i1k_i1kr")',
|
74 |
+
},
|
75 |
+
}
|
76 |
+
|
77 |
+
|
78 |
+
def get_disclf(sz, *, pp_txt=None, dataset_names=('imagenet2012',), **kw):
|
79 |
+
"""Returns config for discriminative_classifier of specified datasets."""
|
80 |
+
config = ml_collections.ConfigDict(dict(
|
81 |
+
dataset_names=list(dataset_names),
|
82 |
+
type='proj.image_text.discriminative_classifier',
|
83 |
+
prefix='z/0shot/',
|
84 |
+
pp_img=_square875(sz),
|
85 |
+
dataset_overrides={},
|
86 |
+
cache_final=True,
|
87 |
+
**kw,
|
88 |
+
))
|
89 |
+
if pp_txt:
|
90 |
+
config.pp_txt = pp_txt
|
91 |
+
for name in dataset_names:
|
92 |
+
if name in DISCLF_DATASET_OVERRIDES:
|
93 |
+
config.dataset_overrides[name] = {**DISCLF_DATASET_OVERRIDES[name]}
|
94 |
+
d = config.dataset_overrides[name]
|
95 |
+
if 'pp_img' in d and callable(d['pp_img']):
|
96 |
+
with d.ignore_type():
|
97 |
+
d['pp_img'] = d['pp_img'](sz)
|
98 |
+
return config
|
99 |
+
|
100 |
+
|
101 |
+
def get_coco(
|
102 |
+
*,
|
103 |
+
pp_img='resize(224)|value_range(-1, 1)',
|
104 |
+
pp_txt='tokenize(max_len=16, inkey="texts", eos="sticky", pad_value=1)',
|
105 |
+
prefix='z/retr/coco_',
|
106 |
+
**kw):
|
107 |
+
"""Returns config for mscoco retrieval zero-shot.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
pp_img: Pre-processing string for "image" feature.
|
111 |
+
pp_txt: Pre-processing string for texts (expected to tokenize "texts" to
|
112 |
+
"labels").
|
113 |
+
prefix: Prefix to use for metrics.
|
114 |
+
**kw: Other config settings, most notably log_{steps,percent,...}.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
`ConfigDict` that can be used as a retrieval evaluator configuration.
|
118 |
+
"""
|
119 |
+
return ml_collections.ConfigDict({
|
120 |
+
'type': 'proj.image_text.retrieval',
|
121 |
+
'pp_txt': pp_txt,
|
122 |
+
'pp_img': pp_img,
|
123 |
+
'prefix': prefix,
|
124 |
+
'dataset': 'coco_captions',
|
125 |
+
'txt_name': ('captions', 'text'),
|
126 |
+
**kw,
|
127 |
+
})
|
big_vision/configs/proj/image_text/lit.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
big_vision/configs/proj/image_text/siglip_lit_coco.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""Minimal SigLIP (https://arxiv.org/abs/2303.15343) example.
|
17 |
+
|
18 |
+
Example training:
|
19 |
+
|
20 |
+
big_vision.trainers.proj.image_text.siglip \
|
21 |
+
--config big_vision/configs/proj/image_text/lit_coco.py:batch_size=512 \
|
22 |
+
--workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%Y-%m-%d_%H%M'`
|
23 |
+
"""
|
24 |
+
|
25 |
+
import big_vision.configs.common as bvcc
|
26 |
+
from big_vision.configs.proj.image_text import common
|
27 |
+
from ml_collections import ConfigDict
|
28 |
+
|
29 |
+
|
30 |
+
def get_config(arg=None):
|
31 |
+
"""The base configuration."""
|
32 |
+
arg = bvcc.parse_arg(
|
33 |
+
arg, res=224, runlocal=False, token_len=16, txt='bert_base', img='B/16',
|
34 |
+
init='', img_head=False, batch_size=512)
|
35 |
+
img_name, img_init = common.inits[arg.img]
|
36 |
+
txt_name, txt_init = common.inits[arg.txt]
|
37 |
+
config = ConfigDict()
|
38 |
+
|
39 |
+
config.input = {}
|
40 |
+
config.input.data = dict(name='coco_captions', split='train')
|
41 |
+
config.input.batch_size = arg.batch_size if not arg.runlocal else 32
|
42 |
+
config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50
|
43 |
+
|
44 |
+
config.total_steps = 5_000 if not arg.runlocal else 1
|
45 |
+
|
46 |
+
config.init_shapes = [(1, arg.res, arg.res, 3), (1, arg.token_len,)]
|
47 |
+
config.init_types = ['float32', 'int32']
|
48 |
+
|
49 |
+
if arg.init:
|
50 |
+
vocab_path = arg.init.rsplit('.', 1)[0] + '.txt'
|
51 |
+
else:
|
52 |
+
vocab_path = f'{txt_init}/vocab.txt'
|
53 |
+
tokenizer = lambda inkey: (
|
54 |
+
f'bert_tokenize(inkey="{inkey}", max_len={arg.token_len}, '
|
55 |
+
f'vocab_path="{vocab_path}")')
|
56 |
+
config.input.pp = (
|
57 |
+
f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)'
|
58 |
+
f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")'
|
59 |
+
)
|
60 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text',
|
61 |
+
'proj.flaxformer.bert_ops', 'archive.randaug']
|
62 |
+
|
63 |
+
config.log_training_steps = 50
|
64 |
+
config.ckpt_steps = 1000
|
65 |
+
|
66 |
+
# Model section
|
67 |
+
config.model_name = 'proj.image_text.two_towers'
|
68 |
+
config.model_load = {}
|
69 |
+
if arg.init:
|
70 |
+
config.model_init = arg.init
|
71 |
+
else:
|
72 |
+
config.model_init = {'image': img_init, 'text': txt_init}
|
73 |
+
config.model_load['txt_load_kw'] = {'dont_load': ['head/kernel', 'head/bias']}
|
74 |
+
if not arg.img_head:
|
75 |
+
config.model_load['img_load_kw'] = {'dont_load': ['head/kernel', 'head/bias']}
|
76 |
+
config.model = ConfigDict()
|
77 |
+
config.model.image_model = 'vit'
|
78 |
+
config.model.text_model = 'proj.flaxformer.bert'
|
79 |
+
config.model.image = ConfigDict({
|
80 |
+
'variant': img_name,
|
81 |
+
'pool_type': 'tok',
|
82 |
+
'head_zeroinit': False,
|
83 |
+
})
|
84 |
+
config.model.text = ConfigDict({
|
85 |
+
'config': txt_name,
|
86 |
+
'head_zeroinit': False,
|
87 |
+
})
|
88 |
+
config.model.temperature_init = 10.0
|
89 |
+
dim = {'B': 768, 'L': 1024}[arg.img[0]]
|
90 |
+
config.model.out_dim = (dim if arg.img_head else None, dim) # (image_out_dim, text_out_dim)
|
91 |
+
config.model.bias_init = -2.71
|
92 |
+
|
93 |
+
if txt_name == 'base':
|
94 |
+
config.optax_name = 'scale_by_adam'
|
95 |
+
else:
|
96 |
+
config.optax_name = 'big_vision.scale_by_adafactor'
|
97 |
+
|
98 |
+
config.lr = 0.001
|
99 |
+
config.wd = 0.01
|
100 |
+
warmup_steps = max(int(0.03 * config.total_steps), 100)
|
101 |
+
config.schedule = [
|
102 |
+
('img/.*', None), # Freezes image tower.
|
103 |
+
('.*', dict(decay_type='cosine', warmup_steps=warmup_steps)),
|
104 |
+
]
|
105 |
+
|
106 |
+
config.grad_clip_norm = 1.0
|
107 |
+
|
108 |
+
config.evals = {}
|
109 |
+
config.evals.retrieval_coco = common.get_coco(
|
110 |
+
pp_img=f'resize({arg.res})|value_range(-1, 1)',
|
111 |
+
pp_txt=tokenizer('texts'),
|
112 |
+
log_steps=1000,
|
113 |
+
)
|
114 |
+
|
115 |
+
return config
|
big_vision/configs/proj/paligemma/README.md
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PaliGemma model README
|
2 |
+
|
3 |
+
PaliGemma is an open vision-language model (VLM) inspired by PaLI-3, built with
|
4 |
+
open components, such as
|
5 |
+
the [SigLIP vision model](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb)
|
6 |
+
and
|
7 |
+
the [Gemma language model](https://ai.google.dev/gemma).
|
8 |
+
PaliGemma is designed as a versatile model for transfer to a wide range of
|
9 |
+
vision-language tasks such as image and short video caption, visual question
|
10 |
+
answering, text reading, object detection and object segmentation. Together with
|
11 |
+
the pretrained and transfer checkpoints at multiple resolutions, we provide a
|
12 |
+
checkpoint transferred to a mixture of tasks that can be used for off-the-shelf
|
13 |
+
exploration.
|
14 |
+
|
15 |
+
## Quick Reference
|
16 |
+
|
17 |
+
This is the reference repository of the model, you may also want to check out the resources on
|
18 |
+
|
19 |
+
- [ArXiv](https://arxiv.org/abs/2407.07726): Technical report.
|
20 |
+
- [Kaggle](https://www.kaggle.com/models/google/paligemma):
|
21 |
+
All pre-trained / mix checkpoints and model card.
|
22 |
+
- [Kaggle-FT](https://www.kaggle.com/models/google/paligemma-ft):
|
23 |
+
All fine-tuned checkpoints and model card.
|
24 |
+
- [VertexAI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363):
|
25 |
+
Paligemma models on GCP.
|
26 |
+
- [Hugging Face](https://huggingface.co/google/paligemma-3b-pt-224):
|
27 |
+
PyTorch port of paligemma models.
|
28 |
+
- [Light finetuning colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/finetune_paligemma.ipynb):
|
29 |
+
Lightweight colab for fine-tuning PaliGemma. It can be run on a single T4 GPU (16GB)
|
30 |
+
available on free Colab.
|
31 |
+
- [HuggingFace demo](https://hf.co/spaces/google/paligemma): live demo.
|
32 |
+
|
33 |
+
### Citation BibTeX
|
34 |
+
|
35 |
+
```
|
36 |
+
@article{beyer2024paligemma,
|
37 |
+
title={{PaliGemma: A versatile 3B VLM for transfer}},
|
38 |
+
author={Lucas Beyer and Andreas Steiner and André Susano Pinto and Alexander Kolesnikov and Xiao Wang and Daniel Salz and Maxim Neumann and Ibrahim Alabdulmohsin and Michael Tschannen and Emanuele Bugliarello and Thomas Unterthiner and Daniel Keysers and Skanda Koppula and Fangyu Liu and Adam Grycner and Alexey Gritsenko and Neil Houlsby and Manoj Kumar and Keran Rong and Julian Eisenschlos and Rishabh Kabra and Matthias Bauer and Matko Bošnjak and Xi Chen and Matthias Minderer and Paul Voigtlaender and Ioana Bica and Ivana Balazevic and Joan Puigcerver and Pinelopi Papalampidi and Olivier Henaff and Xi Xiong and Radu Soricut and Jeremiah Harmsen and Xiaohua Zhai},
|
39 |
+
year={2024},
|
40 |
+
journal={arXiv preprint arXiv:2407.07726}
|
41 |
+
}
|
42 |
+
```
|
43 |
+
|
44 |
+
## Model description
|
45 |
+
|
46 |
+
### Overview
|
47 |
+
|
48 |
+
PaliGemma-3B is Vision-Language model that was inspired by the PaLI-3 recipe.
|
49 |
+
It is built on SigLIP visual encoder (specifically, SigLIP-So400m/14) and the
|
50 |
+
Gemma 2B language model. PaliGemma takes as input one or more images,
|
51 |
+
which are turned into "soft tokens" by the SigLIP encoder, and input text
|
52 |
+
(codenamed the "prefix") that is tokenized by Gemma's tokenizer. The image
|
53 |
+
tokens and prefix tokens are concatenated (in this order) and passed to the
|
54 |
+
Gemma decoder with full block-attention, which then generates an output text
|
55 |
+
(the "suffix") auto-regressively with masked attention.
|
56 |
+
|
57 |
+

|
58 |
+
|
59 |
+
### Training stages
|
60 |
+
|
61 |
+
Similar to PaLI-3, PaliGemma's training consists of multiple stages:
|
62 |
+
|
63 |
+
- Stage 0: the unimodal pre-training. We use publicly available off-the-shelf
|
64 |
+
SigLIP and Gemma models which have been pre-trained unimodally by their
|
65 |
+
respective authors.
|
66 |
+
- Stage 1: multimodal pre-training. The combined PaliGemma model is now
|
67 |
+
pre-trained on a fully multimodal training dataset, this at a low resolution
|
68 |
+
of 224px² and prefix+suffix sequence length of 128 tokens. This results in
|
69 |
+
the first base model that we release.
|
70 |
+
- Stage 2: high-resolution pre-training. We continue pre-training of the
|
71 |
+
Stage 1 model at resolution 448px² with sequence length 512 tokens for a short
|
72 |
+
duration on the same multimodal training data, but re-weighted with more
|
73 |
+
emphasis on examples that make use of higher resolution or longer sequence
|
74 |
+
length. We repeat this once more at resolution 896px². This results in two
|
75 |
+
further "high res" base models that we also release.
|
76 |
+
- Stage 3: fine-tune. The base models are transferred to
|
77 |
+
specific tasks by fine-tuning. To facilitate further research and
|
78 |
+
reproducibility, we release checkpoints fine-tuned on most of the benchmarks
|
79 |
+
we evaluate on. We also provide a "mix" transfer model, fine-tuned on a wide
|
80 |
+
variety of data, for use in interactive demos.
|
81 |
+
|
82 |
+
Most of the code examples, use-cases, and code release are about Stage 3:
|
83 |
+
transferring to a task or dataset of interest to the user.
|
84 |
+
|
85 |
+
### Tokenizer
|
86 |
+
|
87 |
+
PaliGemma uses the Gemma tokenizer with 256'000 tokens, but we further extend
|
88 |
+
its vocabulary with 1024 entries that represent coordinates in normalized
|
89 |
+
image-space (\<loc0000>...\<loc1023>), and another with 128 entries
|
90 |
+
(\<seg000>...\<seg127>) that are codewords used by a lightweight
|
91 |
+
referring-expression segmentation vector-quantized variational auto-encoder
|
92 |
+
(VQ-VAE) with the architecture of [Ning et al. (2023)](https://arxiv.org/abs/2301.02229) and trained on OpenImages
|
93 |
+
as in PaLI-3. While the `big_vision` codebase is flexible enough to extend
|
94 |
+
tokenizers on-the-fly, we also provide a SentencePiece model file of the Gemma
|
95 |
+
tokenizer with these additional tokens baked in, for the convenience of
|
96 |
+
other codebases.
|
97 |
+
|
98 |
+
## Checkpoints
|
99 |
+
|
100 |
+
The PaliGemma models are released under the same open license as the Gemma
|
101 |
+
models, and hence require manual acknowledgement of the license terms on kaggle:
|
102 |
+
https://www.kaggle.com/models/google/paligemma. The reference checkpoints are
|
103 |
+
available on
|
104 |
+
[Kaggle](https://www.kaggle.com/models/google/paligemma),
|
105 |
+
[VertexAI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) and
|
106 |
+
[Hugging Face](https://huggingface.co/google/paligemma-3b-pt-224).
|
107 |
+
|
108 |
+
### Pretrained checkpoints
|
109 |
+
|
110 |
+
Use one of these checkpoints as initialization for fine-tuning:
|
111 |
+
|
112 |
+
- pt-224: Versatile pretrained model for tasks that do not require seeing
|
113 |
+
small details in the image.
|
114 |
+
Examples: natural image captioning and question-answering, detection and
|
115 |
+
segmentation of medium-large objects. This model was trained with
|
116 |
+
sequence length 128.
|
117 |
+
- pt-448: Versatile base model for mid/higher resolution tasks with access
|
118 |
+
to smaller details. Besides higher resolution, it has gotten more weight on
|
119 |
+
text reading, detection, and segmentation during its pre-training. Examples:
|
120 |
+
as above, plus detection, segmentation, text/diagram reading. This model was
|
121 |
+
trained with sequence length 512.
|
122 |
+
- pt-896: Further scaled-up version of pt-448, especially good at reading
|
123 |
+
very small texts as often found in documents and infographics. This model
|
124 |
+
was trained with sequence length 512.
|
125 |
+
|
126 |
+
Besides the reference float32 checkpoint (11GB), we further provide
|
127 |
+
bfloat16 and float16 variants of each, to reduce download and storage time.
|
128 |
+
These are good for inference and frozen transfers, but full fine-tuning
|
129 |
+
should happen in float32 or mixed precision.
|
130 |
+
|
131 |
+
### Mixture checkpoint
|
132 |
+
|
133 |
+
This checkpoint is trained on a mixture of all our transfer tasks,
|
134 |
+
with a balancing intended to make it "nice to use" out of the box for
|
135 |
+
predictions. This model is multilingual and should
|
136 |
+
understand prompts in various languages, although English
|
137 |
+
is still its "mother tongue".
|
138 |
+
Questions can be asked in a natural way (including asking for a caption or
|
139 |
+
reading the text), and detection and segmentation should still work with the
|
140 |
+
structured `detect {things}` and `segment {things}` prompts as in the base model.
|
141 |
+
|
142 |
+
- mix-224: Similarly to pt-224, this model is good at many natural image
|
143 |
+
tasks that do not require high resolution. Unlike the raw pre-trained model,
|
144 |
+
however, it can be interacted with more freely. For example, ask it to
|
145 |
+
"describe this image in great detail, please" or "How many coins do you see
|
146 |
+
in the picture?". This model was trained with sequence length 256.
|
147 |
+
- mix-448: As above, but it is better at tasks that require higher-resolution
|
148 |
+
input. For example, one could ask it "what is written in the "sum" field?",
|
149 |
+
to "describe this figure", or to "what is the GDP of France?" when shown an
|
150 |
+
infographic of countries' GDPs. This model was trained with
|
151 |
+
sequence length 512.
|
152 |
+
|
153 |
+
### Transfers results and checkpoints
|
154 |
+
|
155 |
+
We provide checkpoints transferred to most of the tasks we evaluated
|
156 |
+
transfer on, see the [kaggle page](https://www.kaggle.com/models/google/paligemma).
|
157 |
+
These are intended for use when a specialised model corresponding
|
158 |
+
to one of the tasks is needed, for academic research purposes only.
|
159 |
+
Depending on the task, they may require a specialised preprocessing format.
|
160 |
+
|
161 |
+
The transfer setup is reasonably unified, with the main factors of variation
|
162 |
+
being the training duration, learning-rate, and whether or not to use dropout
|
163 |
+
and label-smoothing. Details can be found in the corresponding config files or
|
164 |
+
in an upcoming tech report.
|
165 |
+
|
166 |
+
Importantly, none of these tasks or datasets are part of the pre-training data
|
167 |
+
mixture, and their images are explicitly removed from the web-scale
|
168 |
+
pretraining data.
|
169 |
+
|
170 |
+
#### Captioning
|
171 |
+
|
172 |
+
Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896
|
173 |
+
-----------------------|----------------|--------|--------|--------
|
174 |
+
[COCO captions](https://cocodataset.org/#home) (train+restval) | CIDEr (val) | 141.92 | 144.60 |
|
175 |
+
[NoCaps](https://nocaps.org/) (Eval of COCO captions transfer) | CIDEr (val) | 121.72 | 123.58 |
|
176 |
+
[COCO-35L](https://arxiv.org/abs/2205.12522) (train) | CIDEr dev (en / avg-34 / avg) | 139.2 / 115.8 / 116.4 | 141.2 / 118.0 / 118.6 |
|
177 |
+
[XM3600](https://arxiv.org/abs/2205.12522) (Eval of COCO-35L transfer) | CIDEr test (en / avg-35 / avg) | 78.1 / 41.3 / 42.4 | 80.0 / 41.9 / 42.9 |
|
178 |
+
[TextCaps](https://textvqa.org/textcaps/) (train) | CIDEr (val) | 127.48 | 153.94 |
|
179 |
+
[SciCap](https://arxiv.org/abs/2110.11624) (first sentence, no subfigure) (train+val) | CIDEr / BLEU-4 (test) | 162.25 / 0.192 | 181.49 / 0.211 |
|
180 |
+
[Screen2words](https://arxiv.org/abs/2108.03353) (train+dev) | CIDEr (test) | 117.57 | 119.59 |
|
181 |
+
[Widget Captioning](https://arxiv.org/abs/2010.04295) (train+dev) | CIDEr (test) | 136.07 | 148.36 |
|
182 |
+
|
183 |
+
#### Question Answering
|
184 |
+
|
185 |
+
Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896
|
186 |
+
-----------------------|----------------|--------|--------|--------
|
187 |
+
[VQAv2](https://visualqa.org/index.html) (train+validation) | Accuracy (Test server - std) | 83.19 | 85.64 |
|
188 |
+
[MMVP](https://arxiv.org/abs/2401.06209) (Eval of VQAv2 transfer) | Paired Accuracy | 47.33 | 45.33 |
|
189 |
+
[POPE](https://arxiv.org/abs/2305.10355) (Eval of VQAv2 transfer) | Accuracy (random / popular / adversarial) | 87.80 / 85.87 / 84.27 | 88.23 / 86.77 / 85.90 |
|
190 |
+
[Objaverse Multiview](https://arxiv.org/abs/2311.17851) (Eval of VQAv2 transfer) | Cosine Similarity (USEv4) | 62.7 | 62.8 |
|
191 |
+
[OKVQA](https://okvqa.allenai.org/) (train) | Accuracy (val) | 63.54 | 63.15 |
|
192 |
+
[A-OKVQA](https://allenai.org/project/a-okvqa/home) (MC) (train+val) | Accuracy (Test server) | 76.37 | 76.90 |
|
193 |
+
[A-OKVQA](https://allenai.org/project/a-okvqa/home) (DA) (train+val) | Accuracy (Test server) | 61.85 | 63.22 |
|
194 |
+
[GQA](https://cs.stanford.edu/people/dorarad/gqa/about.html) (train_balanced+val_balanced) | Accuracy (testdev balanced) | 65.61 | 67.03 |
|
195 |
+
[xGQA](https://aclanthology.org/2022.findings-acl.196/) (Eval of GQA transfer) | Mean Accuracy (bn,de,en,id,ko,pt,ru,zh) | 58.37 | 59.07 |
|
196 |
+
[NLVR2](https://lil.nlp.cornell.edu/nlvr/) (train+dev) | Accuracy (test) | 90.02 | 88.93 |
|
197 |
+
[MaRVL](https://marvl-challenge.github.io/) (Eval of NLVR2 transfer) | Mean Accuracy (test) (id,sw,ta,tr,zh) | 80.57 | 76.78 |
|
198 |
+
[AI2D](https://allenai.org/data/diagrams) (train) | Accuracy (test) | 72.12 | 73.28 |
|
199 |
+
[ScienceQA](https://scienceqa.github.io/) (Img subset, no CoT) (train+val) | Accuracy (test) | 95.39 | 95.93 |
|
200 |
+
[RSVQA-LR](https://zenodo.org/records/6344334) (Non numeric) (train+val) | Mean Accuracy (test) | 92.65 | 93.11 |
|
201 |
+
[RSVQA-HR](https://zenodo.org/records/6344367) (Non numeric) (train+val) | Mean Accuracy (test/test2) | 92.61 / 90.58 | 92.79 / 90.54 |
|
202 |
+
[ChartQA](https://arxiv.org/abs/2203.10244) (human+aug)x(train+val) | Mean Relaxed Accuracy (test_human, test_aug) | 57.08 | 71.36 |
|
203 |
+
[VizWiz](https://vizwiz.org/tasks-and-datasets/vqa/) VQA (train+val) | Accuracy (Test server - std) | 73.7 | 75.52 |
|
204 |
+
[TallyQA](https://arxiv.org/abs/1810.12440) (train) | Accuracy (test_simple/test_complex) | 81.72 / 69.56 | 84.86 / 72.27 |
|
205 |
+
[OCR-VQA](https://ocr-vqa.github.io/) (train+val) | Accuracy (test) | 73.24 | 75.60 | 75.90
|
206 |
+
[TextVQA](https://textvqa.org/) (train+val) | Accuracy (Test server - std) | 55.47 | 73.15 | 76.48
|
207 |
+
[DocVQA](https://www.docvqa.org/) (train+val) | ANLS (Test server) | 43.74 | 78.02 | 84.77
|
208 |
+
[Infographic VQA](https://openaccess.thecvf.com/content/WACV2022/papers/Mathew_InfographicVQA_WACV_2022_paper.pdf) (train+val) | ANLS (Test server) | 28.46 | 40.47 | 47.75
|
209 |
+
[SceneText VQA](https://arxiv.org/abs/1905.13648) (train+val) | ANLS (Test server) | 63.29 | 81.82 | 84.40
|
210 |
+
|
211 |
+
#### Segmentation
|
212 |
+
|
213 |
+
Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896
|
214 |
+
-----------------------|----------------|--------|--------|--------
|
215 |
+
[RefCOCO](https://arxiv.org/abs/1608.00272) (combined refcoco, refcoco+, refcocog excluding val and test images) | MIoU (validation) refcoco / refcoco+ / refcocog | 73.40 / 68.32 / 67.65 | 75.57 / 69.76 / 70.17 | 76.94 / 72.18 / 72.22
|
216 |
+
|
217 |
+
#### Video tasks (Caption/QA)
|
218 |
+
|
219 |
+
Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896
|
220 |
+
-----------------------|----------------|--------|--------|--------
|
221 |
+
[MSR-VTT](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) (Captioning) | CIDEr (test) | 70.54 |
|
222 |
+
[MSR-VTT](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) (QA) | Accuracy (test) | 50.09 |
|
223 |
+
[ActivityNet](http://activity-net.org/) (Captioning)] | CIDEr (test) | 34.62 |
|
224 |
+
[ActivityNet](http://activity-net.org/) (QA) | Accuracy (test) | 50.78 |
|
225 |
+
[VATEX](https://eric-xw.github.io/vatex-website/about.html) (Captioning) | CIDEr (test) | 79.73 |
|
226 |
+
[MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) (QA) | Accuracy (test) | 60.22 |
|
227 |
+
|
228 |
+
#### Mix model (finetune on mixture of transfer tasks)
|
229 |
+
|
230 |
+
Benchmark | Metric (split) | mix-224 | mix-448
|
231 |
+
----------|----------------|---------|---------
|
232 |
+
[MMVP](https://arxiv.org/abs/2401.06209) | Paired Accuracy | 46.00 | 45.33
|
233 |
+
[POPE](https://arxiv.org/abs/2305.10355) | Accuracy (random / popular / adversarial) | 88.00 / 86.63 / 85.67 | 89.37 / 88.40 / 87.47
|
234 |
+
|
235 |
+
|
236 |
+
## How to run PaliGemma fine-tuning
|
237 |
+
|
238 |
+
To run PaliGemma fine-tuning, set up the `big_vision` repository by following the
|
239 |
+
main README file. Here we provide PaliGemma-specific instructions.
|
240 |
+
|
241 |
+
Checkpoints can be downloaded from Kaggle. You need to create an account and acknowledge checkpoint usage policy. You can then download any checkpoint:
|
242 |
+
|
243 |
+
```
|
244 |
+
export KAGGLE_USERNAME=
|
245 |
+
export KAGGLE_KEY=
|
246 |
+
|
247 |
+
# See https://www.kaggle.com/models/google/paligemma for a full list of models.
|
248 |
+
export MODEL_NAME=paligemma-3b-pt-224
|
249 |
+
export CKPT_FILE=paligemma-3b-pt-224.npz
|
250 |
+
|
251 |
+
mkdir ckpts/
|
252 |
+
cd ckpts/
|
253 |
+
|
254 |
+
curl -L -u $KAGGLE_USERNAME:$KAGGLE_KEY\
|
255 |
+
-o pt_224.npz \
|
256 |
+
https://www.kaggle.com/api/v1/models/google/paligemma/jax/$MODEL_NAME/1/download/$CKPT_FILE
|
257 |
+
```
|
258 |
+
|
259 |
+
As an example, we provide the `forkme.py` config that is based on the easily-adjustable jsonl data source:
|
260 |
+
|
261 |
+
```
|
262 |
+
BV_GEMMA_DIR=ckpts/ python -m big_vision.trainers.proj.paligemma.train --config big_vision/configs/proj/paligemma/transfers/forkme.py --workdir workdirs/`date '+%m-%d_%H%M'`
|
263 |
+
```
|
264 |
+
|
265 |
+
If you want to use TFDS-based data, check out other transfer configs. Remember to set `TFDS_DATA_DIR` to point to the folder with data (can be GCP data bucket).
|
266 |
+
|
267 |
+
|
268 |
+
## Model Development Contributions
|
269 |
+
|
270 |
+
See the [technical report](https://arxiv.org/abs/2407.07726)'s Appendix.
|
big_vision/configs/proj/paligemma/finetune_paligemma.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
big_vision/configs/proj/paligemma/paligemma.png
ADDED
![]() |
big_vision/configs/proj/paligemma/transfers/activitynet_cap.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""PaliGemma transfer to ActivityNet Video captioning.
|
17 |
+
|
18 |
+
IMPORTANT: This config is based on an unreleased version of DeepMind Video
|
19 |
+
Readers (DMVR). Users can either set up DMVR using the open source code from
|
20 |
+
GitHub (see below for details), or add their own data loader of choice.
|
21 |
+
"""
|
22 |
+
|
23 |
+
import big_vision.configs.common as bvcc
|
24 |
+
from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER
|
25 |
+
|
26 |
+
TEXT_LEN = 64
|
27 |
+
DATASET_NAME = 'activitynet_captions_mr'
|
28 |
+
# Numbers might need to be updated due to wipeout. Current from 2024-04-28
|
29 |
+
SPLIT_SIZE = {'train': 30545, 'valid': 14338, 'test': 13982}
|
30 |
+
|
31 |
+
|
32 |
+
def training_data(res, *, final_split, num_frames=8, stride=None):
|
33 |
+
"""Creates training data config.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
res: The requested image resolution (eg 224).
|
37 |
+
final_split: Train on all train+valid data.
|
38 |
+
num_frames: number of sampled frames per video.
|
39 |
+
stride: stride at which the frames are sampled.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
The ConfigDict for the input section.
|
43 |
+
"""
|
44 |
+
pp = '|'.join([
|
45 |
+
# prepare the frames by decoding, resizing, replicating, sampling:
|
46 |
+
f'video_decode({res})|video_replicate_img({num_frames},{num_frames})',
|
47 |
+
f'video_ensure_shape("image", {(num_frames, res, res, 3)})',
|
48 |
+
# pick one caption at random during training (there is actually just one!)
|
49 |
+
'strfmt("caption en", outkey="prefix")',
|
50 |
+
'video_choice(inkey="caption/string", outkey="suffix")',
|
51 |
+
combine_and_keep_train(TEXT_LEN),
|
52 |
+
])
|
53 |
+
|
54 |
+
c = bvcc.parse_arg('')
|
55 |
+
c.data = {}
|
56 |
+
splits = ['train', 'valid'] if final_split else ['train']
|
57 |
+
raise NotImplementedError('Please implement a video reader of choice!')
|
58 |
+
# For example DMVR https://github.com/google-deepmind/dmvr
|
59 |
+
# The reader should support the following arguments:
|
60 |
+
# - name: Name of the reader.
|
61 |
+
# - dataset_name: Name of the data set.
|
62 |
+
# - split: Data set split.
|
63 |
+
# - num_frames: Number of frames sampled from the video.
|
64 |
+
# - stride: Stride at which the video frames are sampled.
|
65 |
+
# - deterministic_fs: Whether to sample the frames starting at the first
|
66 |
+
# frame or whether an offest should be chosen at random (if there are more
|
67 |
+
# frames than num_frames * stride)
|
68 |
+
# - first_k_shards: Whether to only use the first k shards of the data
|
69 |
+
# (optional but useful for speeding up intermediate evaluations).
|
70 |
+
for split in splits:
|
71 |
+
c.data[split] = SPLIT_SIZE[split]
|
72 |
+
c[split] = {'pp': pp}
|
73 |
+
c[split].data = dict(
|
74 |
+
# PLEASE ADD YOUR READER HERE:
|
75 |
+
name='<add_your_data_loader_here>',
|
76 |
+
dataset_name=DATASET_NAME, split=split,
|
77 |
+
num_frames=num_frames, stride=stride,
|
78 |
+
deterministic_fs=False)
|
79 |
+
return c
|
80 |
+
|
81 |
+
|
82 |
+
def add_eval(c, res, num_frames=8, stride=None): # pylint: disable=unused-argument
|
83 |
+
"""Captioning evaluator."""
|
84 |
+
c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride)
|
85 |
+
|
86 |
+
pp = '|'.join([
|
87 |
+
f'video_decode({res})|video_replicate_img({num_frames},{num_frames})',
|
88 |
+
f'video_ensure_shape("image", {(num_frames, res, res, 3)})',
|
89 |
+
'strfmt("caption en", outkey="prefix")',
|
90 |
+
'strfmt("{example/video_id}[{segment_start}-{segment_end}]", outkey="image/id")',
|
91 |
+
'copy("caption/string", "captions")',
|
92 |
+
combine_and_keep_eval(TEXT_LEN, keep=('image/id', 'captions')),
|
93 |
+
])
|
94 |
+
|
95 |
+
for freq, name, split, first_k_shards, skip_first_eval in [
|
96 |
+
(1/8, 'minitrain', 'train', 2, False), # To gauge memorization.
|
97 |
+
(1/4, 'minival', 'valid', 2, False), # To monitor val progress.
|
98 |
+
(1, 'val', 'valid', None, False), # To tune hparams.
|
99 |
+
(1, 'eval', 'test', None, False), # final metric
|
100 |
+
]:
|
101 |
+
c.evals[f'{DATASET_NAME}/{name}'] = dict(
|
102 |
+
type='proj.paligemma.transfers.coco_caption',
|
103 |
+
pred='decode', pred_kw={'max_decode_len': TEXT_LEN},
|
104 |
+
data={**c_train.train.data, 'split': split,
|
105 |
+
'first_k_shards': first_k_shards,
|
106 |
+
'deterministic_fs': True},
|
107 |
+
log_percent=freq, tokenizer=TOKENIZER,
|
108 |
+
pp_fn=pp, skip_first=skip_first_eval)
|
109 |
+
|
110 |
+
|
111 |
+
def add_eval_pplx(c, res, num_frames=8, stride=None):
|
112 |
+
"""Perplexity evaluator to test runs before implementing the real deal."""
|
113 |
+
c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride)
|
114 |
+
|
115 |
+
for name, split, first_k_shards in [
|
116 |
+
('minitrain', 'train', 2), # To gauge memorization.
|
117 |
+
]:
|
118 |
+
c.evals[f'{DATASET_NAME}/{name}/pplx'] = dict(
|
119 |
+
type='proj.paligemma.perplexity', pred='logits',
|
120 |
+
key='text', shift_labels=True,
|
121 |
+
log_percent=1/8, # Not too cheap, do 10x per run.
|
122 |
+
data={**c_train.train.data, 'split': split,
|
123 |
+
'first_k_shards': first_k_shards,
|
124 |
+
'deterministic_fs': True},
|
125 |
+
pp_fn=c_train.train.pp,
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
def sweep_best(add, arg=None):
|
130 |
+
"""Train with best hyper-params."""
|
131 |
+
c = bvcc.parse_arg(arg, final_split=False)
|
132 |
+
add(lr=1e-5, wd=1e-6, total_epochs=1, **bvcc.arg(freeze_vit=True, res=224, **c))
|
133 |
+
|
134 |
+
|
135 |
+
sweep = sweep_best
|
136 |
+
|
137 |
+
|
138 |
+
def get_config(arg=None):
|
139 |
+
"""Config for training."""
|
140 |
+
c = bvcc.parse_arg(arg, mode='xm', num_frames=16, stride=30, res=224,
|
141 |
+
freeze_vit=False, freeze_llm=False, final_split=False)
|
142 |
+
|
143 |
+
c.input = training_data(
|
144 |
+
c.res, final_split=c.final_split,
|
145 |
+
num_frames=c.num_frames, stride=c.stride)
|
146 |
+
|
147 |
+
c.total_epochs = 3
|
148 |
+
c.input.batch_size = 128
|
149 |
+
c.optax_name = 'scale_by_adam'
|
150 |
+
c.optax = dict(b2=0.999)
|
151 |
+
c.lr = 3e-6
|
152 |
+
c.wd = 3e-7
|
153 |
+
c.grad_clip_norm = 1.0
|
154 |
+
c.label_smoothing = 0.0
|
155 |
+
|
156 |
+
# Learning-rate schedule.
|
157 |
+
sched = dict(decay_type='cosine', warmup_percent=0.05)
|
158 |
+
c.schedule = [
|
159 |
+
('img/.*', None if c.freeze_vit else sched),
|
160 |
+
('llm/.*', None if c.freeze_llm else sched),
|
161 |
+
]
|
162 |
+
|
163 |
+
# Add evaluators.
|
164 |
+
c.evals = {}
|
165 |
+
add_eval(c, c.res, c.num_frames, c.stride)
|
166 |
+
add_eval_pplx(c, c.res, c.num_frames, c.stride)
|
167 |
+
|
168 |
+
# Model section.
|
169 |
+
c.model_name = 'proj.paligemma.paligemma'
|
170 |
+
c.model = {}
|
171 |
+
c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
|
172 |
+
c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0)
|
173 |
+
c.model_init = f'pt_{c.res}'
|
174 |
+
|
175 |
+
# FSDP strategy.
|
176 |
+
c.mesh = [('data', -1)]
|
177 |
+
c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
178 |
+
c.sharding_rules = [('act_batch', ('data',))]
|
179 |
+
|
180 |
+
for split in c.input.data.keys():
|
181 |
+
c.input[split].shuffle_buffer_size = 10_000
|
182 |
+
c.log_training_steps = 50
|
183 |
+
c.ckpt_steps = 1_000
|
184 |
+
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops',
|
185 |
+
'proj.paligemma.video']
|
186 |
+
|
187 |
+
# Update configs for quicker local runs and avoid swapping.
|
188 |
+
if c.mode in ('runlocal', 'mock'):
|
189 |
+
for split in c.input.data.keys():
|
190 |
+
c.input[split].shuffle_buffer_size = None
|
191 |
+
for ev in c.evals.values():
|
192 |
+
ev.data.first_k_shards = 1
|
193 |
+
|
194 |
+
if c.mode == 'runlocal':
|
195 |
+
c.log_training_steps = 1
|
196 |
+
c.input.batch_size = 2
|
197 |
+
|
198 |
+
c.seed = 0
|
199 |
+
return c
|
200 |
+
|
201 |
+
|
202 |
+
def metrics(arg=None): # pylint: disable=unused-argument
|
203 |
+
m = ['training_loss']
|
204 |
+
for split in ('minitrain', 'minival', 'val', 'eval'):
|
205 |
+
m.append(('epoch', f'{DATASET_NAME}/{split}/cider'))
|
206 |
+
for split in ('minitrain', 'minival'):
|
207 |
+
m.append(('epoch', f'{DATASET_NAME}/{split}/pplx/avg'))
|
208 |
+
return m
|
209 |
+
|
big_vision/configs/proj/paligemma/transfers/activitynet_qa.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""PaliGemma transfer to ActivityNet Video QA.
|
17 |
+
|
18 |
+
IMPORTANT: This config is based on an unreleased version of DeepMind Video
|
19 |
+
Readers (DMVR). Users can either set up DMVR using the open source code from
|
20 |
+
GitHub (see below for details), or add their own data loader of choice.
|
21 |
+
"""
|
22 |
+
|
23 |
+
import big_vision.configs.common as bvcc
|
24 |
+
from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER
|
25 |
+
|
26 |
+
TEXT_LEN = 64
|
27 |
+
DATASET_NAME = 'activitynet_qa'
|
28 |
+
# Numbers might need to be updated due to wipeout. Current from 2024-04-28
|
29 |
+
SPLIT_SIZE = {'train': 27610, 'valid': 15760, 'test': 6900}
|
30 |
+
|
31 |
+
|
32 |
+
def training_data(res, *, final_split, num_frames, stride):
|
33 |
+
"""Creates training data config.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
res: The requested image resolution (eg 224).
|
37 |
+
final_split: Train on all train+valid data.
|
38 |
+
num_frames: number of sampled frames per video.
|
39 |
+
stride: stride at which the frames are sampled.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
The ConfigDict for the input section.
|
43 |
+
"""
|
44 |
+
pp = '|'.join([
|
45 |
+
# prepare the frames by decoding, resizing, replicating, sampling:
|
46 |
+
f'video_decode({res})|video_replicate_img({num_frames},{num_frames})',
|
47 |
+
f'video_ensure_shape("image", {(num_frames, res, res, 3)})',
|
48 |
+
# only one question/answer per example.
|
49 |
+
'reshape([], key="question")|reshape([], key="answer")',
|
50 |
+
'strfmt("answer en {question}", outkey="prefix")',
|
51 |
+
'copy("answer", "suffix")',
|
52 |
+
combine_and_keep_train(TEXT_LEN),
|
53 |
+
])
|
54 |
+
|
55 |
+
c = bvcc.parse_arg('')
|
56 |
+
c.data = {}
|
57 |
+
splits = ['train', 'valid'] if final_split else ['train']
|
58 |
+
raise NotImplementedError('Please implement a video reader of choice!')
|
59 |
+
# For example DMVR https://github.com/google-deepmind/dmvr
|
60 |
+
# The reader should support the following arguments:
|
61 |
+
# - name: Name of the reader.
|
62 |
+
# - dataset_name: Name of the data set.
|
63 |
+
# - split: Data set split.
|
64 |
+
# - num_frames: Number of frames sampled from the video.
|
65 |
+
# - stride: Stride at which the video frames are sampled.
|
66 |
+
# - deterministic_fs: Whether to sample the frames starting at the first
|
67 |
+
# frame or whether an offest should be chosen at random (if there are more
|
68 |
+
# frames than num_frames * stride)
|
69 |
+
# - first_k_shards: Whether to only use the first k shards of the data
|
70 |
+
# (optional but useful for speeding up intermediate evaluations).
|
71 |
+
for split in splits:
|
72 |
+
c.data[split] = SPLIT_SIZE[split]
|
73 |
+
c[split] = {'pp': pp}
|
74 |
+
c[split].data = dict(
|
75 |
+
# PLEASE ADD YOUR READER HERE:
|
76 |
+
name='<add_your_data_loader_here>',
|
77 |
+
dataset_name=DATASET_NAME, split=split,
|
78 |
+
num_frames=num_frames, stride=stride,
|
79 |
+
deterministic_fs=False)
|
80 |
+
return c
|
81 |
+
|
82 |
+
|
83 |
+
def add_eval(c, res, num_frames, stride): # pylint: disable=unused-argument
|
84 |
+
"""QA evaluator."""
|
85 |
+
c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride)
|
86 |
+
|
87 |
+
pp = '|'.join([
|
88 |
+
# prepare the frames by decoding, resizing, replicating, sampling:
|
89 |
+
f'video_decode({res})|video_replicate_img({num_frames},{num_frames})',
|
90 |
+
f'video_ensure_shape("image", {(num_frames, res, res, 3)})',
|
91 |
+
# only one question/answer per example.
|
92 |
+
'reshape([], key="question")|reshape([], key="answer")',
|
93 |
+
'strfmt("answer en {question}", outkey="prefix")',
|
94 |
+
'strfmt("{id}#{example/video_id}: {question}", "question_id")',
|
95 |
+
combine_and_keep_eval(TEXT_LEN, keep=('question_id', 'answer')),
|
96 |
+
])
|
97 |
+
|
98 |
+
for freq, name, split, first_k_shards, skip_first_eval in [
|
99 |
+
(1/8, 'minitrain', 'train', 2, False), # To gauge memorization.
|
100 |
+
(1/4, 'minival', 'valid', 2, False), # To monitor val progress.
|
101 |
+
(1, 'val', 'valid', None, True), # To tune hparams.
|
102 |
+
(1, 'eval', 'test', None, True), # final metric
|
103 |
+
]:
|
104 |
+
c.evals[f'activitynet_qa/{name}'] = dict(
|
105 |
+
type='proj.paligemma.transfers.vqa',
|
106 |
+
pred='decode', pred_kw={'max_decode_len': TEXT_LEN},
|
107 |
+
data={**c_train.train.data, 'split': split,
|
108 |
+
'first_k_shards': first_k_shards,
|
109 |
+
'deterministic_fs': True},
|
110 |
+
log_percent=freq, tokenizer=TOKENIZER,
|
111 |
+
pp_fn=pp, skip_first=skip_first_eval)
|
112 |
+
|
113 |
+
|
114 |
+
def add_eval_pplx(c, res, num_frames, stride):
|
115 |
+
"""Perplexity evaluator to test runs before implementing the real deal."""
|
116 |
+
c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride)
|
117 |
+
|
118 |
+
for name, split, first_k_shards in [
|
119 |
+
('minitrain', 'train', 2), # To gauge memorization.
|
120 |
+
('minival', 'valid', 2),
|
121 |
+
]:
|
122 |
+
c.evals[f'activitynet_qa/{name}/pplx'] = dict(
|
123 |
+
type='proj.paligemma.perplexity', pred='logits',
|
124 |
+
key='text', shift_labels=True,
|
125 |
+
log_percent=1/8, # Not too cheap, do 10x per run.
|
126 |
+
data={**c_train.train.data, 'split': split,
|
127 |
+
'first_k_shards': first_k_shards,
|
128 |
+
'deterministic_fs': True},
|
129 |
+
pp_fn=c_train.train.pp,
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
def sweep_best(add, arg=None):
|
134 |
+
"""Train with best hyper-params."""
|
135 |
+
c = bvcc.parse_arg(arg, final_split=False)
|
136 |
+
add(lr=1e-5, wd=1e-6, total_epochs=1, **bvcc.arg(num_frames=16, stride=70, res=224, **c))
|
137 |
+
|
138 |
+
|
139 |
+
sweep = sweep_best
|
140 |
+
|
141 |
+
|
142 |
+
def get_config(arg=None):
|
143 |
+
"""Config for training."""
|
144 |
+
c = bvcc.parse_arg(arg, mode='xm', num_frames=16, stride=70, res=224,
|
145 |
+
freeze_vit=False, freeze_llm=False, final_split=False)
|
146 |
+
|
147 |
+
c.input = training_data(
|
148 |
+
c.res, final_split=c.final_split,
|
149 |
+
num_frames=c.num_frames, stride=c.stride)
|
150 |
+
|
151 |
+
c.total_epochs = 3
|
152 |
+
c.input.batch_size = 128
|
153 |
+
c.optax_name = 'scale_by_adam'
|
154 |
+
c.optax = dict(b2=0.999)
|
155 |
+
c.lr = 1e-5
|
156 |
+
c.wd = 1e-6
|
157 |
+
c.grad_clip_norm = 1.0
|
158 |
+
c.label_smoothing = 0.0
|
159 |
+
|
160 |
+
# Learning-rate schedule.
|
161 |
+
sched = dict(decay_type='cosine', warmup_percent=0.05)
|
162 |
+
c.schedule = [
|
163 |
+
('img/.*', None if c.freeze_vit else sched),
|
164 |
+
('llm/.*', None if c.freeze_llm else sched),
|
165 |
+
]
|
166 |
+
|
167 |
+
# Add evaluators.
|
168 |
+
c.evals = {}
|
169 |
+
add_eval(c, c.res, c.num_frames, c.stride)
|
170 |
+
add_eval_pplx(c, c.res, c.num_frames, c.stride)
|
171 |
+
|
172 |
+
# Model section.
|
173 |
+
c.model_name = 'proj.paligemma.paligemma'
|
174 |
+
c.model = {}
|
175 |
+
c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
|
176 |
+
c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0)
|
177 |
+
c.model_init = f'pt_{c.res}'
|
178 |
+
|
179 |
+
# FSDP strategy.
|
180 |
+
c.mesh = [('data', -1)]
|
181 |
+
c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
182 |
+
c.sharding_rules = [('act_batch', ('data',))]
|
183 |
+
|
184 |
+
for split in c.input.data.keys():
|
185 |
+
c.input[split].shuffle_buffer_size = 10_000
|
186 |
+
c.log_training_steps = 50
|
187 |
+
c.ckpt_steps = 1_000
|
188 |
+
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops',
|
189 |
+
'proj.paligemma.video']
|
190 |
+
|
191 |
+
# Update configs for quicker local runs and avoid swapping.
|
192 |
+
if c.mode in ('runlocal', 'mock'):
|
193 |
+
for split in c.input.data.keys():
|
194 |
+
c.input[split].shuffle_buffer_size = None
|
195 |
+
for ev in c.evals.values():
|
196 |
+
ev.data.first_k_shards = 1
|
197 |
+
|
198 |
+
if c.mode == 'runlocal':
|
199 |
+
c.log_training_steps = 1
|
200 |
+
c.input.batch_size = 2
|
201 |
+
|
202 |
+
c.seed = 0
|
203 |
+
return c
|
204 |
+
|
205 |
+
|
206 |
+
def metrics(arg=None): # pylint: disable=unused-argument
|
207 |
+
m = ['training_loss']
|
208 |
+
for split in ('minitrain', 'minival', 'val', 'eval'):
|
209 |
+
m.append(('epoch', f'{DATASET_NAME}/{split}/acc'))
|
210 |
+
for split in ('minitrain', 'minival'):
|
211 |
+
m.append(('epoch', f'{DATASET_NAME}/{split}/pplx/avg'))
|
212 |
+
return m
|
213 |
+
|
big_vision/configs/proj/paligemma/transfers/ai2d.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""PaliGemma transfer to AI2D.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import big_vision.configs.common as bvcc
|
20 |
+
from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER
|
21 |
+
|
22 |
+
PREFIX = 'answer en '
|
23 |
+
PROMPT = 'choose from:'
|
24 |
+
PROMPT_SEP = ' \t '
|
25 |
+
|
26 |
+
|
27 |
+
def training_data(res, final_split, text_len=128):
|
28 |
+
"""Creates training data config.
|
29 |
+
|
30 |
+
See (internal link)
|
31 |
+
You can add more arguments beside `res`, but give them good defaults.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
res: The requested image resolution (eg 224).
|
35 |
+
final_split: whether to use all train data.
|
36 |
+
text_len: sequence length
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
The ConfigDict for the input section.
|
40 |
+
"""
|
41 |
+
c = bvcc.parse_arg('') # Just make a configdict without extra import.
|
42 |
+
c.data = dict(
|
43 |
+
name='ai2d',
|
44 |
+
# 12k training examples.
|
45 |
+
split='train' if final_split else 'train[:-1024]',
|
46 |
+
)
|
47 |
+
c.pp = '|'.join([
|
48 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
49 |
+
f'strjoin("{PROMPT_SEP}", inkey="possible_answers", outkey="ansstr")',
|
50 |
+
f'strfmt("{PREFIX} {{question}} {PROMPT} {{ansstr}}", outkey="prefix")',
|
51 |
+
'copy(inkey="answer", outkey="suffix")',
|
52 |
+
combine_and_keep_train(text_len),
|
53 |
+
])
|
54 |
+
return c
|
55 |
+
|
56 |
+
|
57 |
+
def add_eval(c, res, text_len=128, **kw):
|
58 |
+
"""AI2D evaluators."""
|
59 |
+
pp = '|'.join([
|
60 |
+
f'decode|resize({res})|value_range(-1, 1)',
|
61 |
+
f'strjoin("{PROMPT_SEP}", inkey="possible_answers", outkey="ansstr")',
|
62 |
+
f'strfmt("{PREFIX} {{question}} {PROMPT} {{ansstr}}", outkey="prefix")',
|
63 |
+
'copy(inkey="id",outkey="question_id")',
|
64 |
+
combine_and_keep_eval(text_len, keep=('answer', 'question_id')),
|
65 |
+
])
|
66 |
+
|
67 |
+
for name, split in [
|
68 |
+
('minitrain', 'train[:1024]'), # To gauge memorization.
|
69 |
+
('minival', 'train[-1024:]'), # To tune hparams.
|
70 |
+
('eval', 'test'), # To compute final publishable scores.
|
71 |
+
]:
|
72 |
+
c.evals[f'ai2d/{name}'] = dict(
|
73 |
+
type='proj.paligemma.transfers.vqa',
|
74 |
+
pred='decode', pred_kw={'max_decode_len': text_len},
|
75 |
+
outfile=f'{{workdir}}/ai2d_{name}.json',
|
76 |
+
to_lower=False, # Model sees options in prompt and can match the case.
|
77 |
+
data={**training_data(res, True, text_len).data, 'split': split},
|
78 |
+
log_percent=1/8, tokenizer=TOKENIZER, pp_fn=pp)
|
79 |
+
c.evals[f'ai2d/{name}'].update(kw)
|
80 |
+
|
81 |
+
|
82 |
+
def add_eval_pplx(c, res, text_len=128):
|
83 |
+
"""Perplexity evaluator to test runs before implementing the real deal."""
|
84 |
+
c_train = training_data(res, True, text_len) # Use mostly same settings as training.
|
85 |
+
|
86 |
+
for name, split in [
|
87 |
+
('minitrain', 'train[:1024]'), # To gauge memorization.
|
88 |
+
('minival', 'train[-1024:]'), # To tune hparams.
|
89 |
+
('eval', 'test'), # To compute final publishable scores.
|
90 |
+
]:
|
91 |
+
c.evals[f'ai2d/{name}/pplx'] = dict(
|
92 |
+
type='proj.paligemma.perplexity', pred='logits',
|
93 |
+
key='text', shift_labels=True,
|
94 |
+
log_percent=1/8,
|
95 |
+
data={**c_train.data, 'split': split},
|
96 |
+
pp_fn=c_train.pp,
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
def sweep_best(add, arg=None):
|
101 |
+
"""Train with best hyper-params."""
|
102 |
+
c = bvcc.parse_arg(arg, final_split=False)
|
103 |
+
add(lr=1e-5, wd=1e-6, total_epochs=10, **bvcc.arg(res=224, **c))
|
104 |
+
add(lr=1e-5, wd=1e-6, total_epochs=10, **bvcc.arg(res=448, **c))
|
105 |
+
# 896 was not better than 448 ((internal link)).
|
106 |
+
|
107 |
+
|
108 |
+
sweep = sweep_best # Choose which sweep to run.
|
109 |
+
|
110 |
+
|
111 |
+
def get_config(arg=None):
|
112 |
+
"""Config for training."""
|
113 |
+
c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False)
|
114 |
+
|
115 |
+
c.input = training_data(c.res, final_split=c.final_split)
|
116 |
+
|
117 |
+
# Instead of epochs, you can also use `total_examples` or `total_steps`.
|
118 |
+
c.total_epochs = 10
|
119 |
+
c.input.batch_size = 256
|
120 |
+
c.optax_name = 'scale_by_adam'
|
121 |
+
c.optax = dict(b2=0.999)
|
122 |
+
c.lr = 1e-5
|
123 |
+
c.wd = 1e-5 * 0.1
|
124 |
+
c.grad_clip_norm = 1.0
|
125 |
+
c.label_smoothing = 0.0
|
126 |
+
c.schedule = dict(decay_type='cosine', warmup_percent=0.05)
|
127 |
+
|
128 |
+
# Add evaluators.
|
129 |
+
c.evals = {}
|
130 |
+
add_eval(c, c.res, batch_size=1024)
|
131 |
+
add_eval_pplx(c, c.res)
|
132 |
+
|
133 |
+
# Model section.
|
134 |
+
c.model_name = 'proj.paligemma.paligemma'
|
135 |
+
c.model = {}
|
136 |
+
c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
|
137 |
+
c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0)
|
138 |
+
c.model_init = f'pt_{c.res}'
|
139 |
+
|
140 |
+
# FSDP strategy.
|
141 |
+
c.mesh = [('data', -1)]
|
142 |
+
c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
143 |
+
c.sharding_rules = [('act_batch', ('data',))]
|
144 |
+
|
145 |
+
# These probably do not need any change/tuning
|
146 |
+
c.input.shuffle_buffer_size = 50_000
|
147 |
+
c.log_training_steps = 50
|
148 |
+
c.ckpt_steps = 1_000
|
149 |
+
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops']
|
150 |
+
|
151 |
+
# Update configs for quicker local runs and avoid swapping.
|
152 |
+
if c.mode in ('runlocal', 'mock'):
|
153 |
+
c.input.shuffle_buffer_size = None
|
154 |
+
for ev in c.evals.values():
|
155 |
+
ev.data.split = ev.data.split.split('[')[0] + '[:16]'
|
156 |
+
|
157 |
+
if c.mode == 'runlocal':
|
158 |
+
c.log_training_steps = 1
|
159 |
+
c.input.batch_size = 2
|
160 |
+
|
161 |
+
c.seed = 0
|
162 |
+
return c
|
163 |
+
|
164 |
+
|
165 |
+
def metrics(arg=None): # pylint: disable=unused-argument
|
166 |
+
m = ['training_loss']
|
167 |
+
for split in ('eval', 'minival', 'minitrain'):
|
168 |
+
m.append(f'ai2d/{split}/pplx/avg')
|
169 |
+
m.append(f'ai2d/{split}/acc')
|
170 |
+
return m
|
big_vision/configs/proj/paligemma/transfers/aokvqa_da.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""PaliGemma transfer to A-OK-VQA using Direct Answer mode.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import big_vision.configs.common as bvcc
|
20 |
+
from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER
|
21 |
+
|
22 |
+
|
23 |
+
def training_data(res, final_split, text_len=32):
|
24 |
+
"""Creates training data config.
|
25 |
+
|
26 |
+
See (internal link)
|
27 |
+
You can add more arguments beside `res`, but give them good defaults.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
res: The requested image resolution (eg 224).
|
31 |
+
final_split: Whether to use train and validation data.
|
32 |
+
text_len: sequence length
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
The ConfigDict for the input section.
|
36 |
+
"""
|
37 |
+
c = bvcc.parse_arg('') # Just make a configdict without extra import.
|
38 |
+
c.data = dict(
|
39 |
+
name='aokvqa',
|
40 |
+
split='train + val' if final_split else 'train',
|
41 |
+
)
|
42 |
+
c.pp = '|'.join([
|
43 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
44 |
+
'strfmt("answer en {question}", outkey="prefix")',
|
45 |
+
'choice_no_replacement(inkey="direct_answers", outkey="suffix")',
|
46 |
+
combine_and_keep_train(text_len),
|
47 |
+
])
|
48 |
+
return c
|
49 |
+
|
50 |
+
|
51 |
+
def add_eval(c, res, text_len=32, **kw):
|
52 |
+
"""We can use the normal VQA evaluators."""
|
53 |
+
pp = '|'.join([
|
54 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
55 |
+
'strfmt("answer en {question}", outkey="prefix")',
|
56 |
+
'copy(inkey="direct_answers", outkey="answers")',
|
57 |
+
combine_and_keep_eval(text_len, keep=('answers', 'question_id')),
|
58 |
+
])
|
59 |
+
|
60 |
+
for freq, name, split in [
|
61 |
+
(1/4, 'minitrain', 'train[:5%]'), # To gauge memorization.
|
62 |
+
(1/4, 'eval', 'val'), # To tune hparams.
|
63 |
+
(1.0, 'test', 'test'), # To compute final predictions.
|
64 |
+
]:
|
65 |
+
c.evals[f'aokvqa_da/{name}'] = dict(
|
66 |
+
type='proj.paligemma.transfers.vqa',
|
67 |
+
pred='decode', pred_kw={'max_decode_len': text_len},
|
68 |
+
outfile=f'{{workdir}}/aokvqa_da_{name}.json',
|
69 |
+
data={**training_data(res, True, text_len).data, 'split': split},
|
70 |
+
log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp)
|
71 |
+
c.evals[f'aokvqa/{name}'].update(kw)
|
72 |
+
|
73 |
+
|
74 |
+
def add_eval_pplx(c, res, text_len=32):
|
75 |
+
"""Perplexity evaluator to test runs before implementing the real deal."""
|
76 |
+
c_train = training_data(res, True, text_len) # Use mostly same settings as training.
|
77 |
+
|
78 |
+
for name, split in [
|
79 |
+
('minitrain', 'train[:5%]'), # To gauge memorization.
|
80 |
+
('eval', 'val'), # To tune hparams.
|
81 |
+
]:
|
82 |
+
c.evals[f'aokvqa_da/{name}/pplx'] = dict(
|
83 |
+
type='proj.paligemma.perplexity', pred='logits',
|
84 |
+
key='text', shift_labels=True,
|
85 |
+
log_percent=0.05, # Eval ~20x per run; it's cheap.
|
86 |
+
data={**c_train.data, 'split': split},
|
87 |
+
pp_fn=c_train.pp,
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def sweep_best(add, arg=None):
|
92 |
+
"""Train with best hyper-params."""
|
93 |
+
c = bvcc.parse_arg(arg, final_split=False)
|
94 |
+
add(lr=5e-6, wd=0.0, **bvcc.arg(res=224, **c))
|
95 |
+
add(lr=5e-6, wd=0.0, **bvcc.arg(res=448, **c))
|
96 |
+
# not better: add(lr=5e-6, wd=0.0, **bvcc.arg(res=896, **c))
|
97 |
+
|
98 |
+
|
99 |
+
sweep = sweep_best # Choose which sweep to run.
|
100 |
+
|
101 |
+
|
102 |
+
def get_config(arg=None):
|
103 |
+
"""Config for training."""
|
104 |
+
c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False)
|
105 |
+
|
106 |
+
c.input = training_data(c.res, c.final_split)
|
107 |
+
|
108 |
+
# Instead of epochs, you can also use `total_examples` or `total_steps`.
|
109 |
+
c.total_epochs = 10
|
110 |
+
c.input.batch_size = 128
|
111 |
+
c.optax_name = 'scale_by_adam'
|
112 |
+
c.optax = dict(b2=0.999)
|
113 |
+
c.lr = 5e-6
|
114 |
+
c.wd = 0.0
|
115 |
+
c.grad_clip_norm = 1.0
|
116 |
+
c.label_smoothing = 0.0
|
117 |
+
c.schedule = dict(decay_type='cosine', warmup_percent=0.05)
|
118 |
+
|
119 |
+
# Add evaluators.
|
120 |
+
c.evals = {}
|
121 |
+
add_eval(c, c.res, batch_size=256)
|
122 |
+
add_eval_pplx(c, c.res)
|
123 |
+
|
124 |
+
# Model section.
|
125 |
+
c.model_name = 'proj.paligemma.paligemma'
|
126 |
+
c.model = {}
|
127 |
+
c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
|
128 |
+
c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0)
|
129 |
+
c.model_init = f'pt_{c.res}'
|
130 |
+
|
131 |
+
# FSDP strategy.
|
132 |
+
c.mesh = [('data', -1)]
|
133 |
+
c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
134 |
+
c.sharding_rules = [('act_batch', ('data',))]
|
135 |
+
|
136 |
+
# These probably do not need any change/tuning
|
137 |
+
c.input.shuffle_buffer_size = 50_000
|
138 |
+
c.log_training_steps = 50
|
139 |
+
c.ckpt_steps = 1_000
|
140 |
+
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops']
|
141 |
+
|
142 |
+
# Update configs for quicker local runs and avoid swapping.
|
143 |
+
if c.mode in ('runlocal', 'mock'):
|
144 |
+
c.input.shuffle_buffer_size = None
|
145 |
+
for ev in c.evals.values():
|
146 |
+
ev.data.split = ev.data.split.split('[')[0] + '[:16]'
|
147 |
+
|
148 |
+
if c.mode == 'runlocal':
|
149 |
+
c.log_training_steps = 1
|
150 |
+
c.input.batch_size = 2
|
151 |
+
|
152 |
+
c.seed = 0
|
153 |
+
return c
|
154 |
+
|
155 |
+
|
156 |
+
def metrics(arg=None): # pylint: disable=unused-argument
|
157 |
+
m = ['training_loss']
|
158 |
+
for split in ('eval', 'minival', 'minitrain'):
|
159 |
+
m.append(f'aokvqa/{split}/pplx/avg')
|
160 |
+
m.append(f'aokvqa/{split}/acc')
|
161 |
+
return m
|
big_vision/configs/proj/paligemma/transfers/aokvqa_mc.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""PaliGemma transfer to A-OK-VQA using multiple choice answers.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import big_vision.configs.common as bvcc
|
20 |
+
from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER
|
21 |
+
|
22 |
+
|
23 |
+
PREFIX = 'answer en '
|
24 |
+
PROMPT = 'choose from:'
|
25 |
+
PROMPT_SEP = ' \t '
|
26 |
+
|
27 |
+
|
28 |
+
def training_data(res, final_split, text_len=128):
|
29 |
+
"""Creates training data config.
|
30 |
+
|
31 |
+
See (internal link)
|
32 |
+
You can add more arguments beside `res`, but give them good defaults.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
res: The requested image resolution (eg 224).
|
36 |
+
final_split: Whether to use train and validation data.
|
37 |
+
text_len: sequence length
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
The ConfigDict for the input section.
|
41 |
+
"""
|
42 |
+
c = bvcc.parse_arg('') # Just make a configdict without extra import.
|
43 |
+
c.data = dict(
|
44 |
+
name='aokvqa',
|
45 |
+
split='train + val' if final_split else 'train',
|
46 |
+
)
|
47 |
+
c.pp = '|'.join([
|
48 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
49 |
+
f'strjoin("{PROMPT_SEP}", inkey="multiple_choice_possible_answers", outkey="ansstr")',
|
50 |
+
f'strfmt("{PREFIX} {{question}} {PROMPT} {{ansstr}}", outkey="prefix")',
|
51 |
+
'getidx(inkey="multiple_choice_possible_answers", index_key="multiple_choice_correct_idx", outkey="suffix")',
|
52 |
+
combine_and_keep_train(text_len),
|
53 |
+
])
|
54 |
+
return c
|
55 |
+
|
56 |
+
|
57 |
+
def add_eval(c, res, text_len=128, **kw):
|
58 |
+
"""VQAv2 evaluators."""
|
59 |
+
pp = '|'.join([
|
60 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
61 |
+
f'strjoin("{PROMPT_SEP}", inkey="multiple_choice_possible_answers", outkey="ansstr")',
|
62 |
+
f'strfmt("{PREFIX} {{question}} {PROMPT} {{ansstr}}", outkey="prefix")',
|
63 |
+
'getidx(inkey="multiple_choice_possible_answers", index_key="multiple_choice_correct_idx", outkey="answer")',
|
64 |
+
combine_and_keep_eval(text_len, keep=('answer', 'question_id')),
|
65 |
+
])
|
66 |
+
|
67 |
+
for freq, name, split in [
|
68 |
+
(1/4, 'minitrain', 'train[:5%]'), # To gauge memorization.
|
69 |
+
(1/4, 'eval', 'val'), # To tune hparams.
|
70 |
+
(1.0, 'test', 'test'), # To compute final predictions.
|
71 |
+
]:
|
72 |
+
c.evals[f'aokvqa_mc/{name}'] = dict(
|
73 |
+
type='proj.paligemma.transfers.vqa',
|
74 |
+
pred='decode', pred_kw={'max_decode_len': text_len},
|
75 |
+
outfile=f'{{workdir}}/aokvqa_mc_{name}.json',
|
76 |
+
data={**training_data(res, True, text_len).data, 'split': split},
|
77 |
+
log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp)
|
78 |
+
c.evals[f'aokvqa/{name}'].update(kw)
|
79 |
+
|
80 |
+
|
81 |
+
def add_eval_pplx(c, res, text_len=128):
|
82 |
+
"""Perplexity evaluator to test runs before implementing the real deal."""
|
83 |
+
c_train = training_data(res, True, text_len) # Use mostly same settings as training.
|
84 |
+
|
85 |
+
for name, split in [
|
86 |
+
('minitrain', 'train[:5%]'), # To gauge memorization.
|
87 |
+
('eval', 'val'), # To tune hparams.
|
88 |
+
('test', 'test'), # To compute final predictions.
|
89 |
+
]:
|
90 |
+
c.evals[f'aokvqa_mc/{name}/pplx'] = dict(
|
91 |
+
type='proj.paligemma.perplexity', pred='logits',
|
92 |
+
key='text', shift_labels=True,
|
93 |
+
log_percent=0.05, # Eval ~20x per run; it's cheap.
|
94 |
+
data={**c_train.data, 'split': split},
|
95 |
+
pp_fn=c_train.pp,
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def sweep_best(add, arg=None):
|
100 |
+
"""Train with best hyper-params."""
|
101 |
+
c = bvcc.parse_arg(arg, final_split=False)
|
102 |
+
add(lr=5e-6, wd=0.0, **bvcc.arg(res=224, **c))
|
103 |
+
add(lr=5e-6, wd=0.0, **bvcc.arg(res=448, **c))
|
104 |
+
# add(lr=5e-6, wd=0.0, **bvcc.arg(res=896, **c))
|
105 |
+
|
106 |
+
|
107 |
+
sweep = sweep_best # Choose which sweep to run.
|
108 |
+
|
109 |
+
|
110 |
+
def get_config(arg=None):
|
111 |
+
"""Config for training."""
|
112 |
+
c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False)
|
113 |
+
|
114 |
+
c.input = training_data(c.res, c.final_split)
|
115 |
+
|
116 |
+
# Instead of epochs, you can also use `total_examples` or `total_steps`.
|
117 |
+
c.total_epochs = 15
|
118 |
+
c.input.batch_size = 128
|
119 |
+
c.optax_name = 'scale_by_adam'
|
120 |
+
c.optax = dict(b2=0.999)
|
121 |
+
c.lr = 5e-6
|
122 |
+
c.wd = 0.0
|
123 |
+
c.grad_clip_norm = 1.0
|
124 |
+
c.label_smoothing = 0.0
|
125 |
+
c.schedule = dict(decay_type='cosine', warmup_percent=0.05)
|
126 |
+
|
127 |
+
# Add evaluators.
|
128 |
+
c.evals = {}
|
129 |
+
add_eval(c, c.res, batch_size=256)
|
130 |
+
add_eval_pplx(c, c.res)
|
131 |
+
|
132 |
+
# Model section.
|
133 |
+
c.model_name = 'proj.paligemma.paligemma'
|
134 |
+
c.model = {}
|
135 |
+
c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
|
136 |
+
c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0)
|
137 |
+
c.model_init = f'pt_{c.res}'
|
138 |
+
|
139 |
+
# FSDP strategy.
|
140 |
+
c.mesh = [('data', -1)]
|
141 |
+
c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
142 |
+
c.sharding_rules = [('act_batch', ('data',))]
|
143 |
+
|
144 |
+
# These probably do not need any change/tuning
|
145 |
+
c.input.shuffle_buffer_size = 50_000
|
146 |
+
c.log_training_steps = 50
|
147 |
+
c.ckpt_steps = 1_000
|
148 |
+
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops']
|
149 |
+
|
150 |
+
# Update configs for quicker local runs and avoid swapping.
|
151 |
+
if c.mode in ('runlocal', 'mock'):
|
152 |
+
c.input.shuffle_buffer_size = None
|
153 |
+
for ev in c.evals.values():
|
154 |
+
ev.data.split = ev.data.split.split('[')[0] + '[:16]'
|
155 |
+
|
156 |
+
if c.mode == 'runlocal':
|
157 |
+
c.log_training_steps = 1
|
158 |
+
c.input.batch_size = 2
|
159 |
+
|
160 |
+
c.seed = 0
|
161 |
+
return c
|
162 |
+
|
163 |
+
|
164 |
+
def metrics(arg=None): # pylint: disable=unused-argument
|
165 |
+
m = ['training_loss']
|
166 |
+
for split in ('eval', 'minival', 'minitrain'):
|
167 |
+
m.append(f'aokvqa/{split}/pplx/avg')
|
168 |
+
m.append(f'aokvqa/{split}/acc')
|
169 |
+
return m
|
big_vision/configs/proj/paligemma/transfers/chartqa.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Big Vision Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint: disable=line-too-long
|
16 |
+
r"""PaliGemma transfer to chartqa.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import big_vision.configs.common as bvcc
|
20 |
+
from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER
|
21 |
+
|
22 |
+
_DATASETS = ('chartqa/human', 'chartqa/augmented')
|
23 |
+
# We use the true dataset sizes from https://arxiv.org/pdf/2203.10244.pdf.
|
24 |
+
_WEIGHTS = (7_398, 20_901)
|
25 |
+
|
26 |
+
|
27 |
+
def training_data(res, *, final_split=False, text_len=48):
|
28 |
+
"""Creates training data config.
|
29 |
+
|
30 |
+
See (internal link)
|
31 |
+
You can add more arguments beside `res`, but give them good defaults.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
res: The requested image resolution (eg 224).
|
35 |
+
final_split: Train on all train+val data.
|
36 |
+
text_len: sequence length.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
The ConfigDict for the input section.
|
40 |
+
"""
|
41 |
+
c = bvcc.parse_arg('') # Just make a configdict without extra import.
|
42 |
+
pp = '|'.join([
|
43 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
44 |
+
'copy(inkey="question", outkey="prefix")',
|
45 |
+
'copy(inkey="answer", outkey="suffix")',
|
46 |
+
combine_and_keep_train(text_len),
|
47 |
+
])
|
48 |
+
c.data = {ds: weight for ds, weight in zip(_DATASETS, _WEIGHTS)}
|
49 |
+
for ds in c.data:
|
50 |
+
c[ds] = dict(
|
51 |
+
shuffle_buffer_size=50_000,
|
52 |
+
pp=pp,
|
53 |
+
data=dict(
|
54 |
+
name=ds,
|
55 |
+
split='train+val' if final_split else 'train',
|
56 |
+
),
|
57 |
+
)
|
58 |
+
return c
|
59 |
+
|
60 |
+
|
61 |
+
def add_eval(c, res, text_len=48, **kw):
|
62 |
+
"""Add eval configs."""
|
63 |
+
c_train = training_data(res, final_split=True, text_len=text_len)
|
64 |
+
|
65 |
+
pp_eval = '|'.join([
|
66 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
67 |
+
'copy(inkey="question", outkey="prefix")',
|
68 |
+
combine_and_keep_eval(text_len, keep=('answer', 'question_id')),
|
69 |
+
])
|
70 |
+
|
71 |
+
for name, split in [
|
72 |
+
('minitrain', 'train[:5%]'),
|
73 |
+
('minival', 'val'),
|
74 |
+
('eval', 'test'),
|
75 |
+
]:
|
76 |
+
for ds in _DATASETS:
|
77 |
+
c.evals[f'{ds}/{name}'] = dict(
|
78 |
+
type='proj.paligemma.transfers.chartqa',
|
79 |
+
pred='decode', pred_kw={'max_decode_len': text_len},
|
80 |
+
to_lower=True,
|
81 |
+
outfile=f'{{workdir}}/{ds.replace("/", "_")}_{name}.json',
|
82 |
+
data={**c_train[ds].data, 'split': split},
|
83 |
+
log_percent=0.1, tokenizer=TOKENIZER, pp_fn=pp_eval)
|
84 |
+
c.evals[f'{ds}/{name}'].update(kw)
|
85 |
+
|
86 |
+
|
87 |
+
def add_eval_pplx(c, res, text_len=48):
|
88 |
+
"""Perplexity evaluator to test runs before implementing the real deal."""
|
89 |
+
c_train = training_data(res, final_split=True, text_len=text_len) # Use mostly same settings as training.
|
90 |
+
for name, split in [
|
91 |
+
('minitrain', 'train[:5%]'), # To gauge memorization.
|
92 |
+
('minival', 'val'), # To tune hparams.
|
93 |
+
('eval', 'test'), # To compute final publishable scores.
|
94 |
+
]:
|
95 |
+
for ds in _DATASETS:
|
96 |
+
c.evals[f'{ds}/{name}/pplx'] = dict(
|
97 |
+
type='proj.paligemma.perplexity', pred='logits',
|
98 |
+
key='text', shift_labels=True,
|
99 |
+
log_percent=0.05, # Eval ~20x per run; it's cheap.
|
100 |
+
data={**c_train[ds].data, 'split': split},
|
101 |
+
pp_fn=c_train[ds].pp,
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def sweep_best(add, arg=None):
|
106 |
+
"""Train with best hyper-params."""
|
107 |
+
c = bvcc.parse_arg(arg, final_split=False)
|
108 |
+
# TODO: Update once latest numbers are in and have only 1 setup.
|
109 |
+
# Based on (internal link) (relaxed_accuracy).
|
110 |
+
add(lr=1e-5, wd=1e-6, total_epochs=30, **bvcc.arg(res=224, **c))
|
111 |
+
# Based on sweep (internal link) and on (internal link) (relaxed_accuracy).
|
112 |
+
add(lr=1e-5, wd=1e-6, total_epochs=30, **bvcc.arg(res=448, **c))
|
113 |
+
# Based on (internal link) (relaxed_accuracy).
|
114 |
+
# Not better: add(lr=1e-5, wd=1e-6, total_epochs=30, **bvcc.arg(res=896, **c))
|
115 |
+
|
116 |
+
|
117 |
+
sweep = sweep_best # Choose which sweep to run.
|
118 |
+
|
119 |
+
|
120 |
+
def get_config(arg=None):
|
121 |
+
"""Config for training."""
|
122 |
+
c = bvcc.parse_arg(arg, mode='xm', res=896, final_split=False)
|
123 |
+
|
124 |
+
c.input = training_data(c.res, final_split=c.final_split)
|
125 |
+
|
126 |
+
# Instead of epochs, you can also use `total_examples` or `total_steps`.
|
127 |
+
c.total_epochs = 30
|
128 |
+
c.input.batch_size = 256
|
129 |
+
c.optax_name = 'scale_by_adam'
|
130 |
+
c.optax = dict(b2=0.999)
|
131 |
+
c.lr = 1e-5
|
132 |
+
c.wd = 1e-6
|
133 |
+
c.grad_clip_norm = 1.0
|
134 |
+
c.label_smoothing = 0.2
|
135 |
+
c.schedule = dict(decay_type='cosine', warmup_percent=0.05)
|
136 |
+
|
137 |
+
# Add evaluators.
|
138 |
+
c.evals = {}
|
139 |
+
add_eval(c, c.res, batch_size=1024)
|
140 |
+
add_eval_pplx(c, c.res)
|
141 |
+
|
142 |
+
# Model section.
|
143 |
+
c.model_name = 'proj.paligemma.paligemma'
|
144 |
+
c.model = {}
|
145 |
+
c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
|
146 |
+
c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.1)
|
147 |
+
c.model_init = f'pt_{c.res}'
|
148 |
+
|
149 |
+
# FSDP strategy.
|
150 |
+
c.mesh = [('data', -1)]
|
151 |
+
c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
152 |
+
c.sharding_rules = [('act_batch', ('data',))]
|
153 |
+
|
154 |
+
# These probably do not need any change/tuning
|
155 |
+
c.log_training_steps = 50
|
156 |
+
c.ckpt_steps = 1_000
|
157 |
+
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops']
|
158 |
+
|
159 |
+
# Update configs for quicker local runs and avoid swapping.
|
160 |
+
if c.mode in ('runlocal', 'mock'):
|
161 |
+
for ds in _DATASETS:
|
162 |
+
c.input[ds].shuffle_buffer_size = None
|
163 |
+
for ev in c.evals.values():
|
164 |
+
ev.data.split = ev.data.split.split('[')[0] + '[:16]'
|
165 |
+
|
166 |
+
if c.mode == 'runlocal':
|
167 |
+
c.log_training_steps = 1
|
168 |
+
c.input.batch_size = 2
|
169 |
+
|
170 |
+
c.seed = 0
|
171 |
+
return c
|
172 |
+
|
173 |
+
|
174 |
+
def metrics(arg=None): # pylint: disable=unused-argument
|
175 |
+
m = ['training_loss']
|
176 |
+
for split in ('eval', 'minival', 'minitrain'):
|
177 |
+
for ds in _DATASETS:
|
178 |
+
m.append(f'{ds}/{split}/relaxed_acc')
|
179 |
+
m.append(f'{ds}/{split}/pplx/avg')
|
180 |
+
return m
|