chayryali's picture
Fix incorrect model license (#2)
ea70e54 verified
metadata
license: cc-by-nc-4.0
pipeline_tag: image-classification

Hiera Base MAE Trained on IN1K fine-tuned on IN1K

Hiera is a hierarchical vision transformer that is fast, powerful, and, above all, simple. It outperforms the state-of-the-art across a wide array of image and video tasks while being much faster.

How does it work?

A diagram of Hiera's architecture.

Vision transformers like ViT use the same spatial resolution and number of features throughout the whole network. But this is inefficient: the early layers don't need that many features, and the later layers don't need that much spatial resolution. Prior hierarchical models like ResNet accounted for this by using fewer features at the start and less spatial resolution at the end.

Several domain specific vision transformers have been introduced that employ this hierarchical design, such as Swin or MViT. But in the pursuit of state-of-the-art results using fully supervised training on ImageNet-1K, these models have become more and more complicated as they add specialized modules to make up for spatial biases that ViTs lack. While these changes produce effective models with attractive FLOP counts, under the hood the added complexity makes these models slower overall.

We show that a lot of this bulk is actually unnecessary. Instead of manually adding spatial bases through architectural changes, we opt to teach the model these biases instead. By training with MAE, we can simplify or remove all of these bulky modules in existing transformers and increase accuracy in the process. The result is Hiera, an extremely efficient and simple architecture that outperforms the state-of-the-art in several image and video recognition tasks.

Installation

Hiera requires a reasonably recent version of torch. After that, you can install hiera through pip:

pip install hiera-transformer

This repo should support the latest timm version, but timm is a constantly updating package. Create an issue if you have problems with a newer version of timm. With the hiera-transformer and huggingface-hub packages installed, you can simply run, e.g.,

from hiera import Hiera
model = Hiera.from_pretrained("facebook/hiera_base_224.mae_in1k_ft_in1k")  # mae pt then in1k ft'd model
model = Hiera.from_pretrained("facebook/hiera_base_224.mae_in1k") # just mae pt, no ft

to load a model. Use <model_name>.<checkpoint_name> from model zoo below.

If you want to save a model, use model.config as the config, e.g.,

model.save_pretrained("hiera-base-224", config=model.config)

Inference

import hiera
model = hiera.hiera_base_224(pretrained=True, checkpoint="mae_in1k_ft_in1k")

Then you can run inference like any other model:

output = model(x)

Video inference works the same way, just use a 16x224 model instead.

Note: for efficiency, Hiera re-orders its tokens at the start of the network (see the Roll and Unroll modules in hiera_utils.py). Thus, tokens aren't in spatial order by default. If you'd like to use intermediate feature maps for a downstream task, pass the return_intermediates flag when running the model:

output, intermediates = model(x, return_intermediates=True)

MAE Inference

By default, the models do not include the MAE decoder. If you would like to use the decoder or compute MAE loss, you can instantiate an mae version by running:

import hiera
model = hiera.mae_hiera_base_224(pretrained=True, checkpoint="mae_in1k")

Then when you run inference on the model, it will return a 4-tuple of (loss, predictions, labels, mask) where predictions and labels are for the deleted tokens only. The returned mask will be True if the token is visible and False if it's deleted. You can change the masking ratio by passing it during inference:

loss, preds, labels, mask = model(x, mask_ratio=0.6)

The default mask ratio is 0.6 for images, but you should pass in 0.9 for video. See the paper for details.

Note: We use normalized pixel targets for MAE pretraining, meaning the patches are each individually normalized before the model model has to predict them. Thus, you have to unnormalize them using the ground truth before visualizing them. See get_pixel_label_2d in hiera_mae.py for details.

Citation

If you use Hiera or this code in your work, please cite:

@article{ryali2023hiera,
  title={Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles},
  author={Ryali, Chaitanya and Hu, Yuan-Ting and Bolya, Daniel and Wei, Chen and Fan, Haoqi and Huang, Po-Yao and Aggarwal, Vaibhav and Chowdhury, Arkabandhu and Poursaeed, Omid and Hoffman, Judy and Malik, Jitendra and Li, Yanghao and Feichtenhofer, Christoph},
  journal={ICML},
  year={2023}
}