Spaces:
Sleeping
Sleeping
import os | |
from PIL import Image | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
# Set paths | |
base_dir = 'data/chest_xray' | |
val_dir = os.path.join(base_dir, 'val') | |
normal_class_dir = os.path.join(val_dir, 'NORMAL') | |
pneumonia_class_dir = os.path.join(val_dir, 'PNEUMONIA') | |
def augment_images(class_directory, num_augmented_images): | |
datagen = ImageDataGenerator( | |
rescale=1. / 255, | |
rotation_range=20, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True, | |
fill_mode='nearest' | |
) | |
generator = datagen.flow_from_directory( | |
directory=os.path.dirname(class_directory), # Parent directory | |
target_size=(150, 150), | |
batch_size=1, | |
class_mode=None, | |
shuffle=False, | |
classes=[os.path.basename(class_directory)] # Specify class if using subdirectory | |
) | |
print(f"Found {generator.samples} images in {class_directory}") | |
if generator.samples == 0: | |
print("No images found in the directory.") | |
return | |
count = 0 | |
while count < num_augmented_images: | |
try: | |
img_batch = generator.__next__() # Use __next__() to get image batch | |
img = (img_batch[0] * 255).astype('uint8') # Extract the first image in the batch | |
img_pil = Image.fromarray(img) | |
img_path = os.path.join(class_directory, f"augmented_{count}.png") | |
img_pil.save(img_path) | |
count += 1 | |
except StopIteration: | |
print("No more images to generate.") | |
break | |
print(f"Total augmented images created: {count}") | |
# Number of augmented images to generate | |
num_augmented_images_normal = 2944 - 3875 # This should be a negative number since NORMAL is already balanced | |
num_augmented_images_pneumonia = 2944 - 1171 # To match the number of NORMAL images | |
# Generate augmented images for the NORMAL class | |
augment_images(normal_class_dir, max(num_augmented_images_normal, 0)) | |
# Generate augmented images for the PNEUMONIA class | |
augment_images(pneumonia_class_dir, num_augmented_images_pneumonia) | |