Alif Al Hasan
[Task] Model Training
e58ca2a
"""train_model.py
This module trains a simple image classification model using TensorFlow/Keras.
"""
import os
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing import image
import numpy as np
from scipy import misc
# Constants
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 10
NUM_CLASSES = 5
def preprocess_image(img):
"""
Preprocess the input image for model prediction.
Parameters:
- img: Input image.
Returns:
- img: Preprocessed image.
"""
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img /= 255.0
return img
def train_model(dataset_path):
"""
Train the image classification model.
Parameters:
- dataset_path (str): Path to the dataset.
Returns:
- model: Trained Keras model.
"""
# Ensure the dataset path is correct
dataset_path = os.path.abspath(dataset_path)
# Data preprocessing
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2
)
train_generator = datagen.flow_from_directory(
dataset_path,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='training'
)
validation_generator = datagen.flow_from_directory(
dataset_path,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='validation'
)
# Model definition
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dense(NUM_CLASSES, activation='softmax'))
# Model compilation
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
# Model training
history = model.fit(
train_generator,
epochs=EPOCHS,
validation_data=validation_generator
)
return model
if __name__ == "__main__":
script_directory = os.path.dirname(os.path.abspath(__file__))
dataset_path = os.path.join(script_directory, '../../data')
trained_model = train_model(dataset_path)
trained_model.save(os.path.join(script_directory, '../../models/football_logo_model.h5'))