timm
/

Model card for csatv2

CSATv2 is a lightweight high-resolution vision backbone designed to maximize throughput at 512×512 resolution. By applying frequency-domain compression at the input stage, the model suppresses redundant spatial information and achieves extremely fast inference.

Highlights

  • 🚀 2,800 images/s at 512×512 resolution (A6000 1×GPU)
  • Frequency-domain compression for lightweight and efficient modeling
  • 🎯 80.02% ImageNet-1K Top-1 Accuracy
  • 🪶 Only 11M parameters
  • 🧩 Suitable for image classification or as a high-throughput detection backbone

This model is an improved version of the architecture used in the paper

Special thanks to Juno for contributing ideas and feedback that greatly helped in lightweighting and optimizing the model.

Model Details

  • Model Type: Image Classification / Feature Encoder
  • Model Stats:
    • Params (M): 11.1
    • GMACs: 1.4
    • Activations (M): 9.2
    • Image size: 512 x 512
  • Original: https://huggingface.co/Hyunil/CSATv2
  • License: unknown
  • Dataset: ImageNet-1k
  • Papers:

Background and Motivation

In computational pathology, a single whole-slide image (WSI) is typically partitioned into thousands to tens of thousands of high-resolution image patches (e.g., 512×512 pixels) for analysis.

This setting places strong constraints on both throughput and latency: even small inefficiencies in patch-level inference can lead to prohibitively long end-to-end processing times at the slide level.

CSATv2 was originally designed to address this constraint by enabling high-throughput, high-resolution inference while preserving classification accuracy. In practical deployments, this design reduced slide-level processing time from tens of minutes to approximately one minute, making near–real-time pathological analysis feasible at scale.

Model description

image

Training Details

The model was trained on ImageNet-1K using a high-resolution training pipeline adapted from common ImageNet training practices.

  • Dataset: ImageNet-1K
  • Input resolution: 512×512
  • Model: CSATv2
  • Optimizer: AdamW
  • Learning rate: 2e-3
  • Learning rate schedule: Cosine
  • Epochs: 300
  • Warmup epochs: 5
  • Weight decay: 2e-2
  • Batch size: 128 (per GPU)
  • Mixed precision training: Enabled (AMP)

Data Augmentation

  • Random resized crop (scale: 0.08–1.0, ratio: 3/4–4/3)
  • Horizontal flip (p = 0.5)
  • RandAugment (rand-m7-mstd0.5-inc1)
  • Mixup (α = 0.8)
  • CutMix (α = 1.0)
  • Bicubic interpolation

Regularization

  • Label smoothing: Disabled (handled implicitly via Mixup / CutMix)
  • Dropout / DropPath: Disabled
  • Random erase: Disabled

Optimization Details

  • Exponential Moving Average (EMA): Enabled
    • EMA decay: 0.99996
  • Gradient clipping: Disabled
  • Channels-last memory format: Optional

The training pipeline was adapted from publicly available ImageNet training repositories (Solving ImageNet), with task-specific modifications for high-resolution and high-throughput training.

Model Usage

Image Classification

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model('csatv2', pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

Feature Map Extraction

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
    'csatv2',
    pretrained=True,
    features_only=True,
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

for o in output:
    # print shape of each feature map in output
    # e.g.:
    #  torch.Size([1, 32, 64, 64])
    #  torch.Size([1, 72, 32, 32])
    #  torch.Size([1, 168, 16, 16])
    #  torch.Size([1, 386, 8, 8])

    print(o.shape)

Image Embeddings

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
    'csatv2',
    pretrained=True,
    num_classes=0,  # remove classifier nn.Linear
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # output is (batch_size, num_features) shaped tensor

# or equivalently (without needing to set num_classes=0)

output = model.forward_features(transforms(img).unsqueeze(0))
# output is unpooled, a (1, 386, 8, 8) shaped tensor

output = model.forward_head(output, pre_logits=True)
# output is a (1, num_features) shaped tensor

Contact

This project was conducted by members of MLPA Lab. Welcome feedback and suggestion, questions.

Citation

TBD

Downloads last month
13
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train timm/csatv2