|
"""train_model.py |
|
|
|
This module trains a simple image classification model using TensorFlow/Keras. |
|
|
|
""" |
|
|
|
import os |
|
import tensorflow as tf |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense |
|
from tensorflow.keras.optimizers import Adam |
|
from tensorflow.keras.preprocessing import image |
|
import numpy as np |
|
from scipy import misc |
|
|
|
|
|
IMG_SIZE = (224, 224) |
|
BATCH_SIZE = 32 |
|
EPOCHS = 10 |
|
NUM_CLASSES = 5 |
|
|
|
def preprocess_image(img): |
|
""" |
|
Preprocess the input image for model prediction. |
|
|
|
Parameters: |
|
- img: Input image. |
|
|
|
Returns: |
|
- img: Preprocessed image. |
|
""" |
|
img = image.img_to_array(img) |
|
img = np.expand_dims(img, axis=0) |
|
img /= 255.0 |
|
return img |
|
|
|
def train_model(dataset_path): |
|
""" |
|
Train the image classification model. |
|
|
|
Parameters: |
|
- dataset_path (str): Path to the dataset. |
|
|
|
Returns: |
|
- model: Trained Keras model. |
|
""" |
|
|
|
|
|
dataset_path = os.path.abspath(dataset_path) |
|
|
|
|
|
datagen = ImageDataGenerator( |
|
rescale=1./255, |
|
validation_split=0.2 |
|
) |
|
|
|
train_generator = datagen.flow_from_directory( |
|
dataset_path, |
|
target_size=IMG_SIZE, |
|
batch_size=BATCH_SIZE, |
|
class_mode='categorical', |
|
subset='training' |
|
) |
|
|
|
validation_generator = datagen.flow_from_directory( |
|
dataset_path, |
|
target_size=IMG_SIZE, |
|
batch_size=BATCH_SIZE, |
|
class_mode='categorical', |
|
subset='validation' |
|
) |
|
|
|
|
|
model = Sequential() |
|
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3))) |
|
model.add(MaxPooling2D((2, 2))) |
|
model.add(Flatten()) |
|
model.add(Dense(64, activation='relu')) |
|
model.add(Dense(NUM_CLASSES, activation='softmax')) |
|
|
|
|
|
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy']) |
|
|
|
|
|
history = model.fit( |
|
train_generator, |
|
epochs=EPOCHS, |
|
validation_data=validation_generator |
|
) |
|
|
|
return model |
|
|
|
if __name__ == "__main__": |
|
script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
dataset_path = os.path.join(script_directory, '../../data') |
|
trained_model = train_model(dataset_path) |
|
trained_model.save(os.path.join(script_directory, '../../models/football_logo_model.h5')) |
|
|