vit-xray-pneumonia-classification
This model is a fine-tuned version of google/vit-base-patch16-224-in21k on the chest-xray-classification dataset. It achieves the following results on the evaluation set:
- Loss: 0.0868
- Accuracy: 0.9742
Inference example
from transformers import pipeline
classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification")
# image taken from https://www.news-medical.net/health/What-is-Viral-Pneumonia.aspx
classifier("https://d2jx2rerrg6sh3.cloudfront.net/image-handler/ts/20200618040600/ri/650/picture/2020/6/shutterstock_786937069.jpg")
>>>
[{'score': 0.990334689617157, 'label': 'PNEUMONIA'},
{'score': 0.009665317833423615, 'label': 'NORMAL'}]
Training procedure
Notebook link: here
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 5e-05
- train_batch_size: 16
- eval_batch_size: 16
- seed: 42
- gradient_accumulation_steps: 4
- total_train_batch_size: 64
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_ratio: 0.1
- num_epochs: 15
from transformers import EarlyStoppingCallback
training_args = TrainingArguments(
output_dir="vit-xray-pneumonia-classification",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
per_device_eval_batch_size=16,
num_train_epochs=15,
save_total_limit=2,
warmup_ratio=0.1,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=True,
push_to_hub=True,
report_to="tensorboard"
)
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=processor,
compute_metrics=compute_metrics,
callbacks=[early_stopping],
)
Training results
Training Loss | Epoch | Step | Validation Loss | Accuracy |
---|---|---|---|---|
0.5152 | 0.99 | 63 | 0.2507 | 0.9245 |
0.2334 | 1.99 | 127 | 0.1766 | 0.9382 |
0.1647 | 3.0 | 191 | 0.1218 | 0.9588 |
0.144 | 4.0 | 255 | 0.1222 | 0.9502 |
0.1348 | 4.99 | 318 | 0.1293 | 0.9571 |
0.1276 | 5.99 | 382 | 0.1000 | 0.9665 |
0.1175 | 7.0 | 446 | 0.1177 | 0.9502 |
0.109 | 8.0 | 510 | 0.1079 | 0.9665 |
0.0914 | 8.99 | 573 | 0.0804 | 0.9717 |
0.0872 | 9.99 | 637 | 0.0800 | 0.9717 |
0.0804 | 11.0 | 701 | 0.0862 | 0.9682 |
0.0935 | 12.0 | 765 | 0.0883 | 0.9657 |
0.0686 | 12.99 | 828 | 0.0868 | 0.9742 |
Framework versions
- Transformers 4.30.2
- Pytorch 1.9.0+cu102
- Datasets 2.12.0
- Tokenizers 0.13.3
- Downloads last month
- 1,767
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Model tree for lxyuan/vit-xray-pneumonia-classification
Base model
google/vit-base-patch16-224-in21kDataset used to train lxyuan/vit-xray-pneumonia-classification
Space using lxyuan/vit-xray-pneumonia-classification 1
Evaluation results
- Accuracy on chest-xray-classificationvalidation set self-reported0.974