byrkbrk commited on
Commit
1a548a3
·
verified ·
1 Parent(s): 3c006e8

Upload 9 files

Browse files
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from synthesizer import SRSynthesizer
4
+ from gradio_imageslider import ImageSlider
5
+
6
+
7
+
8
+ if __name__ == "__main__":
9
+ sr_synthesizer = SRSynthesizer(create_dirs=False)
10
+ gr_interface = gr.Interface(
11
+ fn=lambda image: sr_synthesizer.synthesize(image,
12
+ show=False,
13
+ save=False,
14
+ return_input=True),
15
+ inputs=[gr.Image(type="pil", label="Input")],
16
+ outputs=ImageSlider(type="pil", label="Output", show_download_button=True),
17
+ title="Super Resolution Image Synthesizer",
18
+ examples=[
19
+ [os.path.join(os.path.dirname(__file__), "low-res-images", "building.png")],
20
+ [os.path.join(os.path.dirname(__file__), "low-res-images", "plant.png")],
21
+ [os.path.join(os.path.dirname(__file__), "low-res-images", "penguin.png")],
22
+ [os.path.join(os.path.dirname(__file__), "low-res-images", "vietnam_park.jpg")],
23
+ ],
24
+ description="Synthesize (4x-upscaled) super-resolved images"
25
+ )
26
+ gr_interface.launch()
low-res-images/building.png ADDED
low-res-images/penguin.png ADDED
low-res-images/plant.png ADDED
low-res-images/vietnam_park.jpg ADDED
model/eval_seemore_t_x4.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://huggingface.co/spaces/eduardzamfir/SeeMoreDetails/blob/main/configs/eval_seemore_t_x4.yml
2
+ scale: 4
3
+ in_chans: 3
4
+ num_experts: 3
5
+ img_range: 1.0
6
+ num_layers: 6
7
+ embedding_dim: 36
8
+ use_shuffle: True
9
+ lr_space: exp
10
+ topk: 1
11
+ recursive: 2
12
+ global_kernel_size: 11
model/seemore.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/eduardzamfir/seemoredetails/blob/main/basicsr/archs/seemore_arch.py
2
+ from typing import Tuple, List
3
+ from torch import Tensor
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops.layers.torch import Rearrange
9
+
10
+
11
+ ######################
12
+ # Meta Architecture
13
+ ######################
14
+ class SeemoRe(nn.Module):
15
+ def __init__(self,
16
+ scale: int = 4,
17
+ in_chans: int = 3,
18
+ num_experts: int = 6,
19
+ num_layers: int = 6,
20
+ embedding_dim: int = 64,
21
+ img_range: float = 1.0,
22
+ use_shuffle: bool = False,
23
+ global_kernel_size: int = 11,
24
+ recursive: int = 2,
25
+ lr_space: int = 1,
26
+ topk: int = 2,):
27
+ super().__init__()
28
+ self.scale = scale
29
+ self.num_in_channels = in_chans
30
+ self.num_out_channels = in_chans
31
+ self.img_range = img_range
32
+
33
+ rgb_mean = (0.4488, 0.4371, 0.4040)
34
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
35
+
36
+
37
+ # -- SHALLOW FEATURES --
38
+ self.conv_1 = nn.Conv2d(self.num_in_channels, embedding_dim, kernel_size=3, padding=1)
39
+
40
+ # -- DEEP FEATURES --
41
+ self.body = nn.ModuleList(
42
+ [ResGroup(in_ch=embedding_dim,
43
+ num_experts=num_experts,
44
+ use_shuffle=use_shuffle,
45
+ topk=topk,
46
+ lr_space=lr_space,
47
+ recursive=recursive,
48
+ global_kernel_size=global_kernel_size) for i in range(num_layers)]
49
+ )
50
+
51
+ # -- UPSCALE --
52
+ self.norm = LayerNorm(embedding_dim, data_format='channels_first')
53
+ self.conv_2 = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1)
54
+ self.upsampler = nn.Sequential(
55
+ nn.Conv2d(embedding_dim, (scale**2) * self.num_out_channels, kernel_size=3, padding=1),
56
+ nn.PixelShuffle(scale)
57
+ )
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ self.mean = self.mean.type_as(x)
61
+ x = (x - self.mean) * self.img_range
62
+
63
+ # -- SHALLOW FEATURES --
64
+ x = self.conv_1(x)
65
+ res = x
66
+
67
+ # -- DEEP FEATURES --
68
+ for idx, layer in enumerate(self.body):
69
+ x = layer(x)
70
+
71
+ x = self.norm(x)
72
+
73
+ # -- HR IMAGE RECONSTRUCTION --
74
+ x = self.conv_2(x) + res
75
+ x = self.upsampler(x)
76
+
77
+ x = x / self.img_range + self.mean
78
+ return x
79
+
80
+
81
+
82
+ #############################
83
+ # Components
84
+ #############################
85
+ class ResGroup(nn.Module):
86
+ def __init__(self,
87
+ in_ch: int,
88
+ num_experts: int,
89
+ global_kernel_size: int = 11,
90
+ lr_space: int = 1,
91
+ topk: int = 2,
92
+ recursive: int = 2,
93
+ use_shuffle: bool = False):
94
+ super().__init__()
95
+
96
+ self.local_block = RME(in_ch=in_ch,
97
+ num_experts=num_experts,
98
+ use_shuffle=use_shuffle,
99
+ lr_space=lr_space,
100
+ topk=topk,
101
+ recursive=recursive)
102
+ self.global_block = SME(in_ch=in_ch,
103
+ kernel_size=global_kernel_size)
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ x = self.local_block(x)
107
+ x = self.global_block(x)
108
+ return x
109
+
110
+
111
+
112
+ #############################
113
+ # Global Block
114
+ #############################
115
+ class SME(nn.Module):
116
+ def __init__(self,
117
+ in_ch: int,
118
+ kernel_size: int = 11):
119
+ super().__init__()
120
+
121
+ self.norm_1 = LayerNorm(in_ch, data_format='channels_first')
122
+ self.block = StripedConvFormer(in_ch=in_ch, kernel_size=kernel_size)
123
+
124
+ self.norm_2 = LayerNorm(in_ch, data_format='channels_first')
125
+ self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU())
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ x = self.block(self.norm_1(x)) + x
129
+ x = self.ffn(self.norm_2(x)) + x
130
+ return x
131
+
132
+
133
+
134
+
135
+ class StripedConvFormer(nn.Module):
136
+ def __init__(self,
137
+ in_ch: int,
138
+ kernel_size: int):
139
+ super().__init__()
140
+ self.in_ch = in_ch
141
+ self.kernel_size = kernel_size
142
+ self.padding = kernel_size // 2
143
+
144
+ self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
145
+ self.to_qv = nn.Sequential(
146
+ nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, padding=0),
147
+ nn.GELU(),
148
+ )
149
+
150
+ self.attn = StripedConv2d(in_ch, kernel_size=kernel_size, depthwise=True)
151
+
152
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
153
+ q, v = self.to_qv(x).chunk(2, dim=1)
154
+ q = self.attn(q)
155
+ x = self.proj(q * v)
156
+ return x
157
+
158
+
159
+
160
+ #############################
161
+ # Local Blocks
162
+ #############################
163
+ class RME(nn.Module):
164
+ def __init__(self,
165
+ in_ch: int,
166
+ num_experts: int,
167
+ topk: int,
168
+ lr_space: int = 1,
169
+ recursive: int = 2,
170
+ use_shuffle: bool = False,):
171
+ super().__init__()
172
+
173
+ self.norm_1 = LayerNorm(in_ch, data_format='channels_first')
174
+ self.block = MoEBlock(in_ch=in_ch, num_experts=num_experts, topk=topk, use_shuffle=use_shuffle, recursive=recursive, lr_space=lr_space,)
175
+
176
+ self.norm_2 = LayerNorm(in_ch, data_format='channels_first')
177
+ self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU())
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ x = self.block(self.norm_1(x)) + x
181
+ x = self.ffn(self.norm_2(x)) + x
182
+ return x
183
+
184
+
185
+
186
+ #################
187
+ # MoE Layer
188
+ #################
189
+ class MoEBlock(nn.Module):
190
+ def __init__(self,
191
+ in_ch: int,
192
+ num_experts: int,
193
+ topk: int,
194
+ use_shuffle: bool = False,
195
+ lr_space: str = "linear",
196
+ recursive: int = 2):
197
+ super().__init__()
198
+ self.use_shuffle = use_shuffle
199
+ self.recursive = recursive
200
+
201
+ self.conv_1 = nn.Sequential(
202
+ nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1),
203
+ nn.GELU(),
204
+ nn.Conv2d(in_ch, 2*in_ch, kernel_size=1, padding=0)
205
+ )
206
+
207
+ self.agg_conv = nn.Sequential(
208
+ nn.Conv2d(in_ch, in_ch, kernel_size=4, stride=4, groups=in_ch),
209
+ nn.GELU())
210
+
211
+ self.conv = nn.Sequential(
212
+ nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, groups=in_ch),
213
+ nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
214
+ )
215
+
216
+ self.conv_2 = nn.Sequential(
217
+ StripedConv2d(in_ch, kernel_size=3, depthwise=True),
218
+ nn.GELU())
219
+
220
+ if lr_space == "linear":
221
+ grow_func = lambda i: i+2
222
+ elif lr_space == "exp":
223
+ grow_func = lambda i: 2**(i+1)
224
+ elif lr_space == "double":
225
+ grow_func = lambda i: 2*i+2
226
+ else:
227
+ raise NotImplementedError(f"lr_space {lr_space} not implemented")
228
+
229
+ self.moe_layer = MoELayer(
230
+ experts=[Expert(in_ch=in_ch, low_dim=grow_func(i)) for i in range(num_experts)], # add here multiple of 2 as low_dim
231
+ gate=Router(in_ch=in_ch, num_experts=num_experts),
232
+ num_expert=topk,
233
+ )
234
+
235
+ self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
236
+
237
+ def calibrate(self, x: torch.Tensor) -> torch.Tensor:
238
+ b, c, h, w = x.shape
239
+ res = x
240
+
241
+ for _ in range(self.recursive):
242
+ x = self.agg_conv(x)
243
+ x = self.conv(x)
244
+ x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
245
+ return res + x
246
+
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ x = self.conv_1(x)
249
+
250
+ if self.use_shuffle:
251
+ x = channel_shuffle(x, groups=2)
252
+ x, k = torch.chunk(x, chunks=2, dim=1)
253
+
254
+ x = self.conv_2(x)
255
+ k = self.calibrate(k)
256
+
257
+ x = self.moe_layer(x, k)
258
+ x = self.proj(x)
259
+ return x
260
+
261
+
262
+ class MoELayer(nn.Module):
263
+ def __init__(self, experts: List[nn.Module], gate: nn.Module, num_expert: int = 1):
264
+ super().__init__()
265
+ assert len(experts) > 0
266
+ self.experts = nn.ModuleList(experts)
267
+ self.gate = gate
268
+ self.num_expert = num_expert
269
+
270
+ def forward(self, inputs: torch.Tensor, k: torch.Tensor):
271
+ out = self.gate(inputs)
272
+ weights = F.softmax(out, dim=1, dtype=torch.float).to(inputs.dtype)
273
+ topk_weights, topk_experts = torch.topk(weights, self.num_expert)
274
+ out = inputs.clone()
275
+
276
+ if self.training:
277
+ exp_weights = torch.zeros_like(weights)
278
+ exp_weights.scatter_(1, topk_experts, weights.gather(1, topk_experts))
279
+ for i, expert in enumerate(self.experts):
280
+ out += expert(inputs, k) * exp_weights[:, i:i+1, None, None]
281
+ else:
282
+ selected_experts = [self.experts[i] for i in topk_experts.squeeze(dim=0)]
283
+ for i, expert in enumerate(selected_experts):
284
+ out += expert(inputs, k) * topk_weights[:, i:i+1, None, None]
285
+
286
+ return out
287
+
288
+
289
+
290
+ class Expert(nn.Module):
291
+ def __init__(self,
292
+ in_ch: int,
293
+ low_dim: int,):
294
+ super().__init__()
295
+ self.conv_1 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0)
296
+ self.conv_2 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0)
297
+ self.conv_3 = nn.Conv2d(low_dim, in_ch, kernel_size=1, padding=0)
298
+
299
+ def forward(self, x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
300
+ x = self.conv_1(x)
301
+ x = self.conv_2(k) * x # here no more sigmoid
302
+ x = self.conv_3(x)
303
+ return x
304
+
305
+
306
+ class Router(nn.Module):
307
+ def __init__(self,
308
+ in_ch: int,
309
+ num_experts: int):
310
+ super().__init__()
311
+
312
+ self.body = nn.Sequential(
313
+ nn.AdaptiveAvgPool2d(1),
314
+ Rearrange('b c 1 1 -> b c'),
315
+ nn.Linear(in_ch, num_experts, bias=False),
316
+ )
317
+
318
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
319
+ return self.body(x)
320
+
321
+
322
+
323
+ #################
324
+ # Utilities
325
+ #################
326
+ class StripedConv2d(nn.Module):
327
+ def __init__(self,
328
+ in_ch: int,
329
+ kernel_size: int,
330
+ depthwise: bool = False):
331
+ super().__init__()
332
+ self.in_ch = in_ch
333
+ self.kernel_size = kernel_size
334
+ self.padding = kernel_size // 2
335
+
336
+ self.conv = nn.Sequential(
337
+ nn.Conv2d(in_ch, in_ch, kernel_size=(1, self.kernel_size), padding=(0, self.padding), groups=in_ch if depthwise else 1),
338
+ nn.Conv2d(in_ch, in_ch, kernel_size=(self.kernel_size, 1), padding=(self.padding, 0), groups=in_ch if depthwise else 1),
339
+ )
340
+
341
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
342
+ return self.conv(x)
343
+
344
+
345
+
346
+ def channel_shuffle(x, groups=2):
347
+ bat_size, channels, w, h = x.shape
348
+ group_c = channels // groups
349
+ x = x.view(bat_size, groups, group_c, w, h)
350
+ x = torch.transpose(x, 1, 2).contiguous()
351
+ x = x.view(bat_size, -1, w, h)
352
+ return x
353
+
354
+
355
+ class GatedFFN(nn.Module):
356
+ def __init__(self,
357
+ in_ch,
358
+ mlp_ratio,
359
+ kernel_size,
360
+ act_layer,):
361
+ super().__init__()
362
+ mlp_ch = in_ch * mlp_ratio
363
+
364
+ self.fn_1 = nn.Sequential(
365
+ nn.Conv2d(in_ch, mlp_ch, kernel_size=1, padding=0),
366
+ act_layer,
367
+ )
368
+ self.fn_2 = nn.Sequential(
369
+ nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0),
370
+ act_layer,
371
+ )
372
+
373
+ self.gate = nn.Conv2d(mlp_ch // 2, mlp_ch // 2,
374
+ kernel_size=kernel_size, padding=kernel_size // 2, groups=mlp_ch // 2)
375
+
376
+ def feat_decompose(self, x):
377
+ s = x - self.gate(x)
378
+ x = x + self.sigma * s
379
+ return x
380
+
381
+ def forward(self, x: torch.Tensor):
382
+ x = self.fn_1(x)
383
+ x, gate = torch.chunk(x, 2, dim=1)
384
+
385
+ gate = self.gate(gate)
386
+ x = x * gate
387
+
388
+ x = self.fn_2(x)
389
+ return x
390
+
391
+
392
+
393
+ class LayerNorm(nn.Module):
394
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
395
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
396
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
397
+ with shape (batch_size, channels, height, width).
398
+ """
399
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
400
+ super().__init__()
401
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
402
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
403
+ self.eps = eps
404
+ self.data_format = data_format
405
+ if self.data_format not in ["channels_last", "channels_first"]:
406
+ raise NotImplementedError
407
+ self.normalized_shape = (normalized_shape, )
408
+
409
+ def forward(self, x):
410
+ if self.data_format == "channels_last":
411
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
412
+ elif self.data_format == "channels_first":
413
+ u = x.mean(1, keepdim=True)
414
+ s = (x - u).pow(2).mean(1, keepdim=True)
415
+ x = (x - u) / torch.sqrt(s + self.eps)
416
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
417
+ return x
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ torchvision
4
+ einops
5
+ gradio
6
+ huggingface-hub
7
+ pillow
8
+ PyYAML
9
+ gradio_imageslider
synthesizer.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from model.seemore import SeemoRe
7
+ from huggingface_hub import hf_hub_download
8
+
9
+
10
+
11
+ class SRSynthesizer(object):
12
+ repo_id = "eduardzamfir/SeemoRe-T"
13
+ checkpoint_name = "SeemoRe_T_X4.pth"
14
+ model_config_name = "eval_seemore_t_x4.yaml"
15
+
16
+ def __init__(self,
17
+ device: str = None,
18
+ create_dirs: bool = True):
19
+ self.module_dir = os.path.dirname(__file__)
20
+ self.device = self.initialize_device(device)
21
+ self.download_model_checkpoint(self.__class__.repo_id,
22
+ self.__class__.checkpoint_name)
23
+ self.model = self.instantiate_model(self.__class__.checkpoint_name,
24
+ self.__class__.model_config_name,
25
+ self.device)
26
+ if create_dirs: self.create_dirs(self.module_dir)
27
+
28
+ @torch.inference_mode()
29
+ def synthesize(self, image, show=True, save=True, return_input=False):
30
+ """Returns synthesized image for given image"""
31
+ if isinstance(image, str):
32
+ synthesized_image_name = image
33
+ image = self.read_image(self.module_dir, "low-res-images", image)
34
+ else:
35
+ synthesized_image_name = "synthesized_image.png"
36
+
37
+ synthesized_image = self.model(transforms.ToTensor()(image).to(self.device))
38
+ synthesized_image = transforms.Compose([lambda x: torch.clamp(x, 0, 1),
39
+ transforms.ToPILImage()])(synthesized_image.squeeze().cpu())
40
+ if show:
41
+ image.show()
42
+ synthesized_image.show()
43
+ if save:
44
+ synthesized_image.save(os.path.join(self.module_dir,
45
+ "synthesized-images",
46
+ synthesized_image_name))
47
+ if return_input:
48
+ return image, synthesized_image
49
+ return synthesized_image
50
+
51
+ def instantiate_model(self, checkpoint_name, model_config_name, device):
52
+ """Returns instantiated model for given arguments"""
53
+ model = SeemoRe(**self.read_model_config_file(model_config_name)).to(device)
54
+ model.load_state_dict(self.load_checkpoint(checkpoint_name, device))
55
+ return model
56
+
57
+ def read_model_config_file(self, config_name):
58
+ """Returns read yaml file for given config name"""
59
+ root = self.module_dir
60
+ base_folder = "model"
61
+ with open(os.path.join(root, base_folder, config_name), "r") as file:
62
+ return yaml.safe_load(file)
63
+
64
+ def load_checkpoint(self, checkpoint_name, device):
65
+ """Loads the checkpoint from memory for given checkpoint name"""
66
+ root = self.module_dir
67
+ base_folder = "model"
68
+ checkpoint = torch.load(os.path.join(root, base_folder, checkpoint_name),
69
+ weights_only=True,
70
+ map_location=device)
71
+ return checkpoint["params"]
72
+
73
+ def download_model_checkpoint(self, repo_id, checkpoint_name, location=None):
74
+ """Downloads the model checkpoint from huggingface to given location"""
75
+ if location is None:
76
+ location = os.path.join(self.module_dir, "model")
77
+ hf_hub_download(repo_id=repo_id,
78
+ filename=checkpoint_name,
79
+ local_dir=location)
80
+
81
+ def initialize_device(self, device: str):
82
+ """Returns device based on GPU availability"""
83
+ if device is None:
84
+ if torch.cuda.is_available():
85
+ device = "cuda"
86
+ elif torch.backends.mps.is_available():
87
+ device = "mps"
88
+ else:
89
+ device = "cpu"
90
+ return torch.device(device)
91
+
92
+ def read_image(self, root, base_folder, image_name):
93
+ """Returns opened image file for given image name"""
94
+ return Image.open(os.path.join(root, base_folder, image_name))
95
+
96
+ def create_dirs(self, root: str) -> None:
97
+ """Creates required directories during inference under root"""
98
+ dir_names = ["low-res-images", "synthesized-images"]
99
+ for dir_name in dir_names:
100
+ os.makedirs(os.path.join(root, dir_name), exist_ok=True)