{ "cells": [ { "cell_type": "code", "execution_count": 11, "id": "10ee1bf4", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import numpy as np\n", "from PIL import Image\n", "from tqdm import tqdm\n", "from lib.gan_networks import define_G\n", "import torchvision.transforms as transforms" ] }, { "cell_type": "code", "execution_count": 12, "id": "59797ab5", "metadata": {}, "outputs": [], "source": [ "def __transforms2pil_resize(method):\n", " mapper = {\n", " transforms.InterpolationMode.BILINEAR: Image.BILINEAR,\n", " transforms.InterpolationMode.BICUBIC: Image.BICUBIC,\n", " transforms.InterpolationMode.NEAREST: Image.NEAREST,\n", " transforms.InterpolationMode.LANCZOS: Image.LANCZOS,\n", " }\n", " return mapper[method]\n", "\n", "\n", "def __scale_width(\n", " img, target_size, crop_size, method=transforms.InterpolationMode.BICUBIC\n", "):\n", " method = __transforms2pil_resize(method)\n", " ow, oh = img.size\n", " if ow == target_size and oh >= crop_size:\n", " return img\n", " w = target_size\n", " h = int(max(target_size * oh / ow, crop_size))\n", " return img.resize((w, h), method)\n", "\n", "\n", "def get_transform(load_size, crop_size, method=transforms.InterpolationMode.BICUBIC):\n", " transform_list = [\n", " transforms.Lambda(lambda img: __scale_width(img, load_size, crop_size, method)),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", " ]\n", " return transforms.Compose(transform_list)\n", "\n", "\n", "def tensor2im(input_image, imtype=np.uint8):\n", " \"\"\" \"Converts a Tensor array into a numpy image array.\n", "\n", " Parameters:\n", " input_image (tensor) -- the input image tensor array\n", " imtype (type) -- the desired type of the converted numpy array\n", " \"\"\"\n", " if not isinstance(input_image, np.ndarray):\n", " if isinstance(input_image, torch.Tensor): # get the data from a variable\n", " image_tensor = input_image.data\n", " else:\n", " return input_image\n", " image_numpy = (\n", " image_tensor[0].cpu().float().numpy()\n", " ) # convert it into a numpy array\n", " if image_numpy.shape[0] == 1: # grayscale to RGB\n", " image_numpy = np.tile(image_numpy, (3, 1, 1))\n", " image_numpy = (\n", " (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0\n", " ) # post-processing: tranpose and scaling\n", " else: # if it is a numpy array, do nothing\n", " image_numpy = input_image\n", " return image_numpy.astype(imtype)\n", "\n", "\n", "def create_model_and_transform(pretrained: str = None):\n", " # Creating model\n", " input_nc = 3\n", " output_nc = 3\n", " ngf = 64\n", " netG = \"resnet_9blocks\"\n", " norm = \"instance\"\n", " no_dropout = True\n", " init_type = \"normal\"\n", " init_gain = 0.02\n", " gpu_ids = []\n", "\n", " netG_A = define_G(\n", " input_nc,\n", " output_nc,\n", " ngf,\n", " netG,\n", " norm,\n", " not no_dropout,\n", " init_type,\n", " init_gain,\n", " gpu_ids,\n", " )\n", " if pretrained:\n", " chkpntA = torch.load(pretrained)\n", " netG_A.load_state_dict(chkpntA)\n", " netG_A.eval()\n", "\n", " netG_A = netG_A.cuda()\n", "\n", " # Creating transform\n", " load_size = 1280\n", " crop_size = 224\n", " image_transforms = get_transform(load_size=load_size, crop_size=crop_size)\n", " return netG_A, image_transforms\n", "\n", "\n", "def run_inference(img_path, model, transform):\n", " image = Image.open(img_path)\n", " inputs = image_transforms(image).unsqueeze(0).to(\"cuda\")\n", "\n", " with torch.no_grad():\n", " out = model(inputs)\n", " out = tensor2im(out)\n", " return Image.fromarray(out)" ] }, { "cell_type": "code", "execution_count": 13, "id": "6fc20d26", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialize network with normal\n" ] }, { "ename": "RuntimeError", "evalue": "Error(s) in loading state_dict for UnetGenerator:\n\tMissing key(s) in state_dict: \"model.model.0.weight\", \"model.model.0.bias\", \"model.model.1.model.1.weight\", \"model.model.1.model.1.bias\", \"model.model.1.model.3.model.1.weight\", \"model.model.1.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.5.weight\", \"model.model.1.model.3.model.5.bias\", \"model.model.1.model.5.weight\", \"model.model.1.model.5.bias\", \"model.model.3.weight\", \"model.model.3.bias\". \n\tUnexpected key(s) in state_dict: \"model.1.weight\", \"model.1.bias\", \"model.4.weight\", \"model.4.bias\", \"model.7.weight\", \"model.7.bias\", \"model.10.conv_block.1.weight\", \"model.10.conv_block.1.bias\", \"model.10.conv_block.5.weight\", \"model.10.conv_block.5.bias\", \"model.11.conv_block.1.weight\", \"model.11.conv_block.1.bias\", \"model.11.conv_block.5.weight\", \"model.11.conv_block.5.bias\", \"model.12.conv_block.1.weight\", \"model.12.conv_block.1.bias\", \"model.12.conv_block.5.weight\", \"model.12.conv_block.5.bias\", \"model.13.conv_block.1.weight\", \"model.13.conv_block.1.bias\", \"model.13.conv_block.5.weight\", \"model.13.conv_block.5.bias\", \"model.14.conv_block.1.weight\", \"model.14.conv_block.1.bias\", \"model.14.conv_block.5.weight\", \"model.14.conv_block.5.bias\", \"model.15.conv_block.1.weight\", \"model.15.conv_block.1.bias\", \"model.15.conv_block.5.weight\", \"model.15.conv_block.5.bias\", \"model.16.conv_block.1.weight\", \"model.16.conv_block.1.bias\", \"model.16.conv_block.5.weight\", \"model.16.conv_block.5.bias\", \"model.17.conv_block.1.weight\", \"model.17.conv_block.1.bias\", \"model.17.conv_block.5.weight\", \"model.17.conv_block.5.bias\", \"model.18.conv_block.1.weight\", \"model.18.conv_block.1.bias\", \"model.18.conv_block.5.weight\", \"model.18.conv_block.5.bias\", \"model.19.weight\", \"model.19.bias\", \"model.22.weight\", \"model.22.bias\", \"model.26.weight\", \"model.26.bias\". ", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m gan, image_transforms \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_model_and_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./checkpoints/clear2snowy.pth\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[12], line 82\u001b[0m, in \u001b[0;36mcreate_model_and_transform\u001b[0;34m(pretrained)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pretrained:\n\u001b[1;32m 81\u001b[0m chkpntA \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(pretrained)\n\u001b[0;32m---> 82\u001b[0m \u001b[43mnetG_A\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchkpntA\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 83\u001b[0m netG_A\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 85\u001b[0m netG_A \u001b[38;5;241m=\u001b[39m netG_A\u001b[38;5;241m.\u001b[39mcuda()\n", "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:2189\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2184\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m 2185\u001b[0m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2186\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 2189\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2190\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2191\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for UnetGenerator:\n\tMissing key(s) in state_dict: \"model.model.0.weight\", \"model.model.0.bias\", \"model.model.1.model.1.weight\", \"model.model.1.model.1.bias\", \"model.model.1.model.3.model.1.weight\", \"model.model.1.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.5.weight\", \"model.model.1.model.3.model.5.bias\", \"model.model.1.model.5.weight\", \"model.model.1.model.5.bias\", \"model.model.3.weight\", \"model.model.3.bias\". \n\tUnexpected key(s) in state_dict: \"model.1.weight\", \"model.1.bias\", \"model.4.weight\", \"model.4.bias\", \"model.7.weight\", \"model.7.bias\", \"model.10.conv_block.1.weight\", \"model.10.conv_block.1.bias\", \"model.10.conv_block.5.weight\", \"model.10.conv_block.5.bias\", \"model.11.conv_block.1.weight\", \"model.11.conv_block.1.bias\", \"model.11.conv_block.5.weight\", \"model.11.conv_block.5.bias\", \"model.12.conv_block.1.weight\", \"model.12.conv_block.1.bias\", \"model.12.conv_block.5.weight\", \"model.12.conv_block.5.bias\", \"model.13.conv_block.1.weight\", \"model.13.conv_block.1.bias\", \"model.13.conv_block.5.weight\", \"model.13.conv_block.5.bias\", \"model.14.conv_block.1.weight\", \"model.14.conv_block.1.bias\", \"model.14.conv_block.5.weight\", \"model.14.conv_block.5.bias\", \"model.15.conv_block.1.weight\", \"model.15.conv_block.1.bias\", \"model.15.conv_block.5.weight\", \"model.15.conv_block.5.bias\", \"model.16.conv_block.1.weight\", \"model.16.conv_block.1.bias\", \"model.16.conv_block.5.weight\", \"model.16.conv_block.5.bias\", \"model.17.conv_block.1.weight\", \"model.17.conv_block.1.bias\", \"model.17.conv_block.5.weight\", \"model.17.conv_block.5.bias\", \"model.18.conv_block.1.weight\", \"model.18.conv_block.1.bias\", \"model.18.conv_block.5.weight\", \"model.18.conv_block.5.bias\", \"model.19.weight\", \"model.19.bias\", \"model.22.weight\", \"model.22.bias\", \"model.26.weight\", \"model.26.bias\". " ] } ], "source": [ "gan, image_transforms = create_model_and_transform(\"./checkpoints/clear2snowy.pth\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d44ebf97", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:39<00:00, 2.51it/s]\n" ] } ], "source": [ "image_path = os.listdir(\"./data/images\")\n", "save_folder = \"./data/gan/snow_images\"\n", "\n", "for img in tqdm(image_path):\n", " trg = os.path.join(\"./data/images\", img)\n", " src = os.path.join(f\"./data/gan/snow_images/\", img.split(\".\")[0] + \".jpg\")\n", " if not (os.path.exists(src)):\n", " out = run_inference(img_path=trg, model=gan, transform=image_transforms)\n", " out.save(src)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }