|
--- |
|
license: mit |
|
tags: |
|
- image-classification |
|
- pneumonia-detection |
|
- healthcare |
|
- medical-imaging |
|
- pytorch |
|
- resnet18 |
|
library_name: pytorch |
|
model_name: resnet18_pneumonia_classifier |
|
--- |
|
|
|
# ResNet18 Pneumonia Detection Model |
|
|
|
This model is a fine-tuned version of the ResNet18 architecture for pneumonia detection. It was trained on the [Kaggle Chest X-ray Pneumonia dataset](https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia), which includes images of normal lungs and lungs with pneumonia. The model is capable of distinguishing between **Pneumonia** and **Normal** chest X-rays. |
|
|
|
## Model Details |
|
|
|
- **Model Architecture**: ResNet18 |
|
- **Input Size**: 224 x 224 |
|
- **Number of Classes**: 2 (Pneumonia, Normal) |
|
- **Framework**: PyTorch |
|
- **Training Dataset**: [Kaggle Chest X-ray Pneumonia Dataset](https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia) |
|
- **Library**: PyTorch |
|
|
|
## Model Performance |
|
|
|
- **Accuracy**: 83.3% |
|
- **Loss**: 0.2459 |
|
|
|
## Intended Use |
|
|
|
This model is designed to assist healthcare professionals in identifying pneumonia from chest X-ray images. It should not be used as a sole diagnostic tool but as a supplement to medical expertise. |
|
|
|
## Training Details |
|
|
|
The model was trained using the following setup: |
|
|
|
- **Architecture**: ResNet18 (Pre-trained on ImageNet) |
|
- **Optimizer**: SGD (Stochastic Gradient Descent) |
|
- **Learning Rate**: 0.001 |
|
- **Momentum**: 0.9 |
|
- **Loss Function**: CrossEntropyLoss |
|
- **Batch Size**: 32 |
|
- **Data Augmentation**: |
|
- Random Rotation (±30 degrees) |
|
- Random Zoom (20%) |
|
- Random Horizontal Shift (±10% width) |
|
- Random Vertical Shift (±10% height) |
|
- Random Horizontal Flip |
|
- **Training Epochs**: 1 |
|
- **Evaluation Metric**: Cross Entropy Loss |
|
|
|
### Augmentation Details |
|
|
|
The dataset was augmented during training with the following transformations: |
|
- Randomly rotated some training images by 30 degrees |
|
- Randomly zoomed some training images by 20% |
|
- Randomly shifted images horizontally by 10% of the width |
|
- Randomly shifted images vertically by 10% of the height |
|
- Randomly flipped images horizontally |
|
|
|
|
|
## How to Use the Model |
|
|
|
You can use this model with the `transformers` and `torch` libraries. |
|
|
|
```python |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from torchvision import transforms |
|
from PIL import Image |
|
import requests |
|
|
|
# Download the model weights from Hugging Face Hub |
|
model_path = hf_hub_download(repo_id="izeeek/resnet18_pneumonia_classifier", filename="resnet18_pneumonia_classifier.pth") |
|
|
|
# Load the model architecture (ResNet18) |
|
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False) |
|
|
|
# Adjust the final layer for binary classification (if necessary) |
|
model.fc = torch.nn.Linear(model.fc.in_features, 2) |
|
|
|
# Load the downloaded weights |
|
model.load_state_dict(torch.load(model_path)) |
|
|
|
# Set the model to evaluation mode |
|
model.eval() |
|
|
|
# Image preprocessing |
|
transform = transforms.Compose([ |
|
transforms.Grayscale(num_output_channels=3), |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
# Sample Image (replace with your own image URL) |
|
url = 'https://storage.googleapis.com/kagglesdsdata/datasets/17810/23812/chest_xray/test/NORMAL/IM-0005-0001.jpeg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=databundle-worker-v2%40kaggle-161607.iam.gserviceaccount.com%2F20240913%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240913T014624Z&X-Goog-Expires=345600&X-Goog-SignedHeaders=host&X-Goog-Signature=1f6b37d181f12d083ffc951657e85fea087bb4e81ab955ec955dafcdae49c0d53ce20bc0be93605e2672b9bdd59e752eba9d5a3a0da2e3b3a03c888580b88d63d87611b4e4cec8b8802d53abd53fda165dd04765b8d9f30ddd4e908cd7a2a389ce8244fca7bfa36b3c9cff79d7c5e3f9ee7d59d5b9ef97a2e5c083997892ee3023302313fafff48ded58232db57d6affcfaee704eebba55f2b0abac40b14a38137275ad19cdb1b787930d134f7c30710e29c409bd765ca02e46851470a871cc697f614d464086373f43f5462f241eaf023cfd31e217d7b11e24e1ff34857deb200f5dc1a8c28c8115048ee840be8481f1bd79a2d8e2de1b30cb71420c007d32c' |
|
img = Image.open(requests.get(url, stream=True).raw) |
|
|
|
# Preprocess the image |
|
input_img = transform(img).unsqueeze(0) |
|
|
|
# Inference |
|
with torch.no_grad(): |
|
output = model(input_img) |
|
_, predicted = torch.max(output, 1) |
|
|
|
# Labels for classification |
|
labels = {0: 'Pneumonia', 1: 'Normal'} |
|
print(f'Predicted label: {labels[predicted.item()]}') |