"""train_model.py This module trains a simple image classification model using TensorFlow/Keras. """ import os import tensorflow as tf from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense from tensorflow.keras.optimizers import Adam from tensorflow.keras.preprocessing import image import numpy as np from scipy import misc # Constants IMG_SIZE = (224, 224) BATCH_SIZE = 32 EPOCHS = 10 NUM_CLASSES = 5 def preprocess_image(img): """ Preprocess the input image for model prediction. Parameters: - img: Input image. Returns: - img: Preprocessed image. """ img = image.img_to_array(img) img = np.expand_dims(img, axis=0) img /= 255.0 return img def train_model(dataset_path): """ Train the image classification model. Parameters: - dataset_path (str): Path to the dataset. Returns: - model: Trained Keras model. """ # Ensure the dataset path is correct dataset_path = os.path.abspath(dataset_path) # Data preprocessing datagen = ImageDataGenerator( rescale=1./255, validation_split=0.2 ) train_generator = datagen.flow_from_directory( dataset_path, target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode='categorical', subset='training' ) validation_generator = datagen.flow_from_directory( dataset_path, target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode='categorical', subset='validation' ) # Model definition model = Sequential() model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3))) model.add(MaxPooling2D((2, 2))) model.add(Flatten()) model.add(Dense(64, activation='relu')) model.add(Dense(NUM_CLASSES, activation='softmax')) # Model compilation model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy']) # Model training history = model.fit( train_generator, epochs=EPOCHS, validation_data=validation_generator ) return model if __name__ == "__main__": script_directory = os.path.dirname(os.path.abspath(__file__)) dataset_path = os.path.join(script_directory, '../../data') trained_model = train_model(dataset_path) trained_model.save(os.path.join(script_directory, '../../models/football_logo_model.h5'))