|
--- |
|
license: mit |
|
base_model: |
|
- torchvision/convnext_tiny |
|
- pytorch/resnet50 |
|
metrics: |
|
- accuracy |
|
tags: |
|
- Interpretability |
|
- Explainable AI |
|
- XAI |
|
- Classification |
|
- CNN |
|
- Convolutional Neural Networks |
|
--- |
|
|
|
# A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations |
|
|
|
This repository contains the Deep Classification-by-Component (CBC) models for |
|
prototype-based learning interpretability benchmarks for classification as described in the paper |
|
"A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations" |
|
|
|
## Model Description |
|
|
|
The CBC approach learns components (or prototypes) to create interpretable learning insights. |
|
It uses positive and negative reasoning to reason about the class predictions |
|
i.e. the presence and absence of components creates evidence for a given class to be predicted |
|
as that class. |
|
|
|
The [`deep_cbc`](https://github.com/si-cim/cbc-aaai-2025) package provides trainer, evaluation |
|
and visualization scripts for the CBC models in deep settings with CNN architecture as feature |
|
extractor backbones. Further, CBC with positive reasoning is equivalent to having an RBF |
|
classification head. Additionally, we provide compatibility support with the PIPNet |
|
classification head as well. |
|
|
|
### Available and Supported Architectures |
|
|
|
We provide two variants of CNNs for each of the CUB-200-2011, Stanford Cars and |
|
Oxford-IIIT dataset: |
|
|
|
- **ResNet50 w/ CBC Classification Head**: Built on both partially trained and fully trained |
|
backbone from the `model_zoo` module in `pytorch`. |
|
- **ConvNeXt w/ CBC Classification Head**: Built on partially trained trained `convnext_tiny` |
|
backbone from `torchvision`. |
|
|
|
Further, training the above two architectures is possible with an RBF and PIPNet classification |
|
head as well. |
|
|
|
## Performance |
|
|
|
All models were trained and evaluated on the CUB-200-2011 (CUB), Stanford Cars (CARS) and |
|
Oxford-IIIT Pets (PETS) datasets and below we report the top-1 classification accuracy |
|
results on these datasets. |
|
|
|
| Model Version | Backbone | CUB | CARS | PETS | |
|
|---------------|-----------------|--------------|--------------|--------------| |
|
| CBC-C | `convnext_tiny` | 87.8 ± 0.1 % | 93.0 ± 0.1 % | 93.9 ± 0.1 % | |
|
| CBC-R | `resnet50` | 83.3 ± 0.3 % | 92.7 ± 0.1 % | 90.1 ± 0.1 % | |
|
| CBC-R Full | `resnet50` | 82.8 ± 0.3 % | 92.8 ± 0.1 % | 89.5 ± 0.2 % | |
|
|
|
## Model Features |
|
|
|
- 🔍 **Interpretable Decision Assistance:** The model performs classification by |
|
using positive and negative reasoning based on learnt components (or prototypes) to provide |
|
interpretable decision-making insights for assistance. |
|
- 🎯 **SotA Accuracy:** Achieves SotA performance on classification tasks for the interpretability benchmarks. |
|
- 🚀 **Multiple Feature Extractor CNN Backbones:** Supports ConvNeXt and ResNet50 feature extractor |
|
architecture backbones with CBC heads for interpretable image classification tasks. |
|
- 📊 **Visualization and Analysis Tools:** Equipped with visualization tools to plot learnt prototype patches and |
|
corresponding activation maps alongside the similarity score and detection probability metrics. |
|
|
|
## Requirements |
|
|
|
- python = "^3.9" |
|
- numpy = "1.26.4" |
|
- matplotlib = "3.8.4" |
|
- scikit-learn = "1.4.2" |
|
- scipy = "1.13.0" |
|
- pillow = "10.3.0" |
|
- omegaconf = "2.3.0" |
|
- hydra-core = "1.3.2" |
|
- torch = "2.2.2" |
|
- torchvision = "0.17.2" |
|
- setuptools = "68.2.0" |
|
|
|
The basic dependencies for using the models are stated above. Please, refer to the |
|
[GitHub repository](https://github.com/si-cim/cbc-aaai-2025) for detailed dependencies |
|
and project setup instructions to execute experiments with the above models. |
|
|
|
## Limitations and Bias |
|
|
|
- ❗ **Partial Interpretability Issue:** The uninterpretable feature extractor CNN backbone introduces |
|
an uninterpretable component into the model. Although, we achieve SotA accuracy and demonstrate |
|
that the models provide quality positive and negative reasoning explanations. But, still we |
|
can only call these methods partially interpretable owing to the fact that all prototypes learnt |
|
are not human interpretable. |
|
- ❗ **Data Bias Issue:** These models are trained on CUB-200-2011, Stanford Cars and Oxford-IIIT Pet datasets |
|
and the stated model performance would not generalize to other domains. |
|
- ❗ **Resolution Constraints Issue:** The model backbones are pre-trained with a resolution of 224×224. |
|
Although models can flexibly input images of different resolutions with current data loaders. |
|
The performance will be suboptimal owing to fixed receptive fields learnt by networks for a given resolution. |
|
Possibly, a scope of improvement on Stanford Cars dataset can be to standardize image sizes as |
|
a pre-processing step to achieve better performance. |
|
- ❗ **Location Misalignment Issue:** CNN based models are not perfectly immune to |
|
location misalignment under adversarial attack. Hence, with blackbox feature extractor the |
|
learnt prototype-based networks are also prone to such issues. |
|
|
|
## Citation |
|
|
|
If you use this model in your research, please consider to cite: |
|
|
|
```bibtex |
|
@article{saralajew2024robust, |
|
title={A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations}, |
|
author={Saralajew, Sascha and Rana, Ashish and Villmann, Thomas and Shaker, Ammar}, |
|
journal={arXiv preprint arXiv:2412.15499}, |
|
year={2024} |
|
} |
|
``` |
|
|
|
## Acknowledgements |
|
|
|
This implementation builds upon the following excellent repositories: |
|
|
|
- [PIPNet](https://github.com/M-Nauta/PIPNet) |
|
- [ProtoPNet](https://github.com/cfchen-duke/ProtoPNet) |
|
|
|
And further these repositories can be referred to as additional documentation details specified |
|
in the above two repositories regarding the data pre-processing, data loaders, |
|
model architectures and visualizations. |
|
|
|
## License |
|
|
|
This project is released under [MIT] license. |
|
|
|
## Contact |
|
|
|
For any questions or feedback, please: |
|
1. Open an issue in the project [GitHub repository](https://github.com/si-cim/cbc-aaai-2025) |
|
2. Contact the Correspondence Author |
|
|