Alias-Free Convnets: Fractional Shift Invariance via Polynomial Activations

Official PyTorch trained model.

This is a ConvNeXt-Tiny variant.

convnext_tiny_baseline is ConvNeXt-Tiny with circular-padded convolutions. convnext_tiny_afc is the full ConvNeXt-Tiny-AFC which is shift invariant to circular shifts.

For more details see the paper or implementation.

git clone https://github.com/hmichaeli/alias_free_convnets.git
from huggingface_hub import hf_hub_download
import torch
from torchvision import datasets, transforms
from alias_free_convnets.models.convnext_afc import convnext_afc_tiny

# baseline
path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_basline.pth")
ckpt = torch.load(path)
base_model = convnext_afc_tiny(pretrained=False, num_classes=1000)
base_model.load_state_dict(ckpt, strict=True)

# AFC
path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_afc.pth")
ckpt = torch.load(path)
afc_model = convnext_afc_tiny(
        pretrained=False,
        num_classes=1000,
        activation='up_poly_per_channel',
        activation_kwargs={'in_scale': 7, 'out_scale': 7, 'train_scale': True},
        blurpool_kwargs={"filter_type": "ideal", "scale_l2": False},
        normalization_type='CHW2',
        stem_activation_kwargs={"in_scale": 7, "out_scale": 7, "train_scale": True, "cutoff": 0.75},
        normalization_kwargs={},
        stem_mode='activation_residual', stem_activation='lpf_poly_per_channel'
    )
afc_model.load_state_dict(ckpt, strict=False)

# evaluate model
interpolation = transforms.InterpolationMode.BICUBIC
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
transform = transforms.Compose([
    transforms.Resize(256, interpolation=interpolation),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
data_path = "/path/to/imagenet/val"
dataset_val = datasets.ImageFolder(data_path, transform=transform)
nb_classes = 1000
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = torch.utils.data.DataLoader(
            dataset_val, sampler=sampler_val,
            batch_size=8,
            num_workers=8,
            drop_last=False
        )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@torch.no_grad()
def evaluate(data_loader, model, device):
    model.eval()
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total
    print("Acc@1 {:.3f}".format(acc))



print("evaluate baseline")
base_model.to(device)
test_stats = evaluate(data_loader_val, base_model, device)

print("evaluate AFC")
afc_model.to(device)
test_stats = evaluate(data_loader_val, afc_model, device)


Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Dataset used to train hmichaeli/convnext-afc