File size: 15,404 Bytes
1a41d53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "10ee1bf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "from lib.gan_networks import define_G\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "59797ab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def __transforms2pil_resize(method):\n",
    "    mapper = {\n",
    "        transforms.InterpolationMode.BILINEAR: Image.BILINEAR,\n",
    "        transforms.InterpolationMode.BICUBIC: Image.BICUBIC,\n",
    "        transforms.InterpolationMode.NEAREST: Image.NEAREST,\n",
    "        transforms.InterpolationMode.LANCZOS: Image.LANCZOS,\n",
    "    }\n",
    "    return mapper[method]\n",
    "\n",
    "\n",
    "def __scale_width(\n",
    "    img, target_size, crop_size, method=transforms.InterpolationMode.BICUBIC\n",
    "):\n",
    "    method = __transforms2pil_resize(method)\n",
    "    ow, oh = img.size\n",
    "    if ow == target_size and oh >= crop_size:\n",
    "        return img\n",
    "    w = target_size\n",
    "    h = int(max(target_size * oh / ow, crop_size))\n",
    "    return img.resize((w, h), method)\n",
    "\n",
    "\n",
    "def get_transform(load_size, crop_size, method=transforms.InterpolationMode.BICUBIC):\n",
    "    transform_list = [\n",
    "        transforms.Lambda(lambda img: __scale_width(img, load_size, crop_size, method)),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "    ]\n",
    "    return transforms.Compose(transform_list)\n",
    "\n",
    "\n",
    "def tensor2im(input_image, imtype=np.uint8):\n",
    "    \"\"\" \"Converts a Tensor array into a numpy image array.\n",
    "\n",
    "    Parameters:\n",
    "        input_image (tensor) --  the input image tensor array\n",
    "        imtype (type)        --  the desired type of the converted numpy array\n",
    "    \"\"\"\n",
    "    if not isinstance(input_image, np.ndarray):\n",
    "        if isinstance(input_image, torch.Tensor):  # get the data from a variable\n",
    "            image_tensor = input_image.data\n",
    "        else:\n",
    "            return input_image\n",
    "        image_numpy = (\n",
    "            image_tensor[0].cpu().float().numpy()\n",
    "        )  # convert it into a numpy array\n",
    "        if image_numpy.shape[0] == 1:  # grayscale to RGB\n",
    "            image_numpy = np.tile(image_numpy, (3, 1, 1))\n",
    "        image_numpy = (\n",
    "            (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0\n",
    "        )  # post-processing: tranpose and scaling\n",
    "    else:  # if it is a numpy array, do nothing\n",
    "        image_numpy = input_image\n",
    "    return image_numpy.astype(imtype)\n",
    "\n",
    "\n",
    "def create_model_and_transform(pretrained: str = None):\n",
    "    # Creating model\n",
    "    input_nc = 3\n",
    "    output_nc = 3\n",
    "    ngf = 64\n",
    "    netG = \"resnet_9blocks\"\n",
    "    norm = \"instance\"\n",
    "    no_dropout = True\n",
    "    init_type = \"normal\"\n",
    "    init_gain = 0.02\n",
    "    gpu_ids = []\n",
    "\n",
    "    netG_A = define_G(\n",
    "        input_nc,\n",
    "        output_nc,\n",
    "        ngf,\n",
    "        netG,\n",
    "        norm,\n",
    "        not no_dropout,\n",
    "        init_type,\n",
    "        init_gain,\n",
    "        gpu_ids,\n",
    "    )\n",
    "    if pretrained:\n",
    "        chkpntA = torch.load(pretrained)\n",
    "        netG_A.load_state_dict(chkpntA)\n",
    "    netG_A.eval()\n",
    "\n",
    "    netG_A = netG_A.cuda()\n",
    "\n",
    "    # Creating transform\n",
    "    load_size = 1280\n",
    "    crop_size = 224\n",
    "    image_transforms = get_transform(load_size=load_size, crop_size=crop_size)\n",
    "    return netG_A, image_transforms\n",
    "\n",
    "\n",
    "def run_inference(img_path, model, transform):\n",
    "    image = Image.open(img_path)\n",
    "    inputs = image_transforms(image).unsqueeze(0).to(\"cuda\")\n",
    "\n",
    "    with torch.no_grad():\n",
    "        out = model(inputs)\n",
    "    out = tensor2im(out)\n",
    "    return Image.fromarray(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6fc20d26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initialize network with normal\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Error(s) in loading state_dict for UnetGenerator:\n\tMissing key(s) in state_dict: \"model.model.0.weight\", \"model.model.0.bias\", \"model.model.1.model.1.weight\", \"model.model.1.model.1.bias\", \"model.model.1.model.3.model.1.weight\", \"model.model.1.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.5.weight\", \"model.model.1.model.3.model.5.bias\", \"model.model.1.model.5.weight\", \"model.model.1.model.5.bias\", \"model.model.3.weight\", \"model.model.3.bias\". \n\tUnexpected key(s) in state_dict: \"model.1.weight\", \"model.1.bias\", \"model.4.weight\", \"model.4.bias\", \"model.7.weight\", \"model.7.bias\", \"model.10.conv_block.1.weight\", \"model.10.conv_block.1.bias\", \"model.10.conv_block.5.weight\", \"model.10.conv_block.5.bias\", \"model.11.conv_block.1.weight\", \"model.11.conv_block.1.bias\", \"model.11.conv_block.5.weight\", \"model.11.conv_block.5.bias\", \"model.12.conv_block.1.weight\", \"model.12.conv_block.1.bias\", \"model.12.conv_block.5.weight\", \"model.12.conv_block.5.bias\", \"model.13.conv_block.1.weight\", \"model.13.conv_block.1.bias\", \"model.13.conv_block.5.weight\", \"model.13.conv_block.5.bias\", \"model.14.conv_block.1.weight\", \"model.14.conv_block.1.bias\", \"model.14.conv_block.5.weight\", \"model.14.conv_block.5.bias\", \"model.15.conv_block.1.weight\", \"model.15.conv_block.1.bias\", \"model.15.conv_block.5.weight\", \"model.15.conv_block.5.bias\", \"model.16.conv_block.1.weight\", \"model.16.conv_block.1.bias\", \"model.16.conv_block.5.weight\", \"model.16.conv_block.5.bias\", \"model.17.conv_block.1.weight\", \"model.17.conv_block.1.bias\", \"model.17.conv_block.5.weight\", \"model.17.conv_block.5.bias\", \"model.18.conv_block.1.weight\", \"model.18.conv_block.1.bias\", \"model.18.conv_block.5.weight\", \"model.18.conv_block.5.bias\", \"model.19.weight\", \"model.19.bias\", \"model.22.weight\", \"model.22.bias\", \"model.26.weight\", \"model.26.bias\". ",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m gan, image_transforms \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_model_and_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./checkpoints/clear2snowy.pth\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[12], line 82\u001b[0m, in \u001b[0;36mcreate_model_and_transform\u001b[0;34m(pretrained)\u001b[0m\n\u001b[1;32m     80\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pretrained:\n\u001b[1;32m     81\u001b[0m     chkpntA \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(pretrained)\n\u001b[0;32m---> 82\u001b[0m     \u001b[43mnetG_A\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchkpntA\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     83\u001b[0m netG_A\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m     85\u001b[0m netG_A \u001b[38;5;241m=\u001b[39m netG_A\u001b[38;5;241m.\u001b[39mcuda()\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:2189\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m   2184\u001b[0m         error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m   2185\u001b[0m             \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m   2186\u001b[0m                 \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[1;32m   2188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 2189\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m   2190\u001b[0m                        \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m   2191\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for UnetGenerator:\n\tMissing key(s) in state_dict: \"model.model.0.weight\", \"model.model.0.bias\", \"model.model.1.model.1.weight\", \"model.model.1.model.1.bias\", \"model.model.1.model.3.model.1.weight\", \"model.model.1.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias\", \"model.model.1.model.3.model.3.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.3.model.5.weight\", \"model.model.1.model.3.model.3.model.5.bias\", \"model.model.1.model.3.model.5.weight\", \"model.model.1.model.3.model.5.bias\", \"model.model.1.model.5.weight\", \"model.model.1.model.5.bias\", \"model.model.3.weight\", \"model.model.3.bias\". \n\tUnexpected key(s) in state_dict: \"model.1.weight\", \"model.1.bias\", \"model.4.weight\", \"model.4.bias\", \"model.7.weight\", \"model.7.bias\", \"model.10.conv_block.1.weight\", \"model.10.conv_block.1.bias\", \"model.10.conv_block.5.weight\", \"model.10.conv_block.5.bias\", \"model.11.conv_block.1.weight\", \"model.11.conv_block.1.bias\", \"model.11.conv_block.5.weight\", \"model.11.conv_block.5.bias\", \"model.12.conv_block.1.weight\", \"model.12.conv_block.1.bias\", \"model.12.conv_block.5.weight\", \"model.12.conv_block.5.bias\", \"model.13.conv_block.1.weight\", \"model.13.conv_block.1.bias\", \"model.13.conv_block.5.weight\", \"model.13.conv_block.5.bias\", \"model.14.conv_block.1.weight\", \"model.14.conv_block.1.bias\", \"model.14.conv_block.5.weight\", \"model.14.conv_block.5.bias\", \"model.15.conv_block.1.weight\", \"model.15.conv_block.1.bias\", \"model.15.conv_block.5.weight\", \"model.15.conv_block.5.bias\", \"model.16.conv_block.1.weight\", \"model.16.conv_block.1.bias\", \"model.16.conv_block.5.weight\", \"model.16.conv_block.5.bias\", \"model.17.conv_block.1.weight\", \"model.17.conv_block.1.bias\", \"model.17.conv_block.5.weight\", \"model.17.conv_block.5.bias\", \"model.18.conv_block.1.weight\", \"model.18.conv_block.1.bias\", \"model.18.conv_block.5.weight\", \"model.18.conv_block.5.bias\", \"model.19.weight\", \"model.19.bias\", \"model.22.weight\", \"model.22.bias\", \"model.26.weight\", \"model.26.bias\". "
     ]
    }
   ],
   "source": [
    "gan, image_transforms = create_model_and_transform(\"./checkpoints/clear2snowy.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d44ebf97",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:39<00:00,  2.51it/s]\n"
     ]
    }
   ],
   "source": [
    "image_path = os.listdir(\"./data/images\")\n",
    "save_folder = \"./data/gan/snow_images\"\n",
    "\n",
    "for img in tqdm(image_path):\n",
    "    trg = os.path.join(\"./data/images\", img)\n",
    "    src = os.path.join(f\"./data/gan/snow_images/\", img.split(\".\")[0] + \".jpg\")\n",
    "    if not (os.path.exists(src)):\n",
    "        out = run_inference(img_path=trg, model=gan, transform=image_transforms)\n",
    "        out.save(src)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}