Model card for csatv2
CSATv2 is a lightweight high-resolution vision backbone designed to maximize throughput at 512×512 resolution. By applying frequency-domain compression at the input stage, the model suppresses redundant spatial information and achieves extremely fast inference.
Highlights
- 🚀 2,800 images/s at 512×512 resolution (A6000 1×GPU)
- ⚡ Frequency-domain compression for lightweight and efficient modeling
- 🎯 80.02% ImageNet-1K Top-1 Accuracy
- 🪶 Only 11M parameters
- 🧩 Suitable for image classification or as a high-throughput detection backbone
This model is an improved version of the architecture used in the paper
Special thanks to Juno for contributing ideas and feedback that greatly helped in lightweighting and optimizing the model.
Model Details
- Model Type: Image Classification / Feature Encoder
- Model Stats:
- Params (M): 11.1
- GMACs: 1.4
- Activations (M): 9.2
- Image size: 512 x 512
- Original: https://huggingface.co/Hyunil/CSATv2
- License: unknown
- Dataset: ImageNet-1k
- Papers:
Background and Motivation
In computational pathology, a single whole-slide image (WSI) is typically partitioned into thousands to tens of thousands of high-resolution image patches (e.g., 512×512 pixels) for analysis.
This setting places strong constraints on both throughput and latency: even small inefficiencies in patch-level inference can lead to prohibitively long end-to-end processing times at the slide level.
CSATv2 was originally designed to address this constraint by enabling high-throughput, high-resolution inference while preserving classification accuracy. In practical deployments, this design reduced slide-level processing time from tens of minutes to approximately one minute, making near–real-time pathological analysis feasible at scale.
Model description
Training Details
The model was trained on ImageNet-1K using a high-resolution training pipeline adapted from common ImageNet training practices.
- Dataset: ImageNet-1K
- Input resolution: 512×512
- Model: CSATv2
- Optimizer: AdamW
- Learning rate: 2e-3
- Learning rate schedule: Cosine
- Epochs: 300
- Warmup epochs: 5
- Weight decay: 2e-2
- Batch size: 128 (per GPU)
- Mixed precision training: Enabled (AMP)
Data Augmentation
- Random resized crop (scale: 0.08–1.0, ratio: 3/4–4/3)
- Horizontal flip (p = 0.5)
- RandAugment (rand-m7-mstd0.5-inc1)
- Mixup (α = 0.8)
- CutMix (α = 1.0)
- Bicubic interpolation
Regularization
- Label smoothing: Disabled (handled implicitly via Mixup / CutMix)
- Dropout / DropPath: Disabled
- Random erase: Disabled
Optimization Details
- Exponential Moving Average (EMA): Enabled
- EMA decay: 0.99996
- Gradient clipping: Disabled
- Channels-last memory format: Optional
The training pipeline was adapted from publicly available ImageNet training repositories (Solving ImageNet), with task-specific modifications for high-resolution and high-throughput training.
Model Usage
Image Classification
from urllib.request import urlopen
from PIL import Image
import timm
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
model = timm.create_model('csatv2', pretrained=True)
model = model.eval()
# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
Feature Map Extraction
from urllib.request import urlopen
from PIL import Image
import timm
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
model = timm.create_model(
'csatv2',
pretrained=True,
features_only=True,
)
model = model.eval()
# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1
for o in output:
# print shape of each feature map in output
# e.g.:
# torch.Size([1, 32, 64, 64])
# torch.Size([1, 72, 32, 32])
# torch.Size([1, 168, 16, 16])
# torch.Size([1, 386, 8, 8])
print(o.shape)
Image Embeddings
from urllib.request import urlopen
from PIL import Image
import timm
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
model = timm.create_model(
'csatv2',
pretrained=True,
num_classes=0, # remove classifier nn.Linear
)
model = model.eval()
# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0)) # output is (batch_size, num_features) shaped tensor
# or equivalently (without needing to set num_classes=0)
output = model.forward_features(transforms(img).unsqueeze(0))
# output is unpooled, a (1, 386, 8, 8) shaped tensor
output = model.forward_head(output, pre_logits=True)
# output is a (1, num_features) shaped tensor
Contact
This project was conducted by members of MLPA Lab. Welcome feedback and suggestion, questions.
Citation
TBD
- Downloads last month
- 13
