ScanSmartAI / augment_images.py
devfire's picture
Upload 6 files
f054618 verified
raw
history blame
2.16 kB
import os
from PIL import Image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Set paths
base_dir = 'data/chest_xray'
val_dir = os.path.join(base_dir, 'val')
normal_class_dir = os.path.join(val_dir, 'NORMAL')
pneumonia_class_dir = os.path.join(val_dir, 'PNEUMONIA')
def augment_images(class_directory, num_augmented_images):
datagen = ImageDataGenerator(
rescale=1. / 255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
generator = datagen.flow_from_directory(
directory=os.path.dirname(class_directory), # Parent directory
target_size=(150, 150),
batch_size=1,
class_mode=None,
shuffle=False,
classes=[os.path.basename(class_directory)] # Specify class if using subdirectory
)
print(f"Found {generator.samples} images in {class_directory}")
if generator.samples == 0:
print("No images found in the directory.")
return
count = 0
while count < num_augmented_images:
try:
img_batch = generator.__next__() # Use __next__() to get image batch
img = (img_batch[0] * 255).astype('uint8') # Extract the first image in the batch
img_pil = Image.fromarray(img)
img_path = os.path.join(class_directory, f"augmented_{count}.png")
img_pil.save(img_path)
count += 1
except StopIteration:
print("No more images to generate.")
break
print(f"Total augmented images created: {count}")
# Number of augmented images to generate
num_augmented_images_normal = 2944 - 3875 # This should be a negative number since NORMAL is already balanced
num_augmented_images_pneumonia = 2944 - 1171 # To match the number of NORMAL images
# Generate augmented images for the NORMAL class
augment_images(normal_class_dir, max(num_augmented_images_normal, 0))
# Generate augmented images for the PNEUMONIA class
augment_images(pneumonia_class_dir, num_augmented_images_pneumonia)