{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing FocusDataSet" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2450\n" ] }, { "data": { "text/plain": [ "{'image': array([[[211, 185, 62],\n", " [216, 192, 68],\n", " [223, 198, 79],\n", " ...,\n", " [214, 190, 64],\n", " [222, 199, 71],\n", " [224, 201, 73]],\n", " \n", " [[218, 192, 69],\n", " [223, 197, 74],\n", " [229, 205, 83],\n", " ...,\n", " [216, 193, 65],\n", " [225, 202, 74],\n", " [226, 203, 75]],\n", " \n", " [[223, 198, 72],\n", " [228, 202, 79],\n", " [234, 210, 88],\n", " ...,\n", " [220, 197, 69],\n", " [228, 205, 77],\n", " [226, 203, 73]],\n", " \n", " ...,\n", " \n", " [[157, 138, 17],\n", " [163, 145, 21],\n", " [178, 157, 32],\n", " ...,\n", " [166, 169, 40],\n", " [170, 173, 42],\n", " [176, 179, 46]],\n", " \n", " [[145, 126, 5],\n", " [155, 137, 13],\n", " [177, 156, 31],\n", " ...,\n", " [156, 158, 31],\n", " [166, 169, 40],\n", " [175, 178, 47]],\n", " \n", " [[147, 128, 7],\n", " [159, 141, 17],\n", " [181, 160, 35],\n", " ...,\n", " [149, 151, 24],\n", " [162, 164, 37],\n", " [172, 175, 46]]], dtype=uint8),\n", " 'focus_value': tensor(0.5450)}" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from importlib.machinery import SourceFileLoader\n", "\n", "focus_datamodule = SourceFileLoader(\"focus_datamodule\", \"../src/datamodules/focus_datamodule.py\").load_module()\n", "from focus_datamodule import FocusDataSet\n", "\n", "ds = FocusDataSet(\"../data/focus150/metadata.csv\", \"../data/focus150/\")\n", "\n", "counter = 0\n", "for d in ds:\n", " counter += 1\n", "\n", "print(counter)\n", "\n", "d" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14\n" ] }, { "data": { "text/plain": [ "torch.Size([64, 1])" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from focus_datamodule import FocusDataModule\n", "\n", "datamodule = FocusDataModule(data_dir=\"../data/focus150\", csv_file=\"../data/focus150/metadata.csv\")\n", "datamodule.setup()\n", "\n", "for data in datamodule.test_dataloader():\n", " break\n", "\n", "len(data[\"focus_value\"])\n", "\n", "# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " pool_size = 3\n", " \n", " conv1_size = 5\n", " conv1_out = 6\n", " conv2_size = 5\n", " conv2_out = 16\n", " size_img = 150\n", "\n", " size_img -= conv1_size - 1\n", " size_img = int( (size_img) / pool_size)\n", " size_img -= conv2_size - 1\n", " size_img = int(size_img / pool_size)\n", "\n", " print(size_img)\n", "\n", " self.model = nn.Sequential(\n", " nn.Conv2d(3, conv1_out, conv1_size),\n", " nn.MaxPool2d(pool_size, pool_size),\n", " nn.Conv2d(conv1_out, conv2_out, conv2_size),\n", " nn.MaxPool2d(pool_size, pool_size),\n", " nn.Flatten(),\n", " nn.Linear(conv2_out * size_img * size_img, 120), # 16 * 34 * 34 or [64, 16, 15, 15]\n", " nn.Linear(120, 84),\n", " nn.Linear(84, 1)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return x\n", "\n", "\n", "net = Net()\n", "\n", "net(data[\"image\"]).shape" ] } ], "metadata": { "interpreter": { "hash": "f9f85f796d01129d0dd105a088854619f454435301f6ffec2fea96ecbd9be4ac" }, "kernelspec": { "display_name": "Python 3.9.7 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }