{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Jlq7oGlpguCe" }, "source": [ "# AI Art Style Detector Project - Topics Used\n", "\n", "## Machine Learning and Deep Learning Topics:\n", "\n", "### 1. Image Preprocessing:\n", "- **Image Loading**: Loading images from file input using Keras's `image.load_img()`.\n", "- **Resizing**: Resizing the input image to a fixed size (`224x224`) before feeding it into the model.\n", "- **Normalization**: Scaling pixel values to the range `[0, 1]` for efficient model input.\n", "\n", "### 2. Model Loading and Inference:\n", "- **Loading Pre-trained Models**: Using `tensorflow.keras.models.load_model()` to load a trained deep learning model (like a CNN for image classification).\n", "- **Prediction**: Using the model to make predictions by feeding the preprocessed image data into the model and getting class probabilities.\n", "\n", "### 3. Transfer Learning:\n", "- **Pre-trained Models**: The model is likely built on a pre-trained CNN model (such as VGG16, ResNet, etc.) through **transfer learning**, where the lower layers are frozen, and only the higher layers are fine-tuned for the specific art style classification task.\n", " \n", "### 4. Classification:\n", "- **Categorical Output**: The model predicts which art style category (e.g., Impressionism, Surrealism) an artwork belongs to.\n", "- **Softmax Activation**: The output layer of the model typically uses **softmax** activation to produce probabilities for each art style class.\n", "\n", "---\n", "\n", "## Web Application Development Topics (Using Streamlit):\n", "\n", "### 1. Streamlit Layout:\n", "- **Column Layouts**: Using `st.columns()` to create responsive, side-by-side layouts for displaying images and results.\n", "- **Expander**: Using `st.expander()` to allow users to reveal additional information about the model and its functionality.\n", "\n", "### 2. File Uploading:\n", "- **Image Upload**: Using `st.file_uploader()` to allow users to upload images directly from their local device into the web app.\n", "- **Image Display**: Using `st.image()` to display the uploaded image on the web app.\n", "\n", "### 3. Interactive Widgets:\n", "- **Dropdown/Selectbox**: Using `st.selectbox()` to allow users to interactively select art styles and get more information about them.\n", "- **Buttons and Inputs**: You could add buttons and input fields to extend functionality, like adding manual entry for predicting specific images.\n", "\n", "### 4. Visualization:\n", "- **Plotly Charts**: Using **Plotly** to visualize art style distributions (like bar charts), making the app more interactive and engaging.\n", "- **Matplotlib/Seaborn** (Optional): Visualizing the results or image transformations (though Plotly is integrated here).\n", "\n", "### 5. Styling the UI:\n", "- **Custom CSS**: Using custom CSS injected into the Streamlit app with `st.markdown()` to enhance the look and feel of the app (e.g., custom colors, fonts, and element styling).\n", " \n", "### 6. Streamlit Features:\n", "- **Markdown Rendering**: Using `st.markdown()` to render HTML and CSS for custom styling or display content.\n", "- **File Handling**: Streamlit handles file uploading, downloading, and processing in a straightforward way using `st.file_uploader()`.\n", "\n", "---\n", "\n", "## Deep Learning Topics in Model Development (for Art Style Classification):\n", "\n", "### 1. Convolutional Neural Networks (CNNs):\n", "- **Convolutional Layers**: CNNs are well-suited for image classification tasks due to their ability to automatically learn spatial hierarchies of features.\n", "- **Pooling Layers**: Max-pooling layers to reduce the spatial dimensions of the image while retaining important features.\n", "- **Fully Connected Layers**: Dense layers to perform the final classification.\n", "\n", "### 2. Transfer Learning:\n", "- Using pre-trained networks like **VGG16**, **ResNet**, or **Inception** as feature extractors, and fine-tuning the final layers for specific art styles.\n", " \n", "### 3. Activation Functions:\n", "- **ReLU (Rectified Linear Unit)**: For non-linear transformations in hidden layers.\n", "- **Softmax**: For multi-class classification, used in the final output layer to output probabilities for each class.\n", "\n", "### 4. Model Training (Optional):\n", "- **Data Augmentation**: Techniques to artificially expand the dataset (e.g., rotations, flips, etc.).\n", "- **Loss Function**: Typically **categorical cross-entropy** for multi-class classification tasks.\n", "- **Optimizer**: Such as **Adam**, to adjust weights during training.\n", "\n", "### 5. Evaluation Metrics:\n", "- **Accuracy**: How often the model predicts the correct class.\n", "- **Confusion Matrix**: (Optional) To evaluate the model’s performance across different art styles.\n", "\n", "---\n", "\n", "## Other Relevant Topics:\n", "\n", "### 1. Data Handling and Preprocessing:\n", "- **Numpy**: Used for image array manipulation and preparing input data.\n", "- **Pandas**: For organizing and visualizing art style statistics (e.g., counts, distributions).\n", "\n", "### 2. Model Evaluation and Fine-tuning (Optional):\n", "- **Hyperparameter Tuning**: Tweaking the learning rate, batch size, etc., to improve model performance.\n", "- **Cross-validation**: Ensuring the model performs well on unseen data.\n", "\n", "---\n", "\n", "## In Summary:\n", "The main topics used in this project are:\n", "\n", "- **Machine Learning**: CNNs, transfer learning, model prediction, image preprocessing, and classification.\n", "- **Deep Learning**: Using pre-trained models, fine-tuning, and evaluating the model’s performance.\n", "- **Streamlit Web Development**: Interactive web app development, custom UI with CSS, file handling, and visualizations.\n", "- **Data Science**: Data manipulation, model deployment, and visualization using Pandas and Plotly.\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "atG_3xNvU720" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The syntax of the command is incorrect.\n" ] } ], "source": [ "!mkdir -p ~/.kaggle\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "s_XA6A_YU7zn", "outputId": "9e66b83c-065f-4b5b-c274-44a57986ebac" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "'cp' is not recognized as an internal or external command,\n", "operable program or batch file.\n" ] } ], "source": [ "!cp kaggle.json ~/.kaggle/\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FW1jquyKU7wu", "outputId": "381ed4f7-26cd-4372-8510-5930a1aa320f" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "'chmod' is not recognized as an internal or external command,\n", "operable program or batch file.\n" ] } ], "source": [ "!chmod 600 ~/.kaggle/kaggle.json\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8hv1Lom6Uec_", "outputId": "3a93e47f-896f-4478-84a2-e2d4e29a5e46" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "'chmod' is not recognized as an internal or external command,\n", "operable program or batch file.\n" ] } ], "source": [ "!chmod 600 kaggle.json\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Tm2JYiyWVCGC", "outputId": "0d222bc9-c378-4822-8999-e57859643897" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.6.17)\n", "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.17.0)\n", "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2024.12.14)\n", "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.8.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.32.3)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.67.1)\n", "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.4)\n", "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.2.3)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.2.0)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n", "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.4.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.10)\n" ] } ], "source": [ "!pip install kaggle\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "OWi7m5uXSobo", "outputId": "1542c1cb-adab-4b3b-db59-612707a19593" }, "outputs": [], "source": [ "#!/bin/bash\n", "!kaggle datasets download steubk/wikiart" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0ado9rLlWD67", "outputId": "32eea455-9996-4e3f-b87c-ee0f40ee5485" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ERROR:root:Internal Python error in the inspect module.\n", "Below is the traceback from this internal error.\n", "\n", "\n", "KeyboardInterrupt\n", "\n" ] } ], "source": [ "import zipfile\n", "\n", "with zipfile.ZipFile(\"/content/wikiart.zip\", \"r\") as zip_ref:\n", " zip_ref.extractall(\"wikiart_data\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QKleAJNUWFED" }, "outputs": [], "source": [ "!ls wikiart_data\n" ] }, { "cell_type": "markdown", "metadata": { "id": "WPI-bMGJWG-w" }, "source": [ "# **1. Data Preprocessing**" ] }, { "cell_type": "markdown", "metadata": { "id": "qAUPwMBYWN8p" }, "source": [ "**Import Libraries**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KwB9KW7vWLFy" }, "outputs": [], "source": [ "import os # For operating system\n", "import numpy as np\n", "import matplotlib.pyplot as plt # for plotting\n", "import tensorflow as tf\n", "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", "from tensorflow.keras.applications import VGG16\n", "from tensorflow.keras import layers, models\n", "from sklearn.model_selection import train_test_split\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jAfmnJEIb02C" }, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow.keras.applications import MobileNetV2\n", "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", "from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint\n", "from tensorflow.keras.optimizers import AdamW\n", "from tensorflow.keras.mixed_precision import set_global_policy" ] }, { "cell_type": "markdown", "metadata": { "id": "XlwZJ5DSXgGF" }, "source": [ "**(B) Load ans Explore the Data**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2wIkoq7AXRpC", "outputId": "317f840d-2a2a-4021-e8c4-24a6485e6b2c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Contemporary_Realism', 'Northern_Renaissance', 'Action_painting', 'wclasses.csv', 'Cubism', 'Color_Field_Painting', 'Realism', 'Rococo', 'Fauvism', 'Romanticism', 'High_Renaissance', 'New_Realism', 'Naive_Art_Primitivism', 'Synthetic_Cubism', 'Art_Nouveau_Modern', 'Baroque', 'Minimalism', 'Impressionism', 'Symbolism', 'Mannerism_Late_Renaissance', 'Abstract_Expressionism', 'Early_Renaissance', 'Analytical_Cubism', 'Post_Impressionism', 'Ukiyo_e', 'classes.csv', 'Pointillism', 'Pop_Art', 'Expressionism']\n" ] } ], "source": [ "# set dataset directory path\n", "dataset_dir = '/content/wikiart_data'\n", "# check the classes available in the dataset\n", "classes = os.listdir(dataset_dir)\n", "print(classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 356 }, "id": "upLuzm4hcWtO", "outputId": "b71042b1-8292-4ba5-eff1-30183c52574d" }, "outputs": [ { "ename": "OSError", "evalue": "[Errno 28] No space left on device", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;31m# Extract the zip file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mzipfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mZipFile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/content/wikiart.zip\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"r\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mzip_ref\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mzip_ref\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextractall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/content/wikiart_data\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;31m# Create directories if they don't exist\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/lib/python3.10/zipfile.py\u001b[0m in \u001b[0;36mextractall\u001b[0;34m(self, path, members, pwd)\u001b[0m\n\u001b[1;32m 1658\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1659\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mzipinfo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmembers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1660\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extract_member\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzipinfo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpwd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1661\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1662\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mclassmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/lib/python3.10/zipfile.py\u001b[0m in \u001b[0;36m_extract_member\u001b[0;34m(self, member, targetpath, pwd)\u001b[0m\n\u001b[1;32m 1713\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmember\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpwd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpwd\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0msource\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1714\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtargetpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"wb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1715\u001b[0;31m \u001b[0mshutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopyfileobj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msource\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1716\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1717\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtargetpath\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/lib/python3.10/shutil.py\u001b[0m in \u001b[0;36mcopyfileobj\u001b[0;34m(fsrc, fdst, length)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0mfdst_write\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_samefile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdst\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mOSError\u001b[0m: [Errno 28] No space left on device" ] } ], "source": [ "import os\n", "import shutil\n", "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", "import zipfile\n", "\n", "# Define paths\n", "dataset_dir = \"/content/wikiart_data\" # All images in this folder\n", "train_dir = \"/content/train\" # Folder for training images\n", "val_dir = \"/content/val\" # Folder for validation images\n", "\n", "# Extract the zip file\n", "with zipfile.ZipFile(\"/content/wikiart.zip\", \"r\") as zip_ref:\n", " zip_ref.extractall(\"/content/wikiart_data\")\n", "\n", "# Create directories if they don't exist\n", "os.makedirs(train_dir, exist_ok=True)\n", "os.makedirs(val_dir, exist_ok=True)\n", "\n", "# Create subdirectories for classes\n", "classes = [d for d in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, d))]\n", "for cls in classes:\n", " os.makedirs(os.path.join(train_dir, cls), exist_ok=True)\n", " os.makedirs(os.path.join(val_dir, cls), exist_ok=True)\n", "\n", "# Split dataset\n", "for cls in classes:\n", " cls_dir = os.path.join(dataset_dir, cls)\n", " images = os.listdir(cls_dir)\n", " # Check if the images list is empty before using train_test_split\n", " if not images:\n", " print(f\"Warning: No images found in {cls_dir}. Skipping this directory.\")\n", " continue # Skip to the next class\n", " # added to handle if there is only one image in the directory\n", " if len(images) == 1:\n", " print(f\"Warning: Only one image found in {cls_dir}. Skipping this directory.\")\n", " continue\n", " train_images, val_images = train_test_split(images, test_size=0.2, random_state=42) # 80% train, 20% val\n", "\n", " # Move files to respective folders\n", " for img in train_images:\n", " try:\n", " shutil.move(os.path.join(cls_dir, img), os.path.join(train_dir, cls, img))\n", " except shutil.Error as e:\n", " print(f\"Error moving file {img} from {cls_dir} to {train_dir}/{cls}: {e}\")\n", " for img in val_images:\n", " try:\n", " shutil.move(os.path.join(cls_dir, img), os.path.join(val_dir, cls, img))\n", " except shutil.Error as e:\n", " print(f\"Error moving file {img} from {cls_dir} to {val_dir}/{cls}: {e}\")\n", "\n", "print(\"Dataset split completed.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "KRDb2vLAX1m-" }, "source": [ "**(c) Image Resizing and Normalization**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CPa6EY8bXxMN", "outputId": "6a2ac532-d5ec-4e80-e8e3-902ac557fdcc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 65166 images belonging to 27 classes.\n", "Found 16278 images belonging to 27 classes.\n" ] } ], "source": [ "# Set parameters\n", "image_size = (128, 128) # Smaller image size for memory efficiency\n", "batch_size = 16 # Reduced batch size\n", "num_classes = 10 # Adjust based on your dataset\n", "\n", "# Data augmentation and rescaling\n", "train_datagen = ImageDataGenerator(\n", " rescale=1.0 / 255,\n", " rotation_range=20,\n", " width_shift_range=0.2,\n", " height_shift_range=0.2,\n", " shear_range=0.2,\n", " zoom_range=0.2,\n", " horizontal_flip=True\n", ")\n", "\n", "val_datagen = ImageDataGenerator(rescale=1.0 / 255)\n", "\n", "# Data generators\n", "train_gen = train_datagen.flow_from_directory(\n", " train_dir,\n", " target_size=image_size,\n", " batch_size=batch_size,\n", " class_mode='categorical'\n", ")\n", "\n", "val_gen = val_datagen.flow_from_directory(\n", " val_dir,\n", " target_size=image_size,\n", " batch_size=batch_size,\n", " class_mode='categorical'\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "i9U4kDsnZ4rW" }, "source": [ "# **2. Model Architecture**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Lr-JPXQcBJo" }, "outputs": [], "source": [ "# Load pre-trained MobileNetV2 with frozen layers\n", "base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(128, 128, 3))\n", "base_model.trainable = False # Freeze base layers to reduce computation\n", "\n", "# Build the model\n", "model = tf.keras.Sequential([\n", " base_model,\n", " tf.keras.layers.GlobalAveragePooling2D(),\n", " tf.keras.layers.Dense(256, activation='relu'),\n", " tf.keras.layers.Dropout(0.5),\n", " tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32') # Ensure outputs are float32\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "3iO9Hv-na53V" }, "source": [ "**(b) compile the model**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M-IwFmTBZ9P5" }, "outputs": [], "source": [ "# Compile the model\n", "optimizer = AdamW(learning_rate=0.001)\n", "model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "7dw_MJpYbHze" }, "source": [ "**(c) Train the model**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ro14UX1HbFx7" }, "outputs": [], "source": [ "# Callbacks\n", "checkpoint = ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss')\n", "early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)\n", "\n", "# Train the model\n", "history = model.fit(\n", " train_gen,\n", " validation_data=val_gen,\n", " epochs=20,\n", " callbacks=[checkpoint, early_stopping]\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "B23w_jvmbpVd" }, "source": [ "# **4. Evaluate the model**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RorPfsd_bmB2" }, "outputs": [], "source": [ "# Plot training and validation accuracy\n", "plt.plot(history.history['accuracy'], label='Training Accuracy')\n", "plt.plot(history.history['val_accuracy'], label='Validation Accuracy')\n", "plt.title('Model Accuracy')\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()\n", "\n", "# Plot training and validation loss\n", "plt.plot(history.history['loss'], label='Training Loss')\n", "plt.plot(history.history['val_loss'], label='Validation Loss')\n", "plt.title('Model Loss')\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Loss')\n", "plt.legend()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "fB4FmTpIcEbc" }, "source": [ "# 5. Model testing" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uJyJa4rfcD0Z" }, "outputs": [], "source": [ "from tensorflow.keras.preprocessing import image\n", "\n", "# Load a test image\n", "img_path = '/path_to_test_image/test_image.jpg'\n", "img = image.load_img(img_path, target_size=(img_size, img_size))\n", "img_array = image.img_to_array(img) / 255.0 # Normalize\n", "img_array = np.expand_dims(img_array, axis=0)\n", "\n", "# Predict the style\n", "prediction = model.predict(img_array)\n", "predicted_class = classes[np.argmax(prediction)]\n", "print(f\"Predicted Art Style: {predicted_class}\")\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }