Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
from model import create_model | |
base_dir = 'data/chest_xray' | |
train_dir = os.path.join(base_dir, 'train') | |
val_dir = os.path.join(base_dir, 'val') | |
train_datagen = ImageDataGenerator( | |
rescale=1./255, | |
rotation_range=20, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True, | |
fill_mode='nearest' | |
) | |
val_datagen = ImageDataGenerator(rescale=1./255) | |
train_generator = train_datagen.flow_from_directory( | |
train_dir, | |
target_size=(150, 150), | |
batch_size=32, | |
class_mode='binary' | |
) | |
val_generator = val_datagen.flow_from_directory( | |
val_dir, | |
target_size=(150, 150), | |
batch_size=32, | |
class_mode='binary' | |
) | |
sample_images, _ = next(train_generator) | |
for i in range(5): | |
plt.subplot(1, 5, i+1) | |
plt.imshow(sample_images[i]) | |
plt.axis('off') | |
plt.show() | |
model = create_model() | |
history = model.fit( | |
train_generator, | |
steps_per_epoch=243, | |
epochs=10, | |
validation_data=val_generator, | |
validation_steps=280, | |
callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)] | |
) | |
model.save('xray_image_classifier_model.keras') | |