|
import os |
|
import numpy as np |
|
import cv2 |
|
from sklearn.model_selection import train_test_split |
|
from transformers import ViTForImageClassification, ViTFeatureExtractor |
|
from transformers import Trainer, TrainingArguments |
|
from datasets import load_dataset, Dataset |
|
|
|
|
|
def load_images_from_folders(folders, label): |
|
images = [] |
|
labels = [] |
|
for folder in folders: |
|
for filename in os.listdir(folder): |
|
if filename.lower().endswith(('.png', '.jpg', '.jpeg')): |
|
img = cv2.imread(os.path.join(folder, filename), cv2.IMREAD_GRAYSCALE) |
|
if img is not None: |
|
img = cv2.resize(img, (224, 224)) |
|
img = img.astype(np.float32) |
|
img /= 255.0 |
|
images.append(img) |
|
labels.append(label) |
|
else: |
|
print(f"Failed to load image: {filename}") |
|
return images, labels |
|
|
|
|
|
normal_folders = [ |
|
os.path.join('chest-xray-pneumonia', 'chest_xray', 'test', 'NORMAL'), |
|
os.path.join('chest-xray-pneumonia', 'chest_xray', 'train', 'NORMAL'), |
|
os.path.join('chest-xray-pneumonia', 'chest_xray', 'val', 'NORMAL'), |
|
] |
|
pneumonia_folders = [ |
|
os.path.join('chest-xray-pneumonia', 'chest_xray', 'test', 'PNEUMONIA'), |
|
os.path.join('chest-xray-pneumonia', 'chest_xray', 'train', 'PNEUMONIA'), |
|
os.path.join('chest-xray-pneumonia', 'chest_xray', 'val', 'PNEUMONIA'), |
|
] |
|
|
|
normal_images, normal_labels = load_images_from_folders(normal_folders, 0) |
|
pneumonia_images, pneumonia_labels = load_images_from_folders(pneumonia_folders, 1) |
|
|
|
|
|
images = normal_images + pneumonia_images |
|
labels = normal_labels + pneumonia_labels |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.2, random_state=42) |
|
|
|
|
|
train_dataset = Dataset.from_dict({"image": X_train, "label": y_train}) |
|
test_dataset = Dataset.from_dict({"image": X_test, "label": y_test}) |
|
|
|
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') |
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=2) |
|
|
|
|
|
def preprocess_function(examples): |
|
return feature_extractor(images=examples['image'], return_tensors="pt") |
|
|
|
train_dataset = train_dataset.map(preprocess_function, batched=True) |
|
test_dataset = test_dataset.map(preprocess_function, batched=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
evaluation_strategy='epoch', |
|
learning_rate=2e-5, |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
num_train_epochs=10, |
|
weight_decay=0.01, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
model.save_pretrained('./pneumonia_model_final') |
|
print("Model saved as './pneumonia_model_final'") |
|
|