|
--- |
|
library_name: torch |
|
tags: |
|
- image-classification |
|
- resnet |
|
- diagrams |
|
- pytorch |
|
- computer-vision |
|
license: apache-2.0 |
|
metrics: |
|
- accuracy |
|
- f1 |
|
- recall |
|
- precision |
|
base_model: |
|
- microsoft/resnet-18 |
|
pipeline_tag: image-classification |
|
datasets: |
|
- phiyodr/coco2017 |
|
- HuggingFaceM4/ChartQA |
|
- JasmineQiuqiu/diagrams_with_captions_2 |
|
--- |
|
|
|
# Model Card for Diagram Classification Model |
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
This is a fine-tuned ResNet-18 model trained for binary image classification, distinguishing between **diagrams** and **non-diagrams**. The model is designed for use in applications that need automatic filtering or processing of diagram-based content. |
|
|
|
- **Developed by:** Aya Mohamed |
|
- **Model type:** ResNet-18 (Fine-tuned for image classification) |
|
- **Language(s) (NLP):** Not applicable (Computer Vision model) |
|
- **License:** Apache 2.0 |
|
- **Finetuned from model:** `microsoft/resnet-18` |
|
|
|
### Model Sources |
|
|
|
- **Repository:** [Ayamohamed/diaclass-model](https://huggingface.co/Ayamohamed/diaclass-model) |
|
|
|
## Uses |
|
|
|
### Direct Use |
|
|
|
This model is intended for classifying images as **diagrams** or **non-diagrams**. It can be used in: |
|
- **Document processing** (extracting diagrams from PDFs or scanned documents) |
|
- **Chart-based visual question generation (VQG)** |
|
- **Content moderation** (filtering diagram images from general image datasets) |
|
|
|
### Out-of-Scope Use |
|
|
|
- Not suitable for **multi-class classification** beyond diagrams vs. non-diagrams. |
|
- Not designed for **hand-drawn sketches** or **complex figures with mixed elements**. |
|
|
|
## Bias, Risks, and Limitations |
|
|
|
- The model's accuracy depends on the training dataset, which may not cover all possible diagram styles. |
|
- May misclassify **charts, blueprints, or artistic drawings** if they resemble diagrams. |
|
|
|
### Recommendations |
|
|
|
Users should **evaluate the model** on their specific dataset before deployment to ensure it performs well in their context. |
|
|
|
|
|
|
|
## 🚀 How to Use |
|
|
|
### **1️⃣ Load the Model from Hugging Face** |
|
You can download the model and load it using `torch`. |
|
|
|
```python |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
|
|
# Download model from Hugging Face Hub |
|
model_path = hf_hub_download(repo_id="Ayamohamed/DiaClassification", filename="model.pth") |
|
|
|
# Load model |
|
model_hg = torch.load(model_path) |
|
model_hg.eval() # Set to evaluation mode |
|
|
|
``` |
|
### **2️⃣ Preprocess and Classify an Image** |
|
```python |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
# Define Image Transformations |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
def predict(image_path): |
|
image = Image.open(image_path).convert("RGB") |
|
image = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
output = model_hg(image) |
|
class_idx = torch.argmax(output, dim=1).item() |
|
|
|
return "Diagram" if class_idx == 0 else "Not Diagram" |
|
|
|
# Example usage |
|
print(predict("my-diagram-classifier/31188_1536932698.jpg")) |
|
|
|
|
|
``` |
|
|
|
|
|
|
|
## Training Details |
|
|
|
### Training Data |
|
|
|
The model was trained using: |
|
- **ChartQA dataset** (for diagram samples) |
|
- **JasmineQiuqiu/diagrams_with_captions_2** (for diagram samples) |
|
- **COCO dataset (subset)** (for non-diagram samples) |
|
|
|
### Training Procedure |
|
|
|
- **Pretrained model:** `microsoft/resnet-18` |
|
- **Optimization:** Adam optimizer |
|
- **Loss function:** Cross-entropy loss |
|
- **Training duration:** Approx. X hours on an NVIDIA GPU |
|
|
|
## Evaluation |
|
|
|
### Testing Data & Metrics |
|
|
|
- **Dataset:** Held-out test set from ChartQA, AI2D-RST, and COCO |
|
- **Metrics:** |
|
- **Test Loss:** 0.0371 |
|
- **Test Accuracy:** 99.08% |
|
- **Precision:** 0.9995 |
|
- **Recall:** 0.9820 |
|
- **F1 Score:** 0.9907 |
|
|
|
## Environmental Impact |
|
|
|
- **Hardware Used:** NVIDIA A100 GPU |
|
- **Compute Hours:** Approx. X hours |
|
- **Estimated Carbon Emission:** [Use MLCO2 Calculator](https://mlco2.github.io/impact#compute) |
|
|
|
## Citation |
|
|
|
If you use this model, please cite: |
|
|
|
```bibtex |
|
@misc{aya2025diaclass, |
|
author = {Aya Mohamed}, |
|
title = {Diagram Classification Model}, |
|
year = {2025}, |
|
publisher = {Hugging Face}, |
|
url = {https://huggingface.co/Ayamohamed/diaclass-model} |
|
} |
|
``` |