roubaofeipi commited on
Commit
5231633
·
verified ·
1 Parent(s): 37f4313

Upload 100 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. core/__init__.py +372 -0
  3. core/data/__init__.py +69 -0
  4. core/data/bucketeer.py +88 -0
  5. core/data/bucketeer_deg.py +91 -0
  6. core/data/deg_kair_utils/test.bmp +0 -0
  7. core/data/deg_kair_utils/test.png +3 -0
  8. core/data/deg_kair_utils/utils_alignfaces.py +263 -0
  9. core/data/deg_kair_utils/utils_blindsr.py +631 -0
  10. core/data/deg_kair_utils/utils_bnorm.py +91 -0
  11. core/data/deg_kair_utils/utils_deblur.py +655 -0
  12. core/data/deg_kair_utils/utils_dist.py +201 -0
  13. core/data/deg_kair_utils/utils_googledownload.py +93 -0
  14. core/data/deg_kair_utils/utils_image.py +1016 -0
  15. core/data/deg_kair_utils/utils_lmdb.py +205 -0
  16. core/data/deg_kair_utils/utils_logger.py +66 -0
  17. core/data/deg_kair_utils/utils_mat.py +88 -0
  18. core/data/deg_kair_utils/utils_matconvnet.py +197 -0
  19. core/data/deg_kair_utils/utils_model.py +330 -0
  20. core/data/deg_kair_utils/utils_modelsummary.py +485 -0
  21. core/data/deg_kair_utils/utils_option.py +255 -0
  22. core/data/deg_kair_utils/utils_params.py +135 -0
  23. core/data/deg_kair_utils/utils_receptivefield.py +62 -0
  24. core/data/deg_kair_utils/utils_regularizers.py +104 -0
  25. core/data/deg_kair_utils/utils_sisr.py +848 -0
  26. core/data/deg_kair_utils/utils_video.py +493 -0
  27. core/data/deg_kair_utils/utils_videoio.py +555 -0
  28. core/scripts/__init__.py +0 -0
  29. core/scripts/cli.py +41 -0
  30. core/templates/__init__.py +1 -0
  31. core/templates/diffusion.py +236 -0
  32. core/utils/__init__.py +9 -0
  33. core/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  34. core/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  35. core/utils/__pycache__/base_dto.cpython-310.pyc +0 -0
  36. core/utils/__pycache__/base_dto.cpython-39.pyc +0 -0
  37. core/utils/__pycache__/save_and_load.cpython-310.pyc +0 -0
  38. core/utils/__pycache__/save_and_load.cpython-39.pyc +0 -0
  39. core/utils/base_dto.py +56 -0
  40. core/utils/save_and_load.py +59 -0
  41. figures/California_000490.jpg +3 -0
  42. figures/example_dataset/000008.jpg +3 -0
  43. figures/example_dataset/000008.json +2 -0
  44. figures/example_dataset/000012.jpg +3 -0
  45. figures/example_dataset/000012.json +1 -0
  46. figures/teaser.jpg +3 -0
  47. gdf/__init__.py +205 -0
  48. gdf/loss_weights.py +101 -0
  49. gdf/noise_conditions.py +102 -0
  50. gdf/readme.md +86 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ core/data/deg_kair_utils/test.png filter=lfs diff=lfs merge=lfs -text
37
+ figures/California_000490.jpg filter=lfs diff=lfs merge=lfs -text
38
+ figures/example_dataset/000008.jpg filter=lfs diff=lfs merge=lfs -text
39
+ figures/example_dataset/000012.jpg filter=lfs diff=lfs merge=lfs -text
40
+ figures/teaser.jpg filter=lfs diff=lfs merge=lfs -text
core/__init__.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from torch import nn
5
+ import wandb
6
+ import json
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from torch.utils.data import Dataset, DataLoader
10
+
11
+ from torch.distributed import init_process_group, destroy_process_group, barrier
12
+ from torch.distributed.fsdp import (
13
+ FullyShardedDataParallel as FSDP,
14
+ FullStateDictConfig,
15
+ MixedPrecision,
16
+ ShardingStrategy,
17
+ StateDictType
18
+ )
19
+
20
+ from .utils import Base, EXPECTED, EXPECTED_TRAIN
21
+ from .utils import create_folder_if_necessary, safe_save, load_or_fail
22
+
23
+ # pylint: disable=unused-argument
24
+ class WarpCore(ABC):
25
+ @dataclass(frozen=True)
26
+ class Config(Base):
27
+ experiment_id: str = EXPECTED_TRAIN
28
+ checkpoint_path: str = EXPECTED_TRAIN
29
+ output_path: str = EXPECTED_TRAIN
30
+ checkpoint_extension: str = "safetensors"
31
+ dist_file_subfolder: str = ""
32
+ allow_tf32: bool = True
33
+
34
+ wandb_project: str = None
35
+ wandb_entity: str = None
36
+
37
+ @dataclass() # not frozen, means that fields are mutable
38
+ class Info(): # not inheriting from Base, because we don't want to enforce the default fields
39
+ wandb_run_id: str = None
40
+ total_steps: int = 0
41
+ iter: int = 0
42
+
43
+ @dataclass(frozen=True)
44
+ class Data(Base):
45
+ dataset: Dataset = EXPECTED
46
+ dataloader: DataLoader = EXPECTED
47
+ iterator: any = EXPECTED
48
+
49
+ @dataclass(frozen=True)
50
+ class Models(Base):
51
+ pass
52
+
53
+ @dataclass(frozen=True)
54
+ class Optimizers(Base):
55
+ pass
56
+
57
+ @dataclass(frozen=True)
58
+ class Schedulers(Base):
59
+ pass
60
+
61
+ @dataclass(frozen=True)
62
+ class Extras(Base):
63
+ pass
64
+ # ---------------------------------------
65
+ info: Info
66
+ config: Config
67
+
68
+ # FSDP stuff
69
+ fsdp_defaults = {
70
+ "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
71
+ "cpu_offload": None,
72
+ "mixed_precision": MixedPrecision(
73
+ param_dtype=torch.bfloat16,
74
+ reduce_dtype=torch.bfloat16,
75
+ buffer_dtype=torch.bfloat16,
76
+ ),
77
+ "limit_all_gathers": True,
78
+ }
79
+ fsdp_fullstate_save_policy = FullStateDictConfig(
80
+ offload_to_cpu=True, rank0_only=True
81
+ )
82
+ # ------------
83
+
84
+ # OVERRIDEABLE METHODS
85
+
86
+ # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
87
+ def setup_extras_pre(self) -> Extras:
88
+ return self.Extras()
89
+
90
+ # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
91
+ @abstractmethod
92
+ def setup_data(self, extras: Extras) -> Data:
93
+ raise NotImplementedError("This method needs to be overriden")
94
+
95
+ # return a dict with all models that are going to be used in the training
96
+ @abstractmethod
97
+ def setup_models(self, extras: Extras) -> Models:
98
+ raise NotImplementedError("This method needs to be overriden")
99
+
100
+ # return a dict with all optimizers that are going to be used in the training
101
+ @abstractmethod
102
+ def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
103
+ raise NotImplementedError("This method needs to be overriden")
104
+
105
+ # [optionally] return a dict with all schedulers that are going to be used in the training
106
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
107
+ return self.Schedulers()
108
+
109
+ # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
110
+ def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
111
+ return self.Extras.from_dict(extras.to_dict())
112
+
113
+ # perform the training here
114
+ @abstractmethod
115
+ def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
116
+ raise NotImplementedError("This method needs to be overriden")
117
+ # ------------
118
+
119
+ def setup_info(self, full_path=None) -> Info:
120
+ if full_path is None:
121
+ full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
122
+ info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
123
+ info_dto = self.Info(**info_dict)
124
+ if info_dto.total_steps > 0 and self.is_main_node:
125
+ print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
126
+ return info_dto
127
+
128
+ def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
129
+ if config_file_path is not None:
130
+ if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
131
+ with open(config_file_path, "r", encoding="utf-8") as file:
132
+ loaded_config = yaml.safe_load(file)
133
+ elif config_file_path.endswith(".json"):
134
+ with open(config_file_path, "r", encoding="utf-8") as file:
135
+ loaded_config = json.load(file)
136
+ else:
137
+ raise ValueError("Config file must be either a .yml|.yaml or .json file")
138
+ return self.Config.from_dict({**loaded_config, 'training': training})
139
+ if config_dict is not None:
140
+ return self.Config.from_dict({**config_dict, 'training': training})
141
+ return self.Config(training=training)
142
+
143
+ def setup_ddp(self, experiment_id, single_gpu=False):
144
+ if not single_gpu:
145
+ local_rank = int(os.environ.get("SLURM_LOCALID"))
146
+ process_id = int(os.environ.get("SLURM_PROCID"))
147
+ world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
148
+
149
+ self.process_id = process_id
150
+ self.is_main_node = process_id == 0
151
+ self.device = torch.device(local_rank)
152
+ self.world_size = world_size
153
+
154
+ dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
155
+ # if os.path.exists(dist_file_path) and self.is_main_node:
156
+ # os.remove(dist_file_path)
157
+
158
+ torch.cuda.set_device(local_rank)
159
+ init_process_group(
160
+ backend="nccl",
161
+ rank=process_id,
162
+ world_size=world_size,
163
+ init_method=f"file://{dist_file_path}",
164
+ )
165
+ print(f"[GPU {process_id}] READY")
166
+ else:
167
+ print("Running in single thread, DDP not enabled.")
168
+
169
+ def setup_wandb(self):
170
+ if self.is_main_node and self.config.wandb_project is not None:
171
+ self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
172
+ wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
173
+
174
+ if self.info.total_steps > 0:
175
+ wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
176
+ else:
177
+ wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
178
+
179
+ # LOAD UTILITIES ----------
180
+ def load_model(self, model, model_id=None, full_path=None, strict=True):
181
+ print('in line 181 load model', type(model), model_id, full_path, strict)
182
+ if model_id is not None and full_path is None:
183
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
184
+ elif full_path is None and model_id is None:
185
+ raise ValueError(
186
+ "This method expects either 'model_id' or 'full_path' to be defined"
187
+ )
188
+
189
+ checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
190
+ if checkpoint is not None:
191
+ model.load_state_dict(checkpoint, strict=strict)
192
+ del checkpoint
193
+
194
+ return model
195
+
196
+ def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
197
+ if optim_id is not None and full_path is None:
198
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
199
+ elif full_path is None and optim_id is None:
200
+ raise ValueError(
201
+ "This method expects either 'optim_id' or 'full_path' to be defined"
202
+ )
203
+
204
+ checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
205
+ if checkpoint is not None:
206
+ try:
207
+ if fsdp_model is not None:
208
+ sharded_optimizer_state_dict = (
209
+ FSDP.scatter_full_optim_state_dict( # <---- FSDP
210
+ checkpoint
211
+ if (
212
+ self.is_main_node
213
+ or self.fsdp_defaults["sharding_strategy"]
214
+ == ShardingStrategy.NO_SHARD
215
+ )
216
+ else None,
217
+ fsdp_model,
218
+ )
219
+ )
220
+ optim.load_state_dict(sharded_optimizer_state_dict)
221
+ del checkpoint, sharded_optimizer_state_dict
222
+ else:
223
+ optim.load_state_dict(checkpoint)
224
+ # pylint: disable=broad-except
225
+ except Exception as e:
226
+ print("!!! Failed loading optimizer, skipping... Exception:", e)
227
+
228
+ return optim
229
+
230
+ # SAVE UTILITIES ----------
231
+ def save_info(self, info, suffix=""):
232
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
233
+ create_folder_if_necessary(full_path)
234
+ if self.is_main_node:
235
+ safe_save(vars(self.info), full_path)
236
+
237
+ def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
238
+ if model_id is not None and full_path is None:
239
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
240
+ elif full_path is None and model_id is None:
241
+ raise ValueError(
242
+ "This method expects either 'model_id' or 'full_path' to be defined"
243
+ )
244
+ create_folder_if_necessary(full_path)
245
+ if is_fsdp:
246
+ with FSDP.summon_full_params(model):
247
+ pass
248
+ with FSDP.state_dict_type(
249
+ model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
250
+ ):
251
+ checkpoint = model.state_dict()
252
+ if self.is_main_node:
253
+ safe_save(checkpoint, full_path)
254
+ del checkpoint
255
+ else:
256
+ if self.is_main_node:
257
+ checkpoint = model.state_dict()
258
+ safe_save(checkpoint, full_path)
259
+ del checkpoint
260
+
261
+ def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
262
+ if optim_id is not None and full_path is None:
263
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
264
+ elif full_path is None and optim_id is None:
265
+ raise ValueError(
266
+ "This method expects either 'optim_id' or 'full_path' to be defined"
267
+ )
268
+ create_folder_if_necessary(full_path)
269
+ if fsdp_model is not None:
270
+ optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
271
+ if self.is_main_node:
272
+ safe_save(optim_statedict, full_path)
273
+ del optim_statedict
274
+ else:
275
+ if self.is_main_node:
276
+ checkpoint = optim.state_dict()
277
+ safe_save(checkpoint, full_path)
278
+ del checkpoint
279
+ # -----
280
+
281
+ def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
282
+ # Temporary setup, will be overriden by setup_ddp if required
283
+ self.device = device
284
+ self.process_id = 0
285
+ self.is_main_node = True
286
+ self.world_size = 1
287
+ # ----
288
+
289
+ self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
290
+ self.info: self.Info = self.setup_info()
291
+
292
+ def __call__(self, single_gpu=False):
293
+ self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
294
+ self.setup_wandb()
295
+ if self.config.allow_tf32:
296
+ torch.backends.cuda.matmul.allow_tf32 = True
297
+ torch.backends.cudnn.allow_tf32 = True
298
+
299
+ if self.is_main_node:
300
+ print()
301
+ print("**STARTIG JOB WITH CONFIG:**")
302
+ print(yaml.dump(self.config.to_dict(), default_flow_style=False))
303
+ print("------------------------------------")
304
+ print()
305
+ print("**INFO:**")
306
+ print(yaml.dump(vars(self.info), default_flow_style=False))
307
+ print("------------------------------------")
308
+ print()
309
+
310
+ # SETUP STUFF
311
+ extras = self.setup_extras_pre()
312
+ assert extras is not None, "setup_extras_pre() must return a DTO"
313
+
314
+ data = self.setup_data(extras)
315
+ assert data is not None, "setup_data() must return a DTO"
316
+ if self.is_main_node:
317
+ print("**DATA:**")
318
+ print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
319
+ print("------------------------------------")
320
+ print()
321
+
322
+ models = self.setup_models(extras)
323
+ assert models is not None, "setup_models() must return a DTO"
324
+ if self.is_main_node:
325
+ print("**MODELS:**")
326
+ print(yaml.dump({
327
+ k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
328
+ }, default_flow_style=False))
329
+ print("------------------------------------")
330
+ print()
331
+
332
+ optimizers = self.setup_optimizers(extras, models)
333
+ assert optimizers is not None, "setup_optimizers() must return a DTO"
334
+ if self.is_main_node:
335
+ print("**OPTIMIZERS:**")
336
+ print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
337
+ print("------------------------------------")
338
+ print()
339
+
340
+ schedulers = self.setup_schedulers(extras, models, optimizers)
341
+ assert schedulers is not None, "setup_schedulers() must return a DTO"
342
+ if self.is_main_node:
343
+ print("**SCHEDULERS:**")
344
+ print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
345
+ print("------------------------------------")
346
+ print()
347
+
348
+ post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
349
+ assert post_extras is not None, "setup_extras_post() must return a DTO"
350
+ extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
351
+ if self.is_main_node:
352
+ print("**EXTRAS:**")
353
+ print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
354
+ print("------------------------------------")
355
+ print()
356
+ # -------
357
+
358
+ # TRAIN
359
+ if self.is_main_node:
360
+ print("**TRAINING STARTING...**")
361
+ self.train(data, extras, models, optimizers, schedulers)
362
+
363
+ if single_gpu is False:
364
+ barrier()
365
+ destroy_process_group()
366
+ if self.is_main_node:
367
+ print()
368
+ print("------------------------------------")
369
+ print()
370
+ print("**TRAINING COMPLETE**")
371
+ if self.config.wandb_project is not None:
372
+ wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
core/data/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import subprocess
3
+ import yaml
4
+ import os
5
+ from .bucketeer import Bucketeer
6
+
7
+ class MultiFilter():
8
+ def __init__(self, rules, default=False):
9
+ self.rules = rules
10
+ self.default = default
11
+
12
+ def __call__(self, x):
13
+ try:
14
+ x_json = x['json']
15
+ if isinstance(x_json, bytes):
16
+ x_json = json.loads(x_json)
17
+ validations = []
18
+ for k, r in self.rules.items():
19
+ if isinstance(k, tuple):
20
+ v = r(*[x_json[kv] for kv in k])
21
+ else:
22
+ v = r(x_json[k])
23
+ validations.append(v)
24
+ return all(validations)
25
+ except Exception:
26
+ return False
27
+
28
+ class MultiGetter():
29
+ def __init__(self, rules):
30
+ self.rules = rules
31
+
32
+ def __call__(self, x_json):
33
+ if isinstance(x_json, bytes):
34
+ x_json = json.loads(x_json)
35
+ outputs = []
36
+ for k, r in self.rules.items():
37
+ if isinstance(k, tuple):
38
+ v = r(*[x_json[kv] for kv in k])
39
+ else:
40
+ v = r(x_json[k])
41
+ outputs.append(v)
42
+ if len(outputs) == 1:
43
+ outputs = outputs[0]
44
+ return outputs
45
+
46
+ def setup_webdataset_path(paths, cache_path=None):
47
+ if cache_path is None or not os.path.exists(cache_path):
48
+ tar_paths = []
49
+ if isinstance(paths, str):
50
+ paths = [paths]
51
+ for path in paths:
52
+ if path.strip().endswith(".tar"):
53
+ # Avoid looking up s3 if we already have a tar file
54
+ tar_paths.append(path)
55
+ continue
56
+ bucket = "/".join(path.split("/")[:3])
57
+ result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
58
+ files = result.stdout.decode('utf-8').split()
59
+ files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
60
+ tar_paths += files
61
+
62
+ with open(cache_path, 'w', encoding='utf-8') as outfile:
63
+ yaml.dump(tar_paths, outfile, default_flow_style=False)
64
+ else:
65
+ with open(cache_path, 'r', encoding='utf-8') as file:
66
+ tar_paths = yaml.safe_load(file)
67
+
68
+ tar_paths_str = ",".join([f"{p}" for p in tar_paths])
69
+ return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
core/data/bucketeer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ from torchtools.transforms import SmartCrop
5
+ import math
6
+
7
+ class Bucketeer():
8
+ def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
9
+ assert crop_mode in ['center', 'random', 'smart']
10
+ self.crop_mode = crop_mode
11
+ self.ratios = ratios
12
+ if reverse_list:
13
+ for r in list(ratios):
14
+ if 1/r not in self.ratios:
15
+ self.ratios.append(1/r)
16
+ self.sizes = {}
17
+ for dd in density:
18
+ self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
19
+
20
+ self.batch_size = dataloader.batch_size
21
+ self.iterator = iter(dataloader)
22
+ all_sizes = []
23
+ for k, vs in self.sizes.items():
24
+ all_sizes += vs
25
+ self.buckets = {s: [] for s in all_sizes}
26
+ self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
27
+ self.p_random_ratio = p_random_ratio
28
+ self.interpolate_nearest = interpolate_nearest
29
+
30
+ def get_available_batch(self):
31
+ for b in self.buckets:
32
+ if len(self.buckets[b]) >= self.batch_size:
33
+ batch = self.buckets[b][:self.batch_size]
34
+ self.buckets[b] = self.buckets[b][self.batch_size:]
35
+ return batch
36
+ return None
37
+
38
+ def get_closest_size(self, x):
39
+ w, h = x.size(-1), x.size(-2)
40
+
41
+
42
+ best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
43
+ find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
44
+ min_ = find_dict[list(find_dict.keys())[0]]
45
+ find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
46
+ for dd, val in find_dict.items():
47
+ if val < min_:
48
+ min_ = val
49
+ find_size = self.sizes[dd][best_size_idx]
50
+
51
+ return find_size
52
+
53
+ def get_resize_size(self, orig_size, tgt_size):
54
+ if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
55
+ alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
56
+ resize_size = max(alt_min, min(tgt_size))
57
+ else:
58
+ alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
59
+ resize_size = max(alt_max, max(tgt_size))
60
+
61
+ return resize_size
62
+
63
+ def __next__(self):
64
+ batch = self.get_available_batch()
65
+ while batch is None:
66
+ elements = next(self.iterator)
67
+ for dct in elements:
68
+ img = dct['images']
69
+ size = self.get_closest_size(img)
70
+ resize_size = self.get_resize_size(img.shape[-2:], size)
71
+
72
+ if self.interpolate_nearest:
73
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
74
+ else:
75
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
76
+ if self.crop_mode == 'center':
77
+ img = torchvision.transforms.functional.center_crop(img, size)
78
+ elif self.crop_mode == 'random':
79
+ img = torchvision.transforms.RandomCrop(size)(img)
80
+ elif self.crop_mode == 'smart':
81
+ self.smartcrop.output_size = size
82
+ img = self.smartcrop(img)
83
+
84
+ self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
85
+ batch = self.get_available_batch()
86
+
87
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
88
+ return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
core/data/bucketeer_deg.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ from torchtools.transforms import SmartCrop
5
+ import math
6
+
7
+ class Bucketeer():
8
+ def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
9
+ assert crop_mode in ['center', 'random', 'smart']
10
+ self.crop_mode = crop_mode
11
+ self.ratios = ratios
12
+ if reverse_list:
13
+ for r in list(ratios):
14
+ if 1/r not in self.ratios:
15
+ self.ratios.append(1/r)
16
+ self.sizes = {}
17
+ for dd in density:
18
+ self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
19
+ print('in line 17 buckteer', self.sizes)
20
+ self.batch_size = dataloader.batch_size
21
+ self.iterator = iter(dataloader)
22
+ all_sizes = []
23
+ for k, vs in self.sizes.items():
24
+ all_sizes += vs
25
+ self.buckets = {s: [] for s in all_sizes}
26
+ self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
27
+ self.p_random_ratio = p_random_ratio
28
+ self.interpolate_nearest = interpolate_nearest
29
+
30
+ def get_available_batch(self):
31
+ for b in self.buckets:
32
+ if len(self.buckets[b]) >= self.batch_size:
33
+ batch = self.buckets[b][:self.batch_size]
34
+ self.buckets[b] = self.buckets[b][self.batch_size:]
35
+ return batch
36
+ return None
37
+
38
+ def get_closest_size(self, x):
39
+ w, h = x.size(-1), x.size(-2)
40
+ #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
41
+ # best_size_idx = np.random.randint(len(self.ratios))
42
+ #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio)
43
+ #else:
44
+
45
+ best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
46
+ find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
47
+ min_ = find_dict[list(find_dict.keys())[0]]
48
+ find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
49
+ for dd, val in find_dict.items():
50
+ if val < min_:
51
+ min_ = val
52
+ find_size = self.sizes[dd][best_size_idx]
53
+
54
+ return find_size
55
+
56
+ def get_resize_size(self, orig_size, tgt_size):
57
+ if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
58
+ alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
59
+ resize_size = max(alt_min, min(tgt_size))
60
+ else:
61
+ alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
62
+ resize_size = max(alt_max, max(tgt_size))
63
+ #print('in line 50', orig_size, tgt_size, resize_size)
64
+ return resize_size
65
+
66
+ def __next__(self):
67
+ batch = self.get_available_batch()
68
+ while batch is None:
69
+ elements = next(self.iterator)
70
+ for dct in elements:
71
+ img = dct['images']
72
+ size = self.get_closest_size(img)
73
+ resize_size = self.get_resize_size(img.shape[-2:], size)
74
+ #print('in line 74', img.size(), resize_size)
75
+ if self.interpolate_nearest:
76
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
77
+ else:
78
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
79
+ if self.crop_mode == 'center':
80
+ img = torchvision.transforms.functional.center_crop(img, size)
81
+ elif self.crop_mode == 'random':
82
+ img = torchvision.transforms.RandomCrop(size)(img)
83
+ elif self.crop_mode == 'smart':
84
+ self.smartcrop.output_size = size
85
+ img = self.smartcrop(img)
86
+ print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img))
87
+ self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
88
+ batch = self.get_available_batch()
89
+
90
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
91
+ return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
core/data/deg_kair_utils/test.bmp ADDED
core/data/deg_kair_utils/test.png ADDED

Git LFS Details

  • SHA256: f36b49d23d9bb206f8c3ec537b57bd44ed008f3f632706ac02eba73e175fa3d0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
core/data/deg_kair_utils/utils_alignfaces.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon Apr 24 15:43:29 2017
4
+ @author: zhaoy
5
+ """
6
+ import cv2
7
+ import numpy as np
8
+ from skimage import transform as trans
9
+
10
+ # reference facial points, a list of coordinates (x,y)
11
+ REFERENCE_FACIAL_POINTS = [
12
+ [30.29459953, 51.69630051],
13
+ [65.53179932, 51.50139999],
14
+ [48.02519989, 71.73660278],
15
+ [33.54930115, 92.3655014],
16
+ [62.72990036, 92.20410156]
17
+ ]
18
+
19
+ DEFAULT_CROP_SIZE = (96, 112)
20
+
21
+
22
+ def _umeyama(src, dst, estimate_scale=True, scale=1.0):
23
+ """Estimate N-D similarity transformation with or without scaling.
24
+ Parameters
25
+ ----------
26
+ src : (M, N) array
27
+ Source coordinates.
28
+ dst : (M, N) array
29
+ Destination coordinates.
30
+ estimate_scale : bool
31
+ Whether to estimate scaling factor.
32
+ Returns
33
+ -------
34
+ T : (N + 1, N + 1)
35
+ The homogeneous similarity transformation matrix. The matrix contains
36
+ NaN values only if the problem is not well-conditioned.
37
+ References
38
+ ----------
39
+ .. [1] "Least-squares estimation of transformation parameters between two
40
+ point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
41
+ """
42
+
43
+ num = src.shape[0]
44
+ dim = src.shape[1]
45
+
46
+ # Compute mean of src and dst.
47
+ src_mean = src.mean(axis=0)
48
+ dst_mean = dst.mean(axis=0)
49
+
50
+ # Subtract mean from src and dst.
51
+ src_demean = src - src_mean
52
+ dst_demean = dst - dst_mean
53
+
54
+ # Eq. (38).
55
+ A = dst_demean.T @ src_demean / num
56
+
57
+ # Eq. (39).
58
+ d = np.ones((dim,), dtype=np.double)
59
+ if np.linalg.det(A) < 0:
60
+ d[dim - 1] = -1
61
+
62
+ T = np.eye(dim + 1, dtype=np.double)
63
+
64
+ U, S, V = np.linalg.svd(A)
65
+
66
+ # Eq. (40) and (43).
67
+ rank = np.linalg.matrix_rank(A)
68
+ if rank == 0:
69
+ return np.nan * T
70
+ elif rank == dim - 1:
71
+ if np.linalg.det(U) * np.linalg.det(V) > 0:
72
+ T[:dim, :dim] = U @ V
73
+ else:
74
+ s = d[dim - 1]
75
+ d[dim - 1] = -1
76
+ T[:dim, :dim] = U @ np.diag(d) @ V
77
+ d[dim - 1] = s
78
+ else:
79
+ T[:dim, :dim] = U @ np.diag(d) @ V
80
+
81
+ if estimate_scale:
82
+ # Eq. (41) and (42).
83
+ scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
84
+ else:
85
+ scale = scale
86
+
87
+ T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
88
+ T[:dim, :dim] *= scale
89
+
90
+ return T, scale
91
+
92
+
93
+ class FaceWarpException(Exception):
94
+ def __str__(self):
95
+ return 'In File {}:{}'.format(
96
+ __file__, super.__str__(self))
97
+
98
+
99
+ def get_reference_facial_points(output_size=None,
100
+ inner_padding_factor=0.0,
101
+ outer_padding=(0, 0),
102
+ default_square=False):
103
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
104
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
105
+
106
+ # 0) make the inner region a square
107
+ if default_square:
108
+ size_diff = max(tmp_crop_size) - tmp_crop_size
109
+ tmp_5pts += size_diff / 2
110
+ tmp_crop_size += size_diff
111
+
112
+ if (output_size and
113
+ output_size[0] == tmp_crop_size[0] and
114
+ output_size[1] == tmp_crop_size[1]):
115
+ print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
116
+ return tmp_5pts
117
+
118
+ if (inner_padding_factor == 0 and
119
+ outer_padding == (0, 0)):
120
+ if output_size is None:
121
+ print('No paddings to do: return default reference points')
122
+ return tmp_5pts
123
+ else:
124
+ raise FaceWarpException(
125
+ 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
126
+
127
+ # check output size
128
+ if not (0 <= inner_padding_factor <= 1.0):
129
+ raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
130
+
131
+ if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
132
+ and output_size is None):
133
+ output_size = tmp_crop_size * \
134
+ (1 + inner_padding_factor * 2).astype(np.int32)
135
+ output_size += np.array(outer_padding)
136
+ print(' deduced from paddings, output_size = ', output_size)
137
+
138
+ if not (outer_padding[0] < output_size[0]
139
+ and outer_padding[1] < output_size[1]):
140
+ raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
141
+ 'and outer_padding[1] < output_size[1])')
142
+
143
+ # 1) pad the inner region according inner_padding_factor
144
+ # print('---> STEP1: pad the inner region according inner_padding_factor')
145
+ if inner_padding_factor > 0:
146
+ size_diff = tmp_crop_size * inner_padding_factor * 2
147
+ tmp_5pts += size_diff / 2
148
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
149
+
150
+ # print(' crop_size = ', tmp_crop_size)
151
+ # print(' reference_5pts = ', tmp_5pts)
152
+
153
+ # 2) resize the padded inner region
154
+ # print('---> STEP2: resize the padded inner region')
155
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
156
+ # print(' crop_size = ', tmp_crop_size)
157
+ # print(' size_bf_outer_pad = ', size_bf_outer_pad)
158
+
159
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
160
+ raise FaceWarpException('Must have (output_size - outer_padding)'
161
+ '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
162
+
163
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
164
+ # print(' resize scale_factor = ', scale_factor)
165
+ tmp_5pts = tmp_5pts * scale_factor
166
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
167
+ # tmp_5pts = tmp_5pts + size_diff / 2
168
+ tmp_crop_size = size_bf_outer_pad
169
+ # print(' crop_size = ', tmp_crop_size)
170
+ # print(' reference_5pts = ', tmp_5pts)
171
+
172
+ # 3) add outer_padding to make output_size
173
+ reference_5point = tmp_5pts + np.array(outer_padding)
174
+ tmp_crop_size = output_size
175
+ # print('---> STEP3: add outer_padding to make output_size')
176
+ # print(' crop_size = ', tmp_crop_size)
177
+ # print(' reference_5pts = ', tmp_5pts)
178
+ #
179
+ # print('===> end get_reference_facial_points\n')
180
+
181
+ return reference_5point
182
+
183
+
184
+ def get_affine_transform_matrix(src_pts, dst_pts):
185
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
186
+ n_pts = src_pts.shape[0]
187
+ ones = np.ones((n_pts, 1), src_pts.dtype)
188
+ src_pts_ = np.hstack([src_pts, ones])
189
+ dst_pts_ = np.hstack([dst_pts, ones])
190
+
191
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
192
+
193
+ if rank == 3:
194
+ tfm = np.float32([
195
+ [A[0, 0], A[1, 0], A[2, 0]],
196
+ [A[0, 1], A[1, 1], A[2, 1]]
197
+ ])
198
+ elif rank == 2:
199
+ tfm = np.float32([
200
+ [A[0, 0], A[1, 0], 0],
201
+ [A[0, 1], A[1, 1], 0]
202
+ ])
203
+
204
+ return tfm
205
+
206
+
207
+ def warp_and_crop_face(src_img,
208
+ facial_pts,
209
+ reference_pts=None,
210
+ crop_size=(96, 112),
211
+ align_type='smilarity'): #smilarity cv2_affine affine
212
+ if reference_pts is None:
213
+ if crop_size[0] == 96 and crop_size[1] == 112:
214
+ reference_pts = REFERENCE_FACIAL_POINTS
215
+ else:
216
+ default_square = False
217
+ inner_padding_factor = 0
218
+ outer_padding = (0, 0)
219
+ output_size = crop_size
220
+
221
+ reference_pts = get_reference_facial_points(output_size,
222
+ inner_padding_factor,
223
+ outer_padding,
224
+ default_square)
225
+
226
+ ref_pts = np.float32(reference_pts)
227
+ ref_pts_shp = ref_pts.shape
228
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
229
+ raise FaceWarpException(
230
+ 'reference_pts.shape must be (K,2) or (2,K) and K>2')
231
+
232
+ if ref_pts_shp[0] == 2:
233
+ ref_pts = ref_pts.T
234
+
235
+ src_pts = np.float32(facial_pts)
236
+ src_pts_shp = src_pts.shape
237
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
238
+ raise FaceWarpException(
239
+ 'facial_pts.shape must be (K,2) or (2,K) and K>2')
240
+
241
+ if src_pts_shp[0] == 2:
242
+ src_pts = src_pts.T
243
+
244
+ if src_pts.shape != ref_pts.shape:
245
+ raise FaceWarpException(
246
+ 'facial_pts and reference_pts must have the same shape')
247
+
248
+ if align_type is 'cv2_affine':
249
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
250
+ tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
251
+ elif align_type is 'affine':
252
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
253
+ tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
254
+ else:
255
+ params, scale = _umeyama(src_pts, ref_pts)
256
+ tfm = params[:2, :]
257
+
258
+ params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
259
+ tfm_inv = params[:2, :]
260
+
261
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
262
+
263
+ return face_img, tfm_inv
core/data/deg_kair_utils/utils_blindsr.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+
6
+ from core.data.deg_kair_utils import utils_image as util
7
+
8
+ import random
9
+ from scipy import ndimage
10
+ import scipy
11
+ import scipy.stats as ss
12
+ from scipy.interpolate import interp2d
13
+ from scipy.linalg import orth
14
+
15
+
16
+
17
+
18
+ """
19
+ # --------------------------------------------
20
+ # Super-Resolution
21
+ # --------------------------------------------
22
+ #
23
+ # Kai Zhang ([email protected])
24
+ # https://github.com/cszn
25
+ # From 2019/03--2021/08
26
+ # --------------------------------------------
27
+ """
28
+
29
+ def modcrop_np(img, sf):
30
+ '''
31
+ Args:
32
+ img: numpy image, WxH or WxHxC
33
+ sf: scale factor
34
+
35
+ Return:
36
+ cropped image
37
+ '''
38
+ w, h = img.shape[:2]
39
+ im = np.copy(img)
40
+ return im[:w - w % sf, :h - h % sf, ...]
41
+
42
+
43
+ """
44
+ # --------------------------------------------
45
+ # anisotropic Gaussian kernels
46
+ # --------------------------------------------
47
+ """
48
+ def analytic_kernel(k):
49
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
50
+ k_size = k.shape[0]
51
+ # Calculate the big kernels size
52
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
53
+ # Loop over the small kernel to fill the big one
54
+ for r in range(k_size):
55
+ for c in range(k_size):
56
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
57
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
58
+ crop = k_size // 2
59
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
60
+ # Normalize to 1
61
+ return cropped_big_k / cropped_big_k.sum()
62
+
63
+
64
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
65
+ """ generate an anisotropic Gaussian kernel
66
+ Args:
67
+ ksize : e.g., 15, kernel size
68
+ theta : [0, pi], rotation angle range
69
+ l1 : [0.1,50], scaling of eigenvalues
70
+ l2 : [0.1,l1], scaling of eigenvalues
71
+ If l1 = l2, will get an isotropic Gaussian kernel.
72
+
73
+ Returns:
74
+ k : kernel
75
+ """
76
+
77
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
78
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
79
+ D = np.array([[l1, 0], [0, l2]])
80
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
81
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
82
+
83
+ return k
84
+
85
+
86
+ def gm_blur_kernel(mean, cov, size=15):
87
+ center = size / 2.0 + 0.5
88
+ k = np.zeros([size, size])
89
+ for y in range(size):
90
+ for x in range(size):
91
+ cy = y - center + 1
92
+ cx = x - center + 1
93
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
94
+
95
+ k = k / np.sum(k)
96
+ return k
97
+
98
+
99
+ def shift_pixel(x, sf, upper_left=True):
100
+ """shift pixel for super-resolution with different scale factors
101
+ Args:
102
+ x: WxHxC or WxH
103
+ sf: scale factor
104
+ upper_left: shift direction
105
+ """
106
+ h, w = x.shape[:2]
107
+ shift = (sf-1)*0.5
108
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
109
+ if upper_left:
110
+ x1 = xv + shift
111
+ y1 = yv + shift
112
+ else:
113
+ x1 = xv - shift
114
+ y1 = yv - shift
115
+
116
+ x1 = np.clip(x1, 0, w-1)
117
+ y1 = np.clip(y1, 0, h-1)
118
+
119
+ if x.ndim == 2:
120
+ x = interp2d(xv, yv, x)(x1, y1)
121
+ if x.ndim == 3:
122
+ for i in range(x.shape[-1]):
123
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
124
+
125
+ return x
126
+
127
+
128
+ def blur(x, k):
129
+ '''
130
+ x: image, NxcxHxW
131
+ k: kernel, Nx1xhxw
132
+ '''
133
+ n, c = x.shape[:2]
134
+ p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2
135
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
136
+ k = k.repeat(1,c,1,1)
137
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
138
+ x = x.view(1, -1, x.shape[2], x.shape[3])
139
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c)
140
+ x = x.view(n, c, x.shape[2], x.shape[3])
141
+
142
+ return x
143
+
144
+
145
+
146
+ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
147
+ """"
148
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
149
+ # Kai Zhang
150
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
151
+ # max_var = 2.5 * sf
152
+ """
153
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
154
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
155
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
156
+ theta = np.random.rand() * np.pi # random theta
157
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
158
+
159
+ # Set COV matrix using Lambdas and Theta
160
+ LAMBDA = np.diag([lambda_1, lambda_2])
161
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
162
+ [np.sin(theta), np.cos(theta)]])
163
+ SIGMA = Q @ LAMBDA @ Q.T
164
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
165
+
166
+ # Set expectation position (shifting kernel for aligned image)
167
+ MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
168
+ MU = MU[None, None, :, None]
169
+
170
+ # Create meshgrid for Gaussian
171
+ [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
172
+ Z = np.stack([X, Y], 2)[:, :, :, None]
173
+
174
+ # Calcualte Gaussian for every pixel of the kernel
175
+ ZZ = Z-MU
176
+ ZZ_t = ZZ.transpose(0,1,3,2)
177
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
178
+
179
+ # shift the kernel so it will be centered
180
+ #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
181
+
182
+ # Normalize the kernel and return
183
+ #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
184
+ kernel = raw_kernel / np.sum(raw_kernel)
185
+ return kernel
186
+
187
+
188
+ def fspecial_gaussian(hsize, sigma):
189
+ hsize = [hsize, hsize]
190
+ siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
191
+ std = sigma
192
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
193
+ arg = -(x*x + y*y)/(2*std*std)
194
+ h = np.exp(arg)
195
+ h[h < scipy.finfo(float).eps * h.max()] = 0
196
+ sumh = h.sum()
197
+ if sumh != 0:
198
+ h = h/sumh
199
+ return h
200
+
201
+
202
+ def fspecial_laplacian(alpha):
203
+ alpha = max([0, min([alpha,1])])
204
+ h1 = alpha/(alpha+1)
205
+ h2 = (1-alpha)/(alpha+1)
206
+ h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
207
+ h = np.array(h)
208
+ return h
209
+
210
+
211
+ def fspecial(filter_type, *args, **kwargs):
212
+ '''
213
+ python code from:
214
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
215
+ '''
216
+ if filter_type == 'gaussian':
217
+ return fspecial_gaussian(*args, **kwargs)
218
+ if filter_type == 'laplacian':
219
+ return fspecial_laplacian(*args, **kwargs)
220
+
221
+ """
222
+ # --------------------------------------------
223
+ # degradation models
224
+ # --------------------------------------------
225
+ """
226
+
227
+
228
+ def bicubic_degradation(x, sf=3):
229
+ '''
230
+ Args:
231
+ x: HxWxC image, [0, 1]
232
+ sf: down-scale factor
233
+
234
+ Return:
235
+ bicubicly downsampled LR image
236
+ '''
237
+ x = util.imresize_np(x, scale=1/sf)
238
+ return x
239
+
240
+
241
+ def srmd_degradation(x, k, sf=3):
242
+ ''' blur + bicubic downsampling
243
+
244
+ Args:
245
+ x: HxWxC image, [0, 1]
246
+ k: hxw, double
247
+ sf: down-scale factor
248
+
249
+ Return:
250
+ downsampled LR image
251
+
252
+ Reference:
253
+ @inproceedings{zhang2018learning,
254
+ title={Learning a single convolutional super-resolution network for multiple degradations},
255
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
256
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
257
+ pages={3262--3271},
258
+ year={2018}
259
+ }
260
+ '''
261
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
262
+ x = bicubic_degradation(x, sf=sf)
263
+ return x
264
+
265
+
266
+ def dpsr_degradation(x, k, sf=3):
267
+
268
+ ''' bicubic downsampling + blur
269
+
270
+ Args:
271
+ x: HxWxC image, [0, 1]
272
+ k: hxw, double
273
+ sf: down-scale factor
274
+
275
+ Return:
276
+ downsampled LR image
277
+
278
+ Reference:
279
+ @inproceedings{zhang2019deep,
280
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
281
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
282
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
283
+ pages={1671--1681},
284
+ year={2019}
285
+ }
286
+ '''
287
+ x = bicubic_degradation(x, sf=sf)
288
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
289
+ return x
290
+
291
+
292
+ def classical_degradation(x, k, sf=3):
293
+ ''' blur + downsampling
294
+
295
+ Args:
296
+ x: HxWxC image, [0, 1]/[0, 255]
297
+ k: hxw, double
298
+ sf: down-scale factor
299
+
300
+ Return:
301
+ downsampled LR image
302
+ '''
303
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
304
+ #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
305
+ st = 0
306
+ return x[st::sf, st::sf, ...]
307
+
308
+
309
+ def add_sharpening(img, weight=0.5, radius=50, threshold=10):
310
+ """USM sharpening. borrowed from real-ESRGAN
311
+ Input image: I; Blurry image: B.
312
+ 1. K = I + weight * (I - B)
313
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
314
+ 3. Blur mask:
315
+ 4. Out = Mask * K + (1 - Mask) * I
316
+ Args:
317
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
318
+ weight (float): Sharp weight. Default: 1.
319
+ radius (float): Kernel size of Gaussian blur. Default: 50.
320
+ threshold (int):
321
+ """
322
+ if radius % 2 == 0:
323
+ radius += 1
324
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
325
+ residual = img - blur
326
+ mask = np.abs(residual) * 255 > threshold
327
+ mask = mask.astype('float32')
328
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
329
+
330
+ K = img + weight * residual
331
+ K = np.clip(K, 0, 1)
332
+ return soft_mask * K + (1 - soft_mask) * img
333
+
334
+
335
+ def add_blur(img, sf=4):
336
+ wd2 = 4.0 + sf
337
+ wd = 2.0 + 0.2*sf
338
+ if random.random() < 0.5:
339
+ l1 = wd2*random.random()
340
+ l2 = wd2*random.random()
341
+ k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2)
342
+ else:
343
+ k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random())
344
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
345
+
346
+ return img
347
+
348
+
349
+ def add_resize(img, sf=4):
350
+ rnum = np.random.rand()
351
+ if rnum > 0.8: # up
352
+ sf1 = random.uniform(1, 2)
353
+ elif rnum < 0.7: # down
354
+ sf1 = random.uniform(0.5/sf, 1)
355
+ else:
356
+ sf1 = 1.0
357
+ img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3]))
358
+ img = np.clip(img, 0.0, 1.0)
359
+
360
+ return img
361
+
362
+
363
+ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
364
+ noise_level = random.randint(noise_level1, noise_level2)
365
+ rnum = np.random.rand()
366
+ if rnum > 0.6: # add color Gaussian noise
367
+ img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
368
+ elif rnum < 0.4: # add grayscale Gaussian noise
369
+ img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
370
+ else: # add noise
371
+ L = noise_level2/255.
372
+ D = np.diag(np.random.rand(3))
373
+ U = orth(np.random.rand(3,3))
374
+ conv = np.dot(np.dot(np.transpose(U), D), U)
375
+ img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
376
+ img = np.clip(img, 0.0, 1.0)
377
+ return img
378
+
379
+
380
+ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
381
+ noise_level = random.randint(noise_level1, noise_level2)
382
+ img = np.clip(img, 0.0, 1.0)
383
+ rnum = random.random()
384
+ if rnum > 0.6:
385
+ img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
386
+ elif rnum < 0.4:
387
+ img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
388
+ else:
389
+ L = noise_level2/255.
390
+ D = np.diag(np.random.rand(3))
391
+ U = orth(np.random.rand(3,3))
392
+ conv = np.dot(np.dot(np.transpose(U), D), U)
393
+ img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
394
+ img = np.clip(img, 0.0, 1.0)
395
+ return img
396
+
397
+
398
+ def add_Poisson_noise(img):
399
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
400
+ vals = 10**(2*random.random()+2.0) # [2, 4]
401
+ if random.random() < 0.5:
402
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
403
+ else:
404
+ img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114])
405
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
406
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
407
+ img += noise_gray[:, :, np.newaxis]
408
+ img = np.clip(img, 0.0, 1.0)
409
+ return img
410
+
411
+
412
+ def add_JPEG_noise(img):
413
+ quality_factor = random.randint(30, 95)
414
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
415
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
416
+ img = cv2.imdecode(encimg, 1)
417
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
418
+ return img
419
+
420
+
421
+ def random_crop(lq, hq, sf=4, lq_patchsize=64):
422
+ h, w = lq.shape[:2]
423
+ rnd_h = random.randint(0, h-lq_patchsize)
424
+ rnd_w = random.randint(0, w-lq_patchsize)
425
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
426
+
427
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
428
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :]
429
+ return lq, hq
430
+
431
+
432
+ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
433
+ """
434
+ This is the degradation model of BSRGAN from the paper
435
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
436
+ ----------
437
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
438
+ sf: scale factor
439
+ isp_model: camera ISP model
440
+
441
+ Returns
442
+ -------
443
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
444
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
445
+ """
446
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
447
+ sf_ori = sf
448
+
449
+ h1, w1 = img.shape[:2]
450
+ img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
451
+ h, w = img.shape[:2]
452
+
453
+ if h < lq_patchsize*sf or w < lq_patchsize*sf:
454
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
455
+
456
+ hq = img.copy()
457
+
458
+ if sf == 4 and random.random() < scale2_prob: # downsample1
459
+ if np.random.rand() < 0.5:
460
+ img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3]))
461
+ else:
462
+ img = util.imresize_np(img, 1/2, True)
463
+ img = np.clip(img, 0.0, 1.0)
464
+ sf = 2
465
+
466
+ shuffle_order = random.sample(range(7), 7)
467
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
468
+ if idx1 > idx2: # keep downsample3 last
469
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
470
+
471
+ for i in shuffle_order:
472
+
473
+ if i == 0:
474
+ img = add_blur(img, sf=sf)
475
+
476
+ elif i == 1:
477
+ img = add_blur(img, sf=sf)
478
+
479
+ elif i == 2:
480
+ a, b = img.shape[1], img.shape[0]
481
+ # downsample2
482
+ if random.random() < 0.75:
483
+ sf1 = random.uniform(1,2*sf)
484
+ img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3]))
485
+ else:
486
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf))
487
+ k_shifted = shift_pixel(k, sf)
488
+ k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel
489
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
490
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
491
+ img = np.clip(img, 0.0, 1.0)
492
+
493
+ elif i == 3:
494
+ # downsample3
495
+ img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3]))
496
+ img = np.clip(img, 0.0, 1.0)
497
+
498
+ elif i == 4:
499
+ # add Gaussian noise
500
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
501
+
502
+ elif i == 5:
503
+ # add JPEG noise
504
+ if random.random() < jpeg_prob:
505
+ img = add_JPEG_noise(img)
506
+
507
+ elif i == 6:
508
+ # add processed camera sensor noise
509
+ if random.random() < isp_prob and isp_model is not None:
510
+ with torch.no_grad():
511
+ img, hq = isp_model.forward(img.copy(), hq)
512
+
513
+ # add final JPEG compression noise
514
+ img = add_JPEG_noise(img)
515
+
516
+ # random crop
517
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
518
+
519
+ return img, hq
520
+
521
+
522
+
523
+
524
+ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None):
525
+ """
526
+ This is an extended degradation model by combining
527
+ the degradation models of BSRGAN and Real-ESRGAN
528
+ ----------
529
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
530
+ sf: scale factor
531
+ use_shuffle: the degradation shuffle
532
+ use_sharp: sharpening the img
533
+
534
+ Returns
535
+ -------
536
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
537
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
538
+ """
539
+
540
+ h1, w1 = img.shape[:2]
541
+ img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
542
+ h, w = img.shape[:2]
543
+
544
+ if h < lq_patchsize*sf or w < lq_patchsize*sf:
545
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
546
+
547
+ if use_sharp:
548
+ img = add_sharpening(img)
549
+ hq = img.copy()
550
+
551
+ if random.random() < shuffle_prob:
552
+ shuffle_order = random.sample(range(13), 13)
553
+ else:
554
+ shuffle_order = list(range(13))
555
+ # local shuffle for noise, JPEG is always the last one
556
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
557
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
558
+
559
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
560
+
561
+ for i in shuffle_order:
562
+ if i == 0:
563
+ img = add_blur(img, sf=sf)
564
+ elif i == 1:
565
+ img = add_resize(img, sf=sf)
566
+ elif i == 2:
567
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
568
+ elif i == 3:
569
+ if random.random() < poisson_prob:
570
+ img = add_Poisson_noise(img)
571
+ elif i == 4:
572
+ if random.random() < speckle_prob:
573
+ img = add_speckle_noise(img)
574
+ elif i == 5:
575
+ if random.random() < isp_prob and isp_model is not None:
576
+ with torch.no_grad():
577
+ img, hq = isp_model.forward(img.copy(), hq)
578
+ elif i == 6:
579
+ img = add_JPEG_noise(img)
580
+ elif i == 7:
581
+ img = add_blur(img, sf=sf)
582
+ elif i == 8:
583
+ img = add_resize(img, sf=sf)
584
+ elif i == 9:
585
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
586
+ elif i == 10:
587
+ if random.random() < poisson_prob:
588
+ img = add_Poisson_noise(img)
589
+ elif i == 11:
590
+ if random.random() < speckle_prob:
591
+ img = add_speckle_noise(img)
592
+ elif i == 12:
593
+ if random.random() < isp_prob and isp_model is not None:
594
+ with torch.no_grad():
595
+ img, hq = isp_model.forward(img.copy(), hq)
596
+ else:
597
+ print('check the shuffle!')
598
+
599
+ # resize to desired size
600
+ img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3]))
601
+
602
+ # add final JPEG compression noise
603
+ img = add_JPEG_noise(img)
604
+
605
+ # random crop
606
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
607
+
608
+ return img, hq
609
+
610
+
611
+
612
+ if __name__ == '__main__':
613
+ img = util.imread_uint('utils/test.png', 3)
614
+ img = util.uint2single(img)
615
+ sf = 4
616
+
617
+ for i in range(20):
618
+ img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72)
619
+ print(i)
620
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
621
+ img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
622
+ util.imsave(img_concat, str(i)+'.png')
623
+
624
+ # for i in range(10):
625
+ # img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64)
626
+ # print(i)
627
+ # lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
628
+ # img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
629
+ # util.imsave(img_concat, str(i)+'.png')
630
+
631
+ # run utils/utils_blindsr.py
core/data/deg_kair_utils/utils_bnorm.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ """
6
+ # --------------------------------------------
7
+ # Batch Normalization
8
+ # --------------------------------------------
9
+
10
+ # Kai Zhang ([email protected])
11
+ # https://github.com/cszn
12
+ # 01/Jan/2019
13
+ # --------------------------------------------
14
+ """
15
+
16
+
17
+ # --------------------------------------------
18
+ # remove/delete specified layer
19
+ # --------------------------------------------
20
+ def deleteLayer(model, layer_type=nn.BatchNorm2d):
21
+ ''' Kai Zhang, 11/Jan/2019.
22
+ '''
23
+ for k, m in list(model.named_children()):
24
+ if isinstance(m, layer_type):
25
+ del model._modules[k]
26
+ deleteLayer(m, layer_type)
27
+
28
+
29
+ # --------------------------------------------
30
+ # merge bn, "conv+bn" --> "conv"
31
+ # --------------------------------------------
32
+ def merge_bn(model):
33
+ ''' Kai Zhang, 11/Jan/2019.
34
+ merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv')
35
+ based on https://github.com/pytorch/pytorch/pull/901
36
+ '''
37
+ prev_m = None
38
+ for k, m in list(model.named_children()):
39
+ if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
40
+
41
+ w = prev_m.weight.data
42
+
43
+ if prev_m.bias is None:
44
+ zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
45
+ prev_m.bias = nn.Parameter(zeros)
46
+ b = prev_m.bias.data
47
+
48
+ invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
49
+ if isinstance(prev_m, nn.ConvTranspose2d):
50
+ w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
51
+ else:
52
+ w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
53
+ b.add_(-m.running_mean).mul_(invstd)
54
+ if m.affine:
55
+ if isinstance(prev_m, nn.ConvTranspose2d):
56
+ w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
57
+ else:
58
+ w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
59
+ b.mul_(m.weight.data).add_(m.bias.data)
60
+
61
+ del model._modules[k]
62
+ prev_m = m
63
+ merge_bn(m)
64
+
65
+
66
+ # --------------------------------------------
67
+ # add bn, "conv" --> "conv+bn"
68
+ # --------------------------------------------
69
+ def add_bn(model):
70
+ ''' Kai Zhang, 11/Jan/2019.
71
+ '''
72
+ for k, m in list(model.named_children()):
73
+ if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
74
+ b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True)
75
+ b.weight.data.fill_(1)
76
+ new_m = nn.Sequential(model._modules[k], b)
77
+ model._modules[k] = new_m
78
+ add_bn(m)
79
+
80
+
81
+ # --------------------------------------------
82
+ # tidy model after removing bn
83
+ # --------------------------------------------
84
+ def tidy_sequential(model):
85
+ ''' Kai Zhang, 11/Jan/2019.
86
+ '''
87
+ for k, m in list(model.named_children()):
88
+ if isinstance(m, nn.Sequential):
89
+ if m.__len__() == 1:
90
+ model._modules[k] = m.__getitem__(0)
91
+ tidy_sequential(m)
core/data/deg_kair_utils/utils_deblur.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import scipy
4
+ from scipy import fftpack
5
+ import torch
6
+
7
+ from math import cos, sin
8
+ from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round
9
+ from numpy.random import randn, rand
10
+ from scipy.signal import convolve2d
11
+ import cv2
12
+ import random
13
+ # import utils_image as util
14
+
15
+ '''
16
+ modified by Kai Zhang (github: https://github.com/cszn)
17
+ 03/03/2019
18
+ '''
19
+
20
+
21
+ def get_uperleft_denominator(img, kernel):
22
+ '''
23
+ img: HxWxC
24
+ kernel: hxw
25
+ denominator: HxWx1
26
+ upperleft: HxWxC
27
+ '''
28
+ V = psf2otf(kernel, img.shape[:2])
29
+ denominator = np.expand_dims(np.abs(V)**2, axis=2)
30
+ upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
31
+ return upperleft, denominator
32
+
33
+
34
+ def get_uperleft_denominator_pytorch(img, kernel):
35
+ '''
36
+ img: NxCxHxW
37
+ kernel: Nx1xhxw
38
+ denominator: Nx1xHxW
39
+ upperleft: NxCxHxWx2
40
+ '''
41
+ V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2
42
+ denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW
43
+ upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2
44
+ return upperleft, denominator
45
+
46
+
47
+ def c2c(x):
48
+ return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
49
+
50
+
51
+ def r2c(x):
52
+ return torch.stack([x, torch.zeros_like(x)], -1)
53
+
54
+
55
+ def cdiv(x, y):
56
+ a, b = x[..., 0], x[..., 1]
57
+ c, d = y[..., 0], y[..., 1]
58
+ cd2 = c**2 + d**2
59
+ return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
60
+
61
+
62
+ def cabs(x):
63
+ return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
64
+
65
+
66
+ def cmul(t1, t2):
67
+ '''
68
+ complex multiplication
69
+ t1: NxCxHxWx2
70
+ output: NxCxHxWx2
71
+ '''
72
+ real1, imag1 = t1[..., 0], t1[..., 1]
73
+ real2, imag2 = t2[..., 0], t2[..., 1]
74
+ return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
75
+
76
+
77
+ def cconj(t, inplace=False):
78
+ '''
79
+ # complex's conjugation
80
+ t: NxCxHxWx2
81
+ output: NxCxHxWx2
82
+ '''
83
+ c = t.clone() if not inplace else t
84
+ c[..., 1] *= -1
85
+ return c
86
+
87
+
88
+ def rfft(t):
89
+ return torch.rfft(t, 2, onesided=False)
90
+
91
+
92
+ def irfft(t):
93
+ return torch.irfft(t, 2, onesided=False)
94
+
95
+
96
+ def fft(t):
97
+ return torch.fft(t, 2)
98
+
99
+
100
+ def ifft(t):
101
+ return torch.ifft(t, 2)
102
+
103
+
104
+ def p2o(psf, shape):
105
+ '''
106
+ # psf: NxCxhxw
107
+ # shape: [H,W]
108
+ # otf: NxCxHxWx2
109
+ '''
110
+ otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
111
+ otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
112
+ for axis, axis_size in enumerate(psf.shape[2:]):
113
+ otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
114
+ otf = torch.rfft(otf, 2, onesided=False)
115
+ n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
116
+ otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
117
+ return otf
118
+
119
+
120
+
121
+ # otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math.
122
+ def otf2psf(otf, outsize=None):
123
+ insize = np.array(otf.shape)
124
+ psf = np.fft.ifftn(otf, axes=(0, 1))
125
+ for axis, axis_size in enumerate(insize):
126
+ psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
127
+ if type(outsize) != type(None):
128
+ insize = np.array(otf.shape)
129
+ outsize = np.array(outsize)
130
+ n = max(np.size(outsize), np.size(insize))
131
+ # outsize = postpad(outsize(:), n, 1);
132
+ # insize = postpad(insize(:) , n, 1);
133
+ colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
134
+ colvec_in = insize.flatten().reshape((np.size(insize), 1))
135
+ outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
136
+ insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")
137
+
138
+ pad = (insize - outsize) / 2
139
+ if np.any(pad < 0):
140
+ print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
141
+ prepad = np.floor(pad)
142
+ postpad = np.ceil(pad)
143
+ dims_start = prepad.astype(int)
144
+ dims_end = (insize - postpad).astype(int)
145
+ for i in range(len(dims_start.shape)):
146
+ psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
147
+ n_ops = np.sum(otf.size * np.log2(otf.shape))
148
+ psf = np.real_if_close(psf, tol=n_ops)
149
+ return psf
150
+
151
+
152
+ # psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py
153
+ def psf2otf(psf, shape=None):
154
+ """
155
+ Convert point-spread function to optical transfer function.
156
+ Compute the Fast Fourier Transform (FFT) of the point-spread
157
+ function (PSF) array and creates the optical transfer function (OTF)
158
+ array that is not influenced by the PSF off-centering.
159
+ By default, the OTF array is the same size as the PSF array.
160
+ To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
161
+ post-pads the PSF array (down or to the right) with zeros to match
162
+ dimensions specified in OUTSIZE, then circularly shifts the values of
163
+ the PSF array up (or to the left) until the central pixel reaches (1,1)
164
+ position.
165
+ Parameters
166
+ ----------
167
+ psf : `numpy.ndarray`
168
+ PSF array
169
+ shape : int
170
+ Output shape of the OTF array
171
+ Returns
172
+ -------
173
+ otf : `numpy.ndarray`
174
+ OTF array
175
+ Notes
176
+ -----
177
+ Adapted from MATLAB psf2otf function
178
+ """
179
+ if type(shape) == type(None):
180
+ shape = psf.shape
181
+ shape = np.array(shape)
182
+ if np.all(psf == 0):
183
+ # return np.zeros_like(psf)
184
+ return np.zeros(shape)
185
+ if len(psf.shape) == 1:
186
+ psf = psf.reshape((1, psf.shape[0]))
187
+ inshape = psf.shape
188
+ psf = zero_pad(psf, shape, position='corner')
189
+ for axis, axis_size in enumerate(inshape):
190
+ psf = np.roll(psf, -int(axis_size / 2), axis=axis)
191
+ # Compute the OTF
192
+ otf = np.fft.fft2(psf, axes=(0, 1))
193
+ # Estimate the rough number of operations involved in the FFT
194
+ # and discard the PSF imaginary part if within roundoff error
195
+ # roundoff error = machine epsilon = sys.float_info.epsilon
196
+ # or np.finfo().eps
197
+ n_ops = np.sum(psf.size * np.log2(psf.shape))
198
+ otf = np.real_if_close(otf, tol=n_ops)
199
+ return otf
200
+
201
+
202
+ def zero_pad(image, shape, position='corner'):
203
+ """
204
+ Extends image to a certain size with zeros
205
+ Parameters
206
+ ----------
207
+ image: real 2d `numpy.ndarray`
208
+ Input image
209
+ shape: tuple of int
210
+ Desired output shape of the image
211
+ position : str, optional
212
+ The position of the input image in the output one:
213
+ * 'corner'
214
+ top-left corner (default)
215
+ * 'center'
216
+ centered
217
+ Returns
218
+ -------
219
+ padded_img: real `numpy.ndarray`
220
+ The zero-padded image
221
+ """
222
+ shape = np.asarray(shape, dtype=int)
223
+ imshape = np.asarray(image.shape, dtype=int)
224
+ if np.alltrue(imshape == shape):
225
+ return image
226
+ if np.any(shape <= 0):
227
+ raise ValueError("ZERO_PAD: null or negative shape given")
228
+ dshape = shape - imshape
229
+ if np.any(dshape < 0):
230
+ raise ValueError("ZERO_PAD: target size smaller than source one")
231
+ pad_img = np.zeros(shape, dtype=image.dtype)
232
+ idx, idy = np.indices(imshape)
233
+ if position == 'center':
234
+ if np.any(dshape % 2 != 0):
235
+ raise ValueError("ZERO_PAD: source and target shapes "
236
+ "have different parity.")
237
+ offx, offy = dshape // 2
238
+ else:
239
+ offx, offy = (0, 0)
240
+ pad_img[idx + offx, idy + offy] = image
241
+ return pad_img
242
+
243
+
244
+ '''
245
+ Reducing boundary artifacts
246
+ '''
247
+
248
+
249
+ def opt_fft_size(n):
250
+ '''
251
+ Kai Zhang (github: https://github.com/cszn)
252
+ 03/03/2019
253
+ # opt_fft_size.m
254
+ # compute an optimal data length for Fourier transforms
255
+ # written by Sunghyun Cho ([email protected])
256
+ # persistent opt_fft_size_LUT;
257
+ '''
258
+
259
+ LUT_size = 2048
260
+ # print("generate opt_fft_size_LUT")
261
+ opt_fft_size_LUT = np.zeros(LUT_size)
262
+
263
+ e2 = 1
264
+ while e2 <= LUT_size:
265
+ e3 = e2
266
+ while e3 <= LUT_size:
267
+ e5 = e3
268
+ while e5 <= LUT_size:
269
+ e7 = e5
270
+ while e7 <= LUT_size:
271
+ if e7 <= LUT_size:
272
+ opt_fft_size_LUT[e7-1] = e7
273
+ if e7*11 <= LUT_size:
274
+ opt_fft_size_LUT[e7*11-1] = e7*11
275
+ if e7*13 <= LUT_size:
276
+ opt_fft_size_LUT[e7*13-1] = e7*13
277
+ e7 = e7 * 7
278
+ e5 = e5 * 5
279
+ e3 = e3 * 3
280
+ e2 = e2 * 2
281
+
282
+ nn = 0
283
+ for i in range(LUT_size, 0, -1):
284
+ if opt_fft_size_LUT[i-1] != 0:
285
+ nn = i-1
286
+ else:
287
+ opt_fft_size_LUT[i-1] = nn+1
288
+
289
+ m = np.zeros(len(n))
290
+ for c in range(len(n)):
291
+ nn = n[c]
292
+ if nn <= LUT_size:
293
+ m[c] = opt_fft_size_LUT[nn-1]
294
+ else:
295
+ m[c] = -1
296
+ return m
297
+
298
+
299
+ def wrap_boundary_liu(img, img_size):
300
+
301
+ """
302
+ Reducing boundary artifacts in image deconvolution
303
+ Renting Liu, Jiaya Jia
304
+ ICIP 2008
305
+ """
306
+ if img.ndim == 2:
307
+ ret = wrap_boundary(img, img_size)
308
+ elif img.ndim == 3:
309
+ ret = [wrap_boundary(img[:, :, i], img_size) for i in range(3)]
310
+ ret = np.stack(ret, 2)
311
+ return ret
312
+
313
+
314
+ def wrap_boundary(img, img_size):
315
+
316
+ """
317
+ python code from:
318
+ https://github.com/ys-koshelev/nla_deblur/blob/90fe0ab98c26c791dcbdf231fe6f938fca80e2a0/boundaries.py
319
+ Reducing boundary artifacts in image deconvolution
320
+ Renting Liu, Jiaya Jia
321
+ ICIP 2008
322
+ """
323
+ (H, W) = np.shape(img)
324
+ H_w = int(img_size[0]) - H
325
+ W_w = int(img_size[1]) - W
326
+
327
+ # ret = np.zeros((img_size[0], img_size[1]));
328
+ alpha = 1
329
+ HG = img[:, :]
330
+
331
+ r_A = np.zeros((alpha*2+H_w, W))
332
+ r_A[:alpha, :] = HG[-alpha:, :]
333
+ r_A[-alpha:, :] = HG[:alpha, :]
334
+ a = np.arange(H_w)/(H_w-1)
335
+ # r_A(alpha+1:end-alpha, 1) = (1-a)*r_A(alpha,1) + a*r_A(end-alpha+1,1)
336
+ r_A[alpha:-alpha, 0] = (1-a)*r_A[alpha-1, 0] + a*r_A[-alpha, 0]
337
+ # r_A(alpha+1:end-alpha, end) = (1-a)*r_A(alpha,end) + a*r_A(end-alpha+1,end)
338
+ r_A[alpha:-alpha, -1] = (1-a)*r_A[alpha-1, -1] + a*r_A[-alpha, -1]
339
+
340
+ r_B = np.zeros((H, alpha*2+W_w))
341
+ r_B[:, :alpha] = HG[:, -alpha:]
342
+ r_B[:, -alpha:] = HG[:, :alpha]
343
+ a = np.arange(W_w)/(W_w-1)
344
+ r_B[0, alpha:-alpha] = (1-a)*r_B[0, alpha-1] + a*r_B[0, -alpha]
345
+ r_B[-1, alpha:-alpha] = (1-a)*r_B[-1, alpha-1] + a*r_B[-1, -alpha]
346
+
347
+ if alpha == 1:
348
+ A2 = solve_min_laplacian(r_A[alpha-1:, :])
349
+ B2 = solve_min_laplacian(r_B[:, alpha-1:])
350
+ r_A[alpha-1:, :] = A2
351
+ r_B[:, alpha-1:] = B2
352
+ else:
353
+ A2 = solve_min_laplacian(r_A[alpha-1:-alpha+1, :])
354
+ r_A[alpha-1:-alpha+1, :] = A2
355
+ B2 = solve_min_laplacian(r_B[:, alpha-1:-alpha+1])
356
+ r_B[:, alpha-1:-alpha+1] = B2
357
+ A = r_A
358
+ B = r_B
359
+
360
+ r_C = np.zeros((alpha*2+H_w, alpha*2+W_w))
361
+ r_C[:alpha, :] = B[-alpha:, :]
362
+ r_C[-alpha:, :] = B[:alpha, :]
363
+ r_C[:, :alpha] = A[:, -alpha:]
364
+ r_C[:, -alpha:] = A[:, :alpha]
365
+
366
+ if alpha == 1:
367
+ C2 = C2 = solve_min_laplacian(r_C[alpha-1:, alpha-1:])
368
+ r_C[alpha-1:, alpha-1:] = C2
369
+ else:
370
+ C2 = solve_min_laplacian(r_C[alpha-1:-alpha+1, alpha-1:-alpha+1])
371
+ r_C[alpha-1:-alpha+1, alpha-1:-alpha+1] = C2
372
+ C = r_C
373
+ # return C
374
+ A = A[alpha-1:-alpha-1, :]
375
+ B = B[:, alpha:-alpha]
376
+ C = C[alpha:-alpha, alpha:-alpha]
377
+ ret = np.vstack((np.hstack((img, B)), np.hstack((A, C))))
378
+ return ret
379
+
380
+
381
+ def solve_min_laplacian(boundary_image):
382
+ (H, W) = np.shape(boundary_image)
383
+
384
+ # Laplacian
385
+ f = np.zeros((H, W))
386
+ # boundary image contains image intensities at boundaries
387
+ boundary_image[1:-1, 1:-1] = 0
388
+ j = np.arange(2, H)-1
389
+ k = np.arange(2, W)-1
390
+ f_bp = np.zeros((H, W))
391
+ f_bp[np.ix_(j, k)] = -4*boundary_image[np.ix_(j, k)] + boundary_image[np.ix_(j, k+1)] + boundary_image[np.ix_(j, k-1)] + boundary_image[np.ix_(j-1, k)] + boundary_image[np.ix_(j+1, k)]
392
+
393
+ del(j, k)
394
+ f1 = f - f_bp # subtract boundary points contribution
395
+ del(f_bp, f)
396
+
397
+ # DST Sine Transform algo starts here
398
+ f2 = f1[1:-1,1:-1]
399
+ del(f1)
400
+
401
+ # compute sine tranform
402
+ if f2.shape[1] == 1:
403
+ tt = fftpack.dst(f2, type=1, axis=0)/2
404
+ else:
405
+ tt = fftpack.dst(f2, type=1)/2
406
+
407
+ if tt.shape[0] == 1:
408
+ f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1, axis=0)/2)
409
+ else:
410
+ f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1)/2)
411
+ del(f2)
412
+
413
+ # compute Eigen Values
414
+ [x, y] = np.meshgrid(np.arange(1, W-1), np.arange(1, H-1))
415
+ denom = (2*np.cos(np.pi*x/(W-1))-2) + (2*np.cos(np.pi*y/(H-1)) - 2)
416
+
417
+ # divide
418
+ f3 = f2sin/denom
419
+ del(f2sin, x, y)
420
+
421
+ # compute Inverse Sine Transform
422
+ if f3.shape[0] == 1:
423
+ tt = fftpack.idst(f3*2, type=1, axis=1)/(2*(f3.shape[1]+1))
424
+ else:
425
+ tt = fftpack.idst(f3*2, type=1, axis=0)/(2*(f3.shape[0]+1))
426
+ del(f3)
427
+ if tt.shape[1] == 1:
428
+ img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1)/(2*(tt.shape[0]+1)))
429
+ else:
430
+ img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1, axis=0)/(2*(tt.shape[1]+1)))
431
+ del(tt)
432
+
433
+ # put solution in inner points; outer points obtained from boundary image
434
+ img_direct = boundary_image
435
+ img_direct[1:-1, 1:-1] = 0
436
+ img_direct[1:-1, 1:-1] = img_tt
437
+ return img_direct
438
+
439
+
440
+ """
441
+ Created on Thu Jan 18 15:36:32 2018
442
+ @author: italo
443
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
444
+ """
445
+
446
+ """
447
+ Syntax
448
+ h = fspecial(type)
449
+ h = fspecial('average',hsize)
450
+ h = fspecial('disk',radius)
451
+ h = fspecial('gaussian',hsize,sigma)
452
+ h = fspecial('laplacian',alpha)
453
+ h = fspecial('log',hsize,sigma)
454
+ h = fspecial('motion',len,theta)
455
+ h = fspecial('prewitt')
456
+ h = fspecial('sobel')
457
+ """
458
+
459
+
460
+ def fspecial_average(hsize=3):
461
+ """Smoothing filter"""
462
+ return np.ones((hsize, hsize))/hsize**2
463
+
464
+
465
+ def fspecial_disk(radius):
466
+ """Disk filter"""
467
+ raise(NotImplemented)
468
+ rad = 0.6
469
+ crad = np.ceil(rad-0.5)
470
+ [x, y] = np.meshgrid(np.arange(-crad, crad+1), np.arange(-crad, crad+1))
471
+ maxxy = np.zeros(x.shape)
472
+ maxxy[abs(x) >= abs(y)] = abs(x)[abs(x) >= abs(y)]
473
+ maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)]
474
+ minxy = np.zeros(x.shape)
475
+ minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)]
476
+ minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)]
477
+ m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\
478
+ (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\
479
+ np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2)
480
+ m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\
481
+ (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\
482
+ np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2)
483
+ h = None
484
+ return h
485
+
486
+
487
+ def fspecial_gaussian(hsize, sigma):
488
+ hsize = [hsize, hsize]
489
+ siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
490
+ std = sigma
491
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
492
+ arg = -(x*x + y*y)/(2*std*std)
493
+ h = np.exp(arg)
494
+ h[h < scipy.finfo(float).eps * h.max()] = 0
495
+ sumh = h.sum()
496
+ if sumh != 0:
497
+ h = h/sumh
498
+ return h
499
+
500
+
501
+ def fspecial_laplacian(alpha):
502
+ alpha = max([0, min([alpha,1])])
503
+ h1 = alpha/(alpha+1)
504
+ h2 = (1-alpha)/(alpha+1)
505
+ h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
506
+ h = np.array(h)
507
+ return h
508
+
509
+
510
+ def fspecial_log(hsize, sigma):
511
+ raise(NotImplemented)
512
+
513
+
514
+ def fspecial_motion(motion_len, theta):
515
+ raise(NotImplemented)
516
+
517
+
518
+ def fspecial_prewitt():
519
+ return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])
520
+
521
+
522
+ def fspecial_sobel():
523
+ return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
524
+
525
+
526
+ def fspecial(filter_type, *args, **kwargs):
527
+ '''
528
+ python code from:
529
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
530
+ '''
531
+ if filter_type == 'average':
532
+ return fspecial_average(*args, **kwargs)
533
+ if filter_type == 'disk':
534
+ return fspecial_disk(*args, **kwargs)
535
+ if filter_type == 'gaussian':
536
+ return fspecial_gaussian(*args, **kwargs)
537
+ if filter_type == 'laplacian':
538
+ return fspecial_laplacian(*args, **kwargs)
539
+ if filter_type == 'log':
540
+ return fspecial_log(*args, **kwargs)
541
+ if filter_type == 'motion':
542
+ return fspecial_motion(*args, **kwargs)
543
+ if filter_type == 'prewitt':
544
+ return fspecial_prewitt(*args, **kwargs)
545
+ if filter_type == 'sobel':
546
+ return fspecial_sobel(*args, **kwargs)
547
+
548
+
549
+ def fspecial_gauss(size, sigma):
550
+ x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
551
+ g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
552
+ return g / g.sum()
553
+
554
+
555
+ def blurkernel_synthesis(h=37, w=None):
556
+ # https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py
557
+ w = h if w is None else w
558
+ kdims = [h, w]
559
+ x = randomTrajectory(250)
560
+ k = None
561
+ while k is None:
562
+ k = kernelFromTrajectory(x)
563
+
564
+ # center pad to kdims
565
+ pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2)
566
+ pad_width = [(pad_width[0],), (pad_width[1],)]
567
+
568
+ if pad_width[0][0]<0 or pad_width[1][0]<0:
569
+ k = k[0:h, 0:h]
570
+ else:
571
+ k = pad(k, pad_width, "constant")
572
+ x1,x2 = k.shape
573
+ if np.random.randint(0, 4) == 1:
574
+ k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR)
575
+ y1, y2 = k.shape
576
+ k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2]
577
+
578
+ if sum(k)<0.1:
579
+ k = fspecial_gaussian(h, 0.1+6*np.random.rand(1))
580
+ k = k / sum(k)
581
+ # import matplotlib.pyplot as plt
582
+ # plt.imshow(k, interpolation="nearest", cmap="gray")
583
+ # plt.show()
584
+ return k
585
+
586
+
587
+ def kernelFromTrajectory(x):
588
+ h = 5 - log(rand()) / 0.15
589
+ h = round(min([h, 27])).astype(int)
590
+ h = h + 1 - h % 2
591
+ w = h
592
+ k = zeros((h, w))
593
+
594
+ xmin = min(x[0])
595
+ xmax = max(x[0])
596
+ ymin = min(x[1])
597
+ ymax = max(x[1])
598
+ xthr = arange(xmin, xmax, (xmax - xmin) / w)
599
+ ythr = arange(ymin, ymax, (ymax - ymin) / h)
600
+
601
+ for i in range(1, xthr.size):
602
+ for j in range(1, ythr.size):
603
+ idx = (
604
+ (x[0, :] >= xthr[i - 1])
605
+ & (x[0, :] < xthr[i])
606
+ & (x[1, :] >= ythr[j - 1])
607
+ & (x[1, :] < ythr[j])
608
+ )
609
+ k[i - 1, j - 1] = sum(idx)
610
+ if sum(k) == 0:
611
+ return
612
+ k = k / sum(k)
613
+ k = convolve2d(k, fspecial_gauss(3, 1), "same")
614
+ k = k / sum(k)
615
+ return k
616
+
617
+
618
+ def randomTrajectory(T):
619
+ x = zeros((3, T))
620
+ v = randn(3, T)
621
+ r = zeros((3, T))
622
+ trv = 1 / 1
623
+ trr = 2 * pi / T
624
+ for t in range(1, T):
625
+ F_rot = randn(3) / (t + 1) + r[:, t - 1]
626
+ F_trans = randn(3) / (t + 1)
627
+ r[:, t] = r[:, t - 1] + trr * F_rot
628
+ v[:, t] = v[:, t - 1] + trv * F_trans
629
+ st = v[:, t]
630
+ st = rot3D(st, r[:, t])
631
+ x[:, t] = x[:, t - 1] + st
632
+ return x
633
+
634
+
635
+ def rot3D(x, r):
636
+ Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]])
637
+ Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]])
638
+ Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]])
639
+ R = Rz @ Ry @ Rx
640
+ x = R @ x
641
+ return x
642
+
643
+
644
+ if __name__ == '__main__':
645
+ a = opt_fft_size([111])
646
+ print(a)
647
+
648
+ print(fspecial('gaussian', 5, 1))
649
+
650
+ print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape)
651
+
652
+ k = blurkernel_synthesis(11)
653
+ import matplotlib.pyplot as plt
654
+ plt.imshow(k, interpolation="nearest", cmap="gray")
655
+ plt.show()
core/data/deg_kair_utils/utils_dist.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.multiprocessing as mp
8
+
9
+
10
+ # ----------------------------------
11
+ # init
12
+ # ----------------------------------
13
+ def init_dist(launcher, backend='nccl', **kwargs):
14
+ if mp.get_start_method(allow_none=True) is None:
15
+ mp.set_start_method('spawn')
16
+ if launcher == 'pytorch':
17
+ _init_dist_pytorch(backend, **kwargs)
18
+ elif launcher == 'slurm':
19
+ _init_dist_slurm(backend, **kwargs)
20
+ else:
21
+ raise ValueError(f'Invalid launcher type: {launcher}')
22
+
23
+
24
+ def _init_dist_pytorch(backend, **kwargs):
25
+ rank = int(os.environ['RANK'])
26
+ num_gpus = torch.cuda.device_count()
27
+ torch.cuda.set_device(rank % num_gpus)
28
+ dist.init_process_group(backend=backend, **kwargs)
29
+
30
+
31
+ def _init_dist_slurm(backend, port=None):
32
+ """Initialize slurm distributed training environment.
33
+ If argument ``port`` is not specified, then the master port will be system
34
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
35
+ environment variable, then a default port ``29500`` will be used.
36
+ Args:
37
+ backend (str): Backend of torch.distributed.
38
+ port (int, optional): Master port. Defaults to None.
39
+ """
40
+ proc_id = int(os.environ['SLURM_PROCID'])
41
+ ntasks = int(os.environ['SLURM_NTASKS'])
42
+ node_list = os.environ['SLURM_NODELIST']
43
+ num_gpus = torch.cuda.device_count()
44
+ torch.cuda.set_device(proc_id % num_gpus)
45
+ addr = subprocess.getoutput(
46
+ f'scontrol show hostname {node_list} | head -n1')
47
+ # specify master port
48
+ if port is not None:
49
+ os.environ['MASTER_PORT'] = str(port)
50
+ elif 'MASTER_PORT' in os.environ:
51
+ pass # use MASTER_PORT in the environment variable
52
+ else:
53
+ # 29500 is torch.distributed default port
54
+ os.environ['MASTER_PORT'] = '29500'
55
+ os.environ['MASTER_ADDR'] = addr
56
+ os.environ['WORLD_SIZE'] = str(ntasks)
57
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
58
+ os.environ['RANK'] = str(proc_id)
59
+ dist.init_process_group(backend=backend)
60
+
61
+
62
+
63
+ # ----------------------------------
64
+ # get rank and world_size
65
+ # ----------------------------------
66
+ def get_dist_info():
67
+ if dist.is_available():
68
+ initialized = dist.is_initialized()
69
+ else:
70
+ initialized = False
71
+ if initialized:
72
+ rank = dist.get_rank()
73
+ world_size = dist.get_world_size()
74
+ else:
75
+ rank = 0
76
+ world_size = 1
77
+ return rank, world_size
78
+
79
+
80
+ def get_rank():
81
+ if not dist.is_available():
82
+ return 0
83
+
84
+ if not dist.is_initialized():
85
+ return 0
86
+
87
+ return dist.get_rank()
88
+
89
+
90
+ def get_world_size():
91
+ if not dist.is_available():
92
+ return 1
93
+
94
+ if not dist.is_initialized():
95
+ return 1
96
+
97
+ return dist.get_world_size()
98
+
99
+
100
+ def master_only(func):
101
+
102
+ @functools.wraps(func)
103
+ def wrapper(*args, **kwargs):
104
+ rank, _ = get_dist_info()
105
+ if rank == 0:
106
+ return func(*args, **kwargs)
107
+
108
+ return wrapper
109
+
110
+
111
+
112
+
113
+
114
+
115
+ # ----------------------------------
116
+ # operation across ranks
117
+ # ----------------------------------
118
+ def reduce_sum(tensor):
119
+ if not dist.is_available():
120
+ return tensor
121
+
122
+ if not dist.is_initialized():
123
+ return tensor
124
+
125
+ tensor = tensor.clone()
126
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
127
+
128
+ return tensor
129
+
130
+
131
+ def gather_grad(params):
132
+ world_size = get_world_size()
133
+
134
+ if world_size == 1:
135
+ return
136
+
137
+ for param in params:
138
+ if param.grad is not None:
139
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
140
+ param.grad.data.div_(world_size)
141
+
142
+
143
+ def all_gather(data):
144
+ world_size = get_world_size()
145
+
146
+ if world_size == 1:
147
+ return [data]
148
+
149
+ buffer = pickle.dumps(data)
150
+ storage = torch.ByteStorage.from_buffer(buffer)
151
+ tensor = torch.ByteTensor(storage).to('cuda')
152
+
153
+ local_size = torch.IntTensor([tensor.numel()]).to('cuda')
154
+ size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
155
+ dist.all_gather(size_list, local_size)
156
+ size_list = [int(size.item()) for size in size_list]
157
+ max_size = max(size_list)
158
+
159
+ tensor_list = []
160
+ for _ in size_list:
161
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
162
+
163
+ if local_size != max_size:
164
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
165
+ tensor = torch.cat((tensor, padding), 0)
166
+
167
+ dist.all_gather(tensor_list, tensor)
168
+
169
+ data_list = []
170
+
171
+ for size, tensor in zip(size_list, tensor_list):
172
+ buffer = tensor.cpu().numpy().tobytes()[:size]
173
+ data_list.append(pickle.loads(buffer))
174
+
175
+ return data_list
176
+
177
+
178
+ def reduce_loss_dict(loss_dict):
179
+ world_size = get_world_size()
180
+
181
+ if world_size < 2:
182
+ return loss_dict
183
+
184
+ with torch.no_grad():
185
+ keys = []
186
+ losses = []
187
+
188
+ for k in sorted(loss_dict.keys()):
189
+ keys.append(k)
190
+ losses.append(loss_dict[k])
191
+
192
+ losses = torch.stack(losses, 0)
193
+ dist.reduce(losses, dst=0)
194
+
195
+ if dist.get_rank() == 0:
196
+ losses /= world_size
197
+
198
+ reduced_losses = {k: v for k, v in zip(keys, losses)}
199
+
200
+ return reduced_losses
201
+
core/data/deg_kair_utils/utils_googledownload.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+
6
+ '''
7
+ borrowed from
8
+ https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py
9
+ '''
10
+
11
+
12
+ def sizeof_fmt(size, suffix='B'):
13
+ """Get human readable file size.
14
+ Args:
15
+ size (int): File size.
16
+ suffix (str): Suffix. Default: 'B'.
17
+ Return:
18
+ str: Formated file siz.
19
+ """
20
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
21
+ if abs(size) < 1024.0:
22
+ return f'{size:3.1f} {unit}{suffix}'
23
+ size /= 1024.0
24
+ return f'{size:3.1f} Y{suffix}'
25
+
26
+
27
+ def download_file_from_google_drive(file_id, save_path):
28
+ """Download files from google drive.
29
+ Ref:
30
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
31
+ Args:
32
+ file_id (str): File id.
33
+ save_path (str): Save path.
34
+ """
35
+
36
+ session = requests.Session()
37
+ URL = 'https://docs.google.com/uc?export=download'
38
+ params = {'id': file_id}
39
+
40
+ response = session.get(URL, params=params, stream=True)
41
+ token = get_confirm_token(response)
42
+ if token:
43
+ params['confirm'] = token
44
+ response = session.get(URL, params=params, stream=True)
45
+
46
+ # get file size
47
+ response_file_size = session.get(
48
+ URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
49
+ if 'Content-Range' in response_file_size.headers:
50
+ file_size = int(
51
+ response_file_size.headers['Content-Range'].split('/')[1])
52
+ else:
53
+ file_size = None
54
+
55
+ save_response_content(response, save_path, file_size)
56
+
57
+
58
+ def get_confirm_token(response):
59
+ for key, value in response.cookies.items():
60
+ if key.startswith('download_warning'):
61
+ return value
62
+ return None
63
+
64
+
65
+ def save_response_content(response,
66
+ destination,
67
+ file_size=None,
68
+ chunk_size=32768):
69
+ if file_size is not None:
70
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
71
+
72
+ readable_file_size = sizeof_fmt(file_size)
73
+ else:
74
+ pbar = None
75
+
76
+ with open(destination, 'wb') as f:
77
+ downloaded_size = 0
78
+ for chunk in response.iter_content(chunk_size):
79
+ downloaded_size += chunk_size
80
+ if pbar is not None:
81
+ pbar.update(1)
82
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
83
+ f'/ {readable_file_size}')
84
+ if chunk: # filter out keep-alive new chunks
85
+ f.write(chunk)
86
+ if pbar is not None:
87
+ pbar.close()
88
+
89
+
90
+ if __name__ == "__main__":
91
+ file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv'
92
+ save_path = 'BSRGAN.pth'
93
+ download_file_from_google_drive(file_id, save_path)
core/data/deg_kair_utils/utils_image.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+ from torchvision.utils import make_grid
8
+ from datetime import datetime
9
+ # import torchvision.transforms as transforms
10
+ import matplotlib.pyplot as plt
11
+ from mpl_toolkits.mplot3d import Axes3D
12
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
13
+
14
+
15
+ '''
16
+ # --------------------------------------------
17
+ # Kai Zhang (github: https://github.com/cszn)
18
+ # 03/Mar/2019
19
+ # --------------------------------------------
20
+ # https://github.com/twhui/SRGAN-pyTorch
21
+ # https://github.com/xinntao/BasicSR
22
+ # --------------------------------------------
23
+ '''
24
+
25
+
26
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
27
+
28
+
29
+ def is_image_file(filename):
30
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
31
+
32
+
33
+ def get_timestamp():
34
+ return datetime.now().strftime('%y%m%d-%H%M%S')
35
+
36
+
37
+ def imshow(x, title=None, cbar=False, figsize=None):
38
+ plt.figure(figsize=figsize)
39
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
40
+ if title:
41
+ plt.title(title)
42
+ if cbar:
43
+ plt.colorbar()
44
+ plt.show()
45
+
46
+
47
+ def surf(Z, cmap='rainbow', figsize=None):
48
+ plt.figure(figsize=figsize)
49
+ ax3 = plt.axes(projection='3d')
50
+
51
+ w, h = Z.shape[:2]
52
+ xx = np.arange(0,w,1)
53
+ yy = np.arange(0,h,1)
54
+ X, Y = np.meshgrid(xx, yy)
55
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
56
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
57
+ plt.show()
58
+
59
+
60
+ '''
61
+ # --------------------------------------------
62
+ # get image pathes
63
+ # --------------------------------------------
64
+ '''
65
+
66
+
67
+ def get_image_paths(dataroot):
68
+ paths = None # return None if dataroot is None
69
+ if isinstance(dataroot, str):
70
+ paths = sorted(_get_paths_from_images(dataroot))
71
+ elif isinstance(dataroot, list):
72
+ paths = []
73
+ for i in dataroot:
74
+ paths += sorted(_get_paths_from_images(i))
75
+ return paths
76
+
77
+
78
+ def _get_paths_from_images(path):
79
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
80
+ images = []
81
+ for dirpath, _, fnames in sorted(os.walk(path)):
82
+ for fname in sorted(fnames):
83
+ if is_image_file(fname):
84
+ img_path = os.path.join(dirpath, fname)
85
+ images.append(img_path)
86
+ assert images, '{:s} has no valid image file'.format(path)
87
+ return images
88
+
89
+
90
+ '''
91
+ # --------------------------------------------
92
+ # split large images into small images
93
+ # --------------------------------------------
94
+ '''
95
+
96
+
97
+ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
98
+ w, h = img.shape[:2]
99
+ patches = []
100
+ if w > p_max and h > p_max:
101
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
102
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
103
+ w1.append(w-p_size)
104
+ h1.append(h-p_size)
105
+ # print(w1)
106
+ # print(h1)
107
+ for i in w1:
108
+ for j in h1:
109
+ patches.append(img[i:i+p_size, j:j+p_size,:])
110
+ else:
111
+ patches.append(img)
112
+
113
+ return patches
114
+
115
+
116
+ def imssave(imgs, img_path):
117
+ """
118
+ imgs: list, N images of size WxHxC
119
+ """
120
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
121
+ for i, img in enumerate(imgs):
122
+ if img.ndim == 3:
123
+ img = img[:, :, [2, 1, 0]]
124
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png')
125
+ cv2.imwrite(new_path, img)
126
+
127
+
128
+ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800):
129
+ """
130
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
131
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
132
+ will be splitted.
133
+
134
+ Args:
135
+ original_dataroot:
136
+ taget_dataroot:
137
+ p_size: size of small images
138
+ p_overlap: patch size in training is a good choice
139
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
140
+ """
141
+ paths = get_image_paths(original_dataroot)
142
+ for img_path in paths:
143
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
144
+ img = imread_uint(img_path, n_channels=n_channels)
145
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
146
+ imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
147
+ #if original_dataroot == taget_dataroot:
148
+ #del img_path
149
+
150
+ '''
151
+ # --------------------------------------------
152
+ # makedir
153
+ # --------------------------------------------
154
+ '''
155
+
156
+
157
+ def mkdir(path):
158
+ if not os.path.exists(path):
159
+ os.makedirs(path)
160
+
161
+
162
+ def mkdirs(paths):
163
+ if isinstance(paths, str):
164
+ mkdir(paths)
165
+ else:
166
+ for path in paths:
167
+ mkdir(path)
168
+
169
+
170
+ def mkdir_and_rename(path):
171
+ if os.path.exists(path):
172
+ new_name = path + '_archived_' + get_timestamp()
173
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
174
+ os.rename(path, new_name)
175
+ os.makedirs(path)
176
+
177
+
178
+ '''
179
+ # --------------------------------------------
180
+ # read image from path
181
+ # opencv is fast, but read BGR numpy image
182
+ # --------------------------------------------
183
+ '''
184
+
185
+
186
+ # --------------------------------------------
187
+ # get uint8 image of size HxWxn_channles (RGB)
188
+ # --------------------------------------------
189
+ def imread_uint(path, n_channels=3):
190
+ # input: path
191
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
192
+ if n_channels == 1:
193
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
194
+ img = np.expand_dims(img, axis=2) # HxWx1
195
+ elif n_channels == 3:
196
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
197
+ if img.ndim == 2:
198
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
199
+ else:
200
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
201
+ return img
202
+
203
+
204
+ # --------------------------------------------
205
+ # matlab's imwrite
206
+ # --------------------------------------------
207
+ def imsave(img, img_path):
208
+ img = np.squeeze(img)
209
+ if img.ndim == 3:
210
+ img = img[:, :, [2, 1, 0]]
211
+ cv2.imwrite(img_path, img)
212
+
213
+ def imwrite(img, img_path):
214
+ img = np.squeeze(img)
215
+ if img.ndim == 3:
216
+ img = img[:, :, [2, 1, 0]]
217
+ cv2.imwrite(img_path, img)
218
+
219
+
220
+
221
+ # --------------------------------------------
222
+ # get single image of size HxWxn_channles (BGR)
223
+ # --------------------------------------------
224
+ def read_img(path):
225
+ # read image by cv2
226
+ # return: Numpy float32, HWC, BGR, [0,1]
227
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
228
+ img = img.astype(np.float32) / 255.
229
+ if img.ndim == 2:
230
+ img = np.expand_dims(img, axis=2)
231
+ # some images have 4 channels
232
+ if img.shape[2] > 3:
233
+ img = img[:, :, :3]
234
+ return img
235
+
236
+
237
+ '''
238
+ # --------------------------------------------
239
+ # image format conversion
240
+ # --------------------------------------------
241
+ # numpy(single) <---> numpy(uint)
242
+ # numpy(single) <---> tensor
243
+ # numpy(uint) <---> tensor
244
+ # --------------------------------------------
245
+ '''
246
+
247
+
248
+ # --------------------------------------------
249
+ # numpy(single) [0, 1] <---> numpy(uint)
250
+ # --------------------------------------------
251
+
252
+
253
+ def uint2single(img):
254
+
255
+ return np.float32(img/255.)
256
+
257
+
258
+ def single2uint(img):
259
+
260
+ return np.uint8((img.clip(0, 1)*255.).round())
261
+
262
+
263
+ def uint162single(img):
264
+
265
+ return np.float32(img/65535.)
266
+
267
+
268
+ def single2uint16(img):
269
+
270
+ return np.uint16((img.clip(0, 1)*65535.).round())
271
+
272
+
273
+ # --------------------------------------------
274
+ # numpy(uint) (HxWxC or HxW) <---> tensor
275
+ # --------------------------------------------
276
+
277
+
278
+ # convert uint to 4-dimensional torch tensor
279
+ def uint2tensor4(img):
280
+ if img.ndim == 2:
281
+ img = np.expand_dims(img, axis=2)
282
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
283
+
284
+
285
+ # convert uint to 3-dimensional torch tensor
286
+ def uint2tensor3(img):
287
+ if img.ndim == 2:
288
+ img = np.expand_dims(img, axis=2)
289
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
290
+
291
+
292
+ # convert 2/3/4-dimensional torch tensor to uint
293
+ def tensor2uint(img):
294
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
295
+ if img.ndim == 3:
296
+ img = np.transpose(img, (1, 2, 0))
297
+ return np.uint8((img*255.0).round())
298
+
299
+
300
+ # --------------------------------------------
301
+ # numpy(single) (HxWxC) <---> tensor
302
+ # --------------------------------------------
303
+
304
+
305
+ # convert single (HxWxC) to 3-dimensional torch tensor
306
+ def single2tensor3(img):
307
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
308
+
309
+
310
+ # convert single (HxWxC) to 4-dimensional torch tensor
311
+ def single2tensor4(img):
312
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
313
+
314
+
315
+ # convert torch tensor to single
316
+ def tensor2single(img):
317
+ img = img.data.squeeze().float().cpu().numpy()
318
+ if img.ndim == 3:
319
+ img = np.transpose(img, (1, 2, 0))
320
+
321
+ return img
322
+
323
+ # convert torch tensor to single
324
+ def tensor2single3(img):
325
+ img = img.data.squeeze().float().cpu().numpy()
326
+ if img.ndim == 3:
327
+ img = np.transpose(img, (1, 2, 0))
328
+ elif img.ndim == 2:
329
+ img = np.expand_dims(img, axis=2)
330
+ return img
331
+
332
+
333
+ def single2tensor5(img):
334
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
335
+
336
+
337
+ def single32tensor5(img):
338
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
339
+
340
+
341
+ def single42tensor4(img):
342
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
343
+
344
+
345
+ # from skimage.io import imread, imsave
346
+ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
347
+ '''
348
+ Converts a torch Tensor into an image Numpy array of BGR channel order
349
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
350
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
351
+ '''
352
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
353
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
354
+ n_dim = tensor.dim()
355
+ if n_dim == 4:
356
+ n_img = len(tensor)
357
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
358
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
359
+ elif n_dim == 3:
360
+ img_np = tensor.numpy()
361
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
362
+ elif n_dim == 2:
363
+ img_np = tensor.numpy()
364
+ else:
365
+ raise TypeError(
366
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
367
+ if out_type == np.uint8:
368
+ img_np = (img_np * 255.0).round()
369
+ # Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
370
+ return img_np.astype(out_type)
371
+
372
+
373
+ '''
374
+ # --------------------------------------------
375
+ # Augmentation, flipe and/or rotate
376
+ # --------------------------------------------
377
+ # The following two are enough.
378
+ # (1) augmet_img: numpy image of WxHxC or WxH
379
+ # (2) augment_img_tensor4: tensor image 1xCxWxH
380
+ # --------------------------------------------
381
+ '''
382
+
383
+
384
+ def augment_img(img, mode=0):
385
+ '''Kai Zhang (github: https://github.com/cszn)
386
+ '''
387
+ if mode == 0:
388
+ return img
389
+ elif mode == 1:
390
+ return np.flipud(np.rot90(img))
391
+ elif mode == 2:
392
+ return np.flipud(img)
393
+ elif mode == 3:
394
+ return np.rot90(img, k=3)
395
+ elif mode == 4:
396
+ return np.flipud(np.rot90(img, k=2))
397
+ elif mode == 5:
398
+ return np.rot90(img)
399
+ elif mode == 6:
400
+ return np.rot90(img, k=2)
401
+ elif mode == 7:
402
+ return np.flipud(np.rot90(img, k=3))
403
+
404
+
405
+ def augment_img_tensor4(img, mode=0):
406
+ '''Kai Zhang (github: https://github.com/cszn)
407
+ '''
408
+ if mode == 0:
409
+ return img
410
+ elif mode == 1:
411
+ return img.rot90(1, [2, 3]).flip([2])
412
+ elif mode == 2:
413
+ return img.flip([2])
414
+ elif mode == 3:
415
+ return img.rot90(3, [2, 3])
416
+ elif mode == 4:
417
+ return img.rot90(2, [2, 3]).flip([2])
418
+ elif mode == 5:
419
+ return img.rot90(1, [2, 3])
420
+ elif mode == 6:
421
+ return img.rot90(2, [2, 3])
422
+ elif mode == 7:
423
+ return img.rot90(3, [2, 3]).flip([2])
424
+
425
+
426
+ def augment_img_tensor(img, mode=0):
427
+ '''Kai Zhang (github: https://github.com/cszn)
428
+ '''
429
+ img_size = img.size()
430
+ img_np = img.data.cpu().numpy()
431
+ if len(img_size) == 3:
432
+ img_np = np.transpose(img_np, (1, 2, 0))
433
+ elif len(img_size) == 4:
434
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
435
+ img_np = augment_img(img_np, mode=mode)
436
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
437
+ if len(img_size) == 3:
438
+ img_tensor = img_tensor.permute(2, 0, 1)
439
+ elif len(img_size) == 4:
440
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
441
+
442
+ return img_tensor.type_as(img)
443
+
444
+
445
+ def augment_img_np3(img, mode=0):
446
+ if mode == 0:
447
+ return img
448
+ elif mode == 1:
449
+ return img.transpose(1, 0, 2)
450
+ elif mode == 2:
451
+ return img[::-1, :, :]
452
+ elif mode == 3:
453
+ img = img[::-1, :, :]
454
+ img = img.transpose(1, 0, 2)
455
+ return img
456
+ elif mode == 4:
457
+ return img[:, ::-1, :]
458
+ elif mode == 5:
459
+ img = img[:, ::-1, :]
460
+ img = img.transpose(1, 0, 2)
461
+ return img
462
+ elif mode == 6:
463
+ img = img[:, ::-1, :]
464
+ img = img[::-1, :, :]
465
+ return img
466
+ elif mode == 7:
467
+ img = img[:, ::-1, :]
468
+ img = img[::-1, :, :]
469
+ img = img.transpose(1, 0, 2)
470
+ return img
471
+
472
+
473
+ def augment_imgs(img_list, hflip=True, rot=True):
474
+ # horizontal flip OR rotate
475
+ hflip = hflip and random.random() < 0.5
476
+ vflip = rot and random.random() < 0.5
477
+ rot90 = rot and random.random() < 0.5
478
+
479
+ def _augment(img):
480
+ if hflip:
481
+ img = img[:, ::-1, :]
482
+ if vflip:
483
+ img = img[::-1, :, :]
484
+ if rot90:
485
+ img = img.transpose(1, 0, 2)
486
+ return img
487
+
488
+ return [_augment(img) for img in img_list]
489
+
490
+
491
+ '''
492
+ # --------------------------------------------
493
+ # modcrop and shave
494
+ # --------------------------------------------
495
+ '''
496
+
497
+
498
+ def modcrop(img_in, scale):
499
+ # img_in: Numpy, HWC or HW
500
+ img = np.copy(img_in)
501
+ if img.ndim == 2:
502
+ H, W = img.shape
503
+ H_r, W_r = H % scale, W % scale
504
+ img = img[:H - H_r, :W - W_r]
505
+ elif img.ndim == 3:
506
+ H, W, C = img.shape
507
+ H_r, W_r = H % scale, W % scale
508
+ img = img[:H - H_r, :W - W_r, :]
509
+ else:
510
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
511
+ return img
512
+
513
+
514
+ def shave(img_in, border=0):
515
+ # img_in: Numpy, HWC or HW
516
+ img = np.copy(img_in)
517
+ h, w = img.shape[:2]
518
+ img = img[border:h-border, border:w-border]
519
+ return img
520
+
521
+
522
+ '''
523
+ # --------------------------------------------
524
+ # image processing process on numpy image
525
+ # channel_convert(in_c, tar_type, img_list):
526
+ # rgb2ycbcr(img, only_y=True):
527
+ # bgr2ycbcr(img, only_y=True):
528
+ # ycbcr2rgb(img):
529
+ # --------------------------------------------
530
+ '''
531
+
532
+
533
+ def rgb2ycbcr(img, only_y=True):
534
+ '''same as matlab rgb2ycbcr
535
+ only_y: only return Y channel
536
+ Input:
537
+ uint8, [0, 255]
538
+ float, [0, 1]
539
+ '''
540
+ in_img_type = img.dtype
541
+ img.astype(np.float32)
542
+ if in_img_type != np.uint8:
543
+ img *= 255.
544
+ # convert
545
+ if only_y:
546
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
547
+ else:
548
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
549
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
550
+ if in_img_type == np.uint8:
551
+ rlt = rlt.round()
552
+ else:
553
+ rlt /= 255.
554
+ return rlt.astype(in_img_type)
555
+
556
+
557
+ def ycbcr2rgb(img):
558
+ '''same as matlab ycbcr2rgb
559
+ Input:
560
+ uint8, [0, 255]
561
+ float, [0, 1]
562
+ '''
563
+ in_img_type = img.dtype
564
+ img.astype(np.float32)
565
+ if in_img_type != np.uint8:
566
+ img *= 255.
567
+ # convert
568
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
569
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
570
+ rlt = np.clip(rlt, 0, 255)
571
+ if in_img_type == np.uint8:
572
+ rlt = rlt.round()
573
+ else:
574
+ rlt /= 255.
575
+ return rlt.astype(in_img_type)
576
+
577
+
578
+ def bgr2ycbcr(img, only_y=True):
579
+ '''bgr version of rgb2ycbcr
580
+ only_y: only return Y channel
581
+ Input:
582
+ uint8, [0, 255]
583
+ float, [0, 1]
584
+ '''
585
+ in_img_type = img.dtype
586
+ img.astype(np.float32)
587
+ if in_img_type != np.uint8:
588
+ img *= 255.
589
+ # convert
590
+ if only_y:
591
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
592
+ else:
593
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
594
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
595
+ if in_img_type == np.uint8:
596
+ rlt = rlt.round()
597
+ else:
598
+ rlt /= 255.
599
+ return rlt.astype(in_img_type)
600
+
601
+
602
+ def channel_convert(in_c, tar_type, img_list):
603
+ # conversion among BGR, gray and y
604
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
605
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
606
+ return [np.expand_dims(img, axis=2) for img in gray_list]
607
+ elif in_c == 3 and tar_type == 'y': # BGR to y
608
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
609
+ return [np.expand_dims(img, axis=2) for img in y_list]
610
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
611
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
612
+ else:
613
+ return img_list
614
+
615
+
616
+ '''
617
+ # --------------------------------------------
618
+ # metric, PSNR, SSIM and PSNRB
619
+ # --------------------------------------------
620
+ '''
621
+
622
+
623
+ # --------------------------------------------
624
+ # PSNR
625
+ # --------------------------------------------
626
+ def calculate_psnr(img1, img2, border=0):
627
+ # img1 and img2 have range [0, 255]
628
+ #img1 = img1.squeeze()
629
+ #img2 = img2.squeeze()
630
+ if not img1.shape == img2.shape:
631
+ raise ValueError('Input images must have the same dimensions.')
632
+ h, w = img1.shape[:2]
633
+ img1 = img1[border:h-border, border:w-border]
634
+ img2 = img2[border:h-border, border:w-border]
635
+
636
+ img1 = img1.astype(np.float64)
637
+ img2 = img2.astype(np.float64)
638
+ mse = np.mean((img1 - img2)**2)
639
+ if mse == 0:
640
+ return float('inf')
641
+ return 20 * math.log10(255.0 / math.sqrt(mse))
642
+
643
+
644
+ # --------------------------------------------
645
+ # SSIM
646
+ # --------------------------------------------
647
+ def calculate_ssim(img1, img2, border=0):
648
+ '''calculate SSIM
649
+ the same outputs as MATLAB's
650
+ img1, img2: [0, 255]
651
+ '''
652
+ #img1 = img1.squeeze()
653
+ #img2 = img2.squeeze()
654
+ if not img1.shape == img2.shape:
655
+ raise ValueError('Input images must have the same dimensions.')
656
+ h, w = img1.shape[:2]
657
+ img1 = img1[border:h-border, border:w-border]
658
+ img2 = img2[border:h-border, border:w-border]
659
+
660
+ if img1.ndim == 2:
661
+ return ssim(img1, img2)
662
+ elif img1.ndim == 3:
663
+ if img1.shape[2] == 3:
664
+ ssims = []
665
+ for i in range(3):
666
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
667
+ return np.array(ssims).mean()
668
+ elif img1.shape[2] == 1:
669
+ return ssim(np.squeeze(img1), np.squeeze(img2))
670
+ else:
671
+ raise ValueError('Wrong input image dimensions.')
672
+
673
+
674
+ def ssim(img1, img2):
675
+ C1 = (0.01 * 255)**2
676
+ C2 = (0.03 * 255)**2
677
+
678
+ img1 = img1.astype(np.float64)
679
+ img2 = img2.astype(np.float64)
680
+ kernel = cv2.getGaussianKernel(11, 1.5)
681
+ window = np.outer(kernel, kernel.transpose())
682
+
683
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
684
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
685
+ mu1_sq = mu1**2
686
+ mu2_sq = mu2**2
687
+ mu1_mu2 = mu1 * mu2
688
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
689
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
690
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
691
+
692
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
693
+ (sigma1_sq + sigma2_sq + C2))
694
+ return ssim_map.mean()
695
+
696
+
697
+ def _blocking_effect_factor(im):
698
+ block_size = 8
699
+
700
+ block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
701
+ block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
702
+
703
+ horizontal_block_difference = (
704
+ (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
705
+ 3).sum(2).sum(1)
706
+ vertical_block_difference = (
707
+ (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
708
+ 2).sum(1)
709
+
710
+ nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
711
+ nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
712
+
713
+ horizontal_nonblock_difference = (
714
+ (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
715
+ 3).sum(2).sum(1)
716
+ vertical_nonblock_difference = (
717
+ (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
718
+ 3).sum(2).sum(1)
719
+
720
+ n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
721
+ n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
722
+ boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
723
+ n_boundary_horiz + n_boundary_vert)
724
+
725
+ n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
726
+ n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
727
+ nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
728
+ n_nonboundary_horiz + n_nonboundary_vert)
729
+
730
+ scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
731
+ bef = scaler * (boundary_difference - nonboundary_difference)
732
+
733
+ bef[boundary_difference <= nonboundary_difference] = 0
734
+ return bef
735
+
736
+
737
+ def calculate_psnrb(img1, img2, border=0):
738
+ """Calculate PSNR-B (Peak Signal-to-Noise Ratio).
739
+ Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
740
+ # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
741
+ Args:
742
+ img1 (ndarray): Images with range [0, 255].
743
+ img2 (ndarray): Images with range [0, 255].
744
+ border (int): Cropped pixels in each edge of an image. These
745
+ pixels are not involved in the PSNR calculation.
746
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
747
+ Returns:
748
+ float: psnr result.
749
+ """
750
+
751
+ if not img1.shape == img2.shape:
752
+ raise ValueError('Input images must have the same dimensions.')
753
+
754
+ if img1.ndim == 2:
755
+ img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2)
756
+
757
+ h, w = img1.shape[:2]
758
+ img1 = img1[border:h-border, border:w-border]
759
+ img2 = img2[border:h-border, border:w-border]
760
+
761
+ img1 = img1.astype(np.float64)
762
+ img2 = img2.astype(np.float64)
763
+
764
+ # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
765
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
766
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
767
+
768
+ total = 0
769
+ for c in range(img1.shape[1]):
770
+ mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
771
+ bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
772
+
773
+ mse = mse.view(mse.shape[0], -1).mean(1)
774
+ total += 10 * torch.log10(1 / (mse + bef))
775
+
776
+ return float(total) / img1.shape[1]
777
+
778
+ '''
779
+ # --------------------------------------------
780
+ # matlab's bicubic imresize (numpy and torch) [0, 1]
781
+ # --------------------------------------------
782
+ '''
783
+
784
+
785
+ # matlab 'imresize' function, now only support 'bicubic'
786
+ def cubic(x):
787
+ absx = torch.abs(x)
788
+ absx2 = absx**2
789
+ absx3 = absx**3
790
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
791
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
792
+
793
+
794
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
795
+ if (scale < 1) and (antialiasing):
796
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
797
+ kernel_width = kernel_width / scale
798
+
799
+ # Output-space coordinates
800
+ x = torch.linspace(1, out_length, out_length)
801
+
802
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
803
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
804
+ # space maps to 1.5 in input space.
805
+ u = x / scale + 0.5 * (1 - 1 / scale)
806
+
807
+ # What is the left-most pixel that can be involved in the computation?
808
+ left = torch.floor(u - kernel_width / 2)
809
+
810
+ # What is the maximum number of pixels that can be involved in the
811
+ # computation? Note: it's OK to use an extra pixel here; if the
812
+ # corresponding weights are all zero, it will be eliminated at the end
813
+ # of this function.
814
+ P = math.ceil(kernel_width) + 2
815
+
816
+ # The indices of the input pixels involved in computing the k-th output
817
+ # pixel are in row k of the indices matrix.
818
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
819
+ 1, P).expand(out_length, P)
820
+
821
+ # The weights used to compute the k-th output pixel are in row k of the
822
+ # weights matrix.
823
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
824
+ # apply cubic kernel
825
+ if (scale < 1) and (antialiasing):
826
+ weights = scale * cubic(distance_to_center * scale)
827
+ else:
828
+ weights = cubic(distance_to_center)
829
+ # Normalize the weights matrix so that each row sums to 1.
830
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
831
+ weights = weights / weights_sum.expand(out_length, P)
832
+
833
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
834
+ weights_zero_tmp = torch.sum((weights == 0), 0)
835
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
836
+ indices = indices.narrow(1, 1, P - 2)
837
+ weights = weights.narrow(1, 1, P - 2)
838
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
839
+ indices = indices.narrow(1, 0, P - 2)
840
+ weights = weights.narrow(1, 0, P - 2)
841
+ weights = weights.contiguous()
842
+ indices = indices.contiguous()
843
+ sym_len_s = -indices.min() + 1
844
+ sym_len_e = indices.max() - in_length
845
+ indices = indices + sym_len_s - 1
846
+ return weights, indices, int(sym_len_s), int(sym_len_e)
847
+
848
+
849
+ # --------------------------------------------
850
+ # imresize for tensor image [0, 1]
851
+ # --------------------------------------------
852
+ def imresize(img, scale, antialiasing=True):
853
+ # Now the scale should be the same for H and W
854
+ # input: img: pytorch tensor, CHW or HW [0,1]
855
+ # output: CHW or HW [0,1] w/o round
856
+ need_squeeze = True if img.dim() == 2 else False
857
+ if need_squeeze:
858
+ img.unsqueeze_(0)
859
+ in_C, in_H, in_W = img.size()
860
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
861
+ kernel_width = 4
862
+ kernel = 'cubic'
863
+
864
+ # Return the desired dimension order for performing the resize. The
865
+ # strategy is to perform the resize first along the dimension with the
866
+ # smallest scale factor.
867
+ # Now we do not support this.
868
+
869
+ # get weights and indices
870
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
871
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
872
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
873
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
874
+ # process H dimension
875
+ # symmetric copying
876
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
877
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
878
+
879
+ sym_patch = img[:, :sym_len_Hs, :]
880
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
881
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
882
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
883
+
884
+ sym_patch = img[:, -sym_len_He:, :]
885
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
886
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
887
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
888
+
889
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
890
+ kernel_width = weights_H.size(1)
891
+ for i in range(out_H):
892
+ idx = int(indices_H[i][0])
893
+ for j in range(out_C):
894
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
895
+
896
+ # process W dimension
897
+ # symmetric copying
898
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
899
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
900
+
901
+ sym_patch = out_1[:, :, :sym_len_Ws]
902
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
903
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
904
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
905
+
906
+ sym_patch = out_1[:, :, -sym_len_We:]
907
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
908
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
909
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
910
+
911
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
912
+ kernel_width = weights_W.size(1)
913
+ for i in range(out_W):
914
+ idx = int(indices_W[i][0])
915
+ for j in range(out_C):
916
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
917
+ if need_squeeze:
918
+ out_2.squeeze_()
919
+ return out_2
920
+
921
+
922
+ # --------------------------------------------
923
+ # imresize for numpy image [0, 1]
924
+ # --------------------------------------------
925
+ def imresize_np(img, scale, antialiasing=True):
926
+ # Now the scale should be the same for H and W
927
+ # input: img: Numpy, HWC or HW [0,1]
928
+ # output: HWC or HW [0,1] w/o round
929
+ img = torch.from_numpy(img)
930
+ need_squeeze = True if img.dim() == 2 else False
931
+ if need_squeeze:
932
+ img.unsqueeze_(2)
933
+
934
+ in_H, in_W, in_C = img.size()
935
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
936
+ kernel_width = 4
937
+ kernel = 'cubic'
938
+
939
+ # Return the desired dimension order for performing the resize. The
940
+ # strategy is to perform the resize first along the dimension with the
941
+ # smallest scale factor.
942
+ # Now we do not support this.
943
+
944
+ # get weights and indices
945
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
946
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
947
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
948
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
949
+ # process H dimension
950
+ # symmetric copying
951
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
952
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
953
+
954
+ sym_patch = img[:sym_len_Hs, :, :]
955
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
956
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
957
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
958
+
959
+ sym_patch = img[-sym_len_He:, :, :]
960
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
961
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
962
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
963
+
964
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
965
+ kernel_width = weights_H.size(1)
966
+ for i in range(out_H):
967
+ idx = int(indices_H[i][0])
968
+ for j in range(out_C):
969
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
970
+
971
+ # process W dimension
972
+ # symmetric copying
973
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
974
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
975
+
976
+ sym_patch = out_1[:, :sym_len_Ws, :]
977
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
978
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
979
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
980
+
981
+ sym_patch = out_1[:, -sym_len_We:, :]
982
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
983
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
984
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
985
+
986
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
987
+ kernel_width = weights_W.size(1)
988
+ for i in range(out_W):
989
+ idx = int(indices_W[i][0])
990
+ for j in range(out_C):
991
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
992
+ if need_squeeze:
993
+ out_2.squeeze_()
994
+
995
+ return out_2.numpy()
996
+
997
+
998
+ if __name__ == '__main__':
999
+ img = imread_uint('test.bmp', 3)
1000
+ # img = uint2single(img)
1001
+ # img_bicubic = imresize_np(img, 1/4)
1002
+ # imshow(single2uint(img_bicubic))
1003
+ #
1004
+ # img_tensor = single2tensor4(img)
1005
+ # for i in range(8):
1006
+ # imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1))
1007
+
1008
+ # patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200)
1009
+ # imssave(patches,'a.png')
1010
+
1011
+
1012
+
1013
+
1014
+
1015
+
1016
+
core/data/deg_kair_utils/utils_lmdb.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import lmdb
3
+ import sys
4
+ from multiprocessing import Pool
5
+ from os import path as osp
6
+ from tqdm import tqdm
7
+
8
+
9
+ def make_lmdb_from_imgs(data_path,
10
+ lmdb_path,
11
+ img_path_list,
12
+ keys,
13
+ batch=5000,
14
+ compress_level=1,
15
+ multiprocessing_read=False,
16
+ n_thread=40,
17
+ map_size=None):
18
+ """Make lmdb from images.
19
+
20
+ Contents of lmdb. The file structure is:
21
+ example.lmdb
22
+ ├── data.mdb
23
+ ├── lock.mdb
24
+ ├── meta_info.txt
25
+
26
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
27
+ https://lmdb.readthedocs.io/en/release/ for more details.
28
+
29
+ The meta_info.txt is a specified txt file to record the meta information
30
+ of our datasets. It will be automatically created when preparing
31
+ datasets by our provided dataset tools.
32
+ Each line in the txt file records 1)image name (with extension),
33
+ 2)image shape, and 3)compression level, separated by a white space.
34
+
35
+ For example, the meta information could be:
36
+ `000_00000000.png (720,1280,3) 1`, which means:
37
+ 1) image name (with extension): 000_00000000.png;
38
+ 2) image shape: (720,1280,3);
39
+ 3) compression level: 1
40
+
41
+ We use the image name without extension as the lmdb key.
42
+
43
+ If `multiprocessing_read` is True, it will read all the images to memory
44
+ using multiprocessing. Thus, your server needs to have enough memory.
45
+
46
+ Args:
47
+ data_path (str): Data path for reading images.
48
+ lmdb_path (str): Lmdb save path.
49
+ img_path_list (str): Image path list.
50
+ keys (str): Used for lmdb keys.
51
+ batch (int): After processing batch images, lmdb commits.
52
+ Default: 5000.
53
+ compress_level (int): Compress level when encoding images. Default: 1.
54
+ multiprocessing_read (bool): Whether use multiprocessing to read all
55
+ the images to memory. Default: False.
56
+ n_thread (int): For multiprocessing.
57
+ map_size (int | None): Map size for lmdb env. If None, use the
58
+ estimated size from images. Default: None
59
+ """
60
+
61
+ assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
62
+ f'but got {len(img_path_list)} and {len(keys)}')
63
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
64
+ print(f'Totoal images: {len(img_path_list)}')
65
+ if not lmdb_path.endswith('.lmdb'):
66
+ raise ValueError("lmdb_path must end with '.lmdb'.")
67
+ if osp.exists(lmdb_path):
68
+ print(f'Folder {lmdb_path} already exists. Exit.')
69
+ sys.exit(1)
70
+
71
+ if multiprocessing_read:
72
+ # read all the images to memory (multiprocessing)
73
+ dataset = {} # use dict to keep the order for multiprocessing
74
+ shapes = {}
75
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
76
+ pbar = tqdm(total=len(img_path_list), unit='image')
77
+
78
+ def callback(arg):
79
+ """get the image data and update pbar."""
80
+ key, dataset[key], shapes[key] = arg
81
+ pbar.update(1)
82
+ pbar.set_description(f'Read {key}')
83
+
84
+ pool = Pool(n_thread)
85
+ for path, key in zip(img_path_list, keys):
86
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
87
+ pool.close()
88
+ pool.join()
89
+ pbar.close()
90
+ print(f'Finish reading {len(img_path_list)} images.')
91
+
92
+ # create lmdb environment
93
+ if map_size is None:
94
+ # obtain data size for one image
95
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
96
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
97
+ data_size_per_img = img_byte.nbytes
98
+ print('Data size per image is: ', data_size_per_img)
99
+ data_size = data_size_per_img * len(img_path_list)
100
+ map_size = data_size * 10
101
+
102
+ env = lmdb.open(lmdb_path, map_size=map_size)
103
+
104
+ # write data to lmdb
105
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
106
+ txn = env.begin(write=True)
107
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
108
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
109
+ pbar.update(1)
110
+ pbar.set_description(f'Write {key}')
111
+ key_byte = key.encode('ascii')
112
+ if multiprocessing_read:
113
+ img_byte = dataset[key]
114
+ h, w, c = shapes[key]
115
+ else:
116
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
117
+ h, w, c = img_shape
118
+
119
+ txn.put(key_byte, img_byte)
120
+ # write meta information
121
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
122
+ if idx % batch == 0:
123
+ txn.commit()
124
+ txn = env.begin(write=True)
125
+ pbar.close()
126
+ txn.commit()
127
+ env.close()
128
+ txt_file.close()
129
+ print('\nFinish writing lmdb.')
130
+
131
+
132
+ def read_img_worker(path, key, compress_level):
133
+ """Read image worker.
134
+
135
+ Args:
136
+ path (str): Image path.
137
+ key (str): Image key.
138
+ compress_level (int): Compress level when encoding images.
139
+
140
+ Returns:
141
+ str: Image key.
142
+ byte: Image byte.
143
+ tuple[int]: Image shape.
144
+ """
145
+
146
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
147
+ # deal with `libpng error: Read Error`
148
+ if img is None:
149
+ print(f'To deal with `libpng error: Read Error`, use PIL to load {path}')
150
+ from PIL import Image
151
+ import numpy as np
152
+ img = Image.open(path)
153
+ img = np.asanyarray(img)
154
+ img = img[:, :, [2, 1, 0]]
155
+
156
+ if img.ndim == 2:
157
+ h, w = img.shape
158
+ c = 1
159
+ else:
160
+ h, w, c = img.shape
161
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
162
+ return (key, img_byte, (h, w, c))
163
+
164
+
165
+ class LmdbMaker():
166
+ """LMDB Maker.
167
+
168
+ Args:
169
+ lmdb_path (str): Lmdb save path.
170
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
171
+ batch (int): After processing batch images, lmdb commits.
172
+ Default: 5000.
173
+ compress_level (int): Compress level when encoding images. Default: 1.
174
+ """
175
+
176
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
177
+ if not lmdb_path.endswith('.lmdb'):
178
+ raise ValueError("lmdb_path must end with '.lmdb'.")
179
+ if osp.exists(lmdb_path):
180
+ print(f'Folder {lmdb_path} already exists. Exit.')
181
+ sys.exit(1)
182
+
183
+ self.lmdb_path = lmdb_path
184
+ self.batch = batch
185
+ self.compress_level = compress_level
186
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
187
+ self.txn = self.env.begin(write=True)
188
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
189
+ self.counter = 0
190
+
191
+ def put(self, img_byte, key, img_shape):
192
+ self.counter += 1
193
+ key_byte = key.encode('ascii')
194
+ self.txn.put(key_byte, img_byte)
195
+ # write meta information
196
+ h, w, c = img_shape
197
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
198
+ if self.counter % self.batch == 0:
199
+ self.txn.commit()
200
+ self.txn = self.env.begin(write=True)
201
+
202
+ def close(self):
203
+ self.txn.commit()
204
+ self.env.close()
205
+ self.txt_file.close()
core/data/deg_kair_utils/utils_logger.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import datetime
3
+ import logging
4
+
5
+
6
+ '''
7
+ # --------------------------------------------
8
+ # Kai Zhang (github: https://github.com/cszn)
9
+ # 03/Mar/2019
10
+ # --------------------------------------------
11
+ # https://github.com/xinntao/BasicSR
12
+ # --------------------------------------------
13
+ '''
14
+
15
+
16
+ def log(*args, **kwargs):
17
+ print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
18
+
19
+
20
+ '''
21
+ # --------------------------------------------
22
+ # logger
23
+ # --------------------------------------------
24
+ '''
25
+
26
+
27
+ def logger_info(logger_name, log_path='default_logger.log'):
28
+ ''' set up logger
29
+ modified by Kai Zhang (github: https://github.com/cszn)
30
+ '''
31
+ log = logging.getLogger(logger_name)
32
+ if log.hasHandlers():
33
+ print('LogHandlers exist!')
34
+ else:
35
+ print('LogHandlers setup!')
36
+ level = logging.INFO
37
+ formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
38
+ fh = logging.FileHandler(log_path, mode='a')
39
+ fh.setFormatter(formatter)
40
+ log.setLevel(level)
41
+ log.addHandler(fh)
42
+ # print(len(log.handlers))
43
+
44
+ sh = logging.StreamHandler()
45
+ sh.setFormatter(formatter)
46
+ log.addHandler(sh)
47
+
48
+
49
+ '''
50
+ # --------------------------------------------
51
+ # print to file and std_out simultaneously
52
+ # --------------------------------------------
53
+ '''
54
+
55
+
56
+ class logger_print(object):
57
+ def __init__(self, log_path="default.log"):
58
+ self.terminal = sys.stdout
59
+ self.log = open(log_path, 'a')
60
+
61
+ def write(self, message):
62
+ self.terminal.write(message)
63
+ self.log.write(message) # write the message
64
+
65
+ def flush(self):
66
+ pass
core/data/deg_kair_utils/utils_mat.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import scipy.io as spio
4
+ import pandas as pd
5
+
6
+
7
+ def loadmat(filename):
8
+ '''
9
+ this function should be called instead of direct spio.loadmat
10
+ as it cures the problem of not properly recovering python dictionaries
11
+ from mat files. It calls the function check keys to cure all entries
12
+ which are still mat-objects
13
+ '''
14
+ data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
15
+ return dict_to_nonedict(_check_keys(data))
16
+
17
+ def _check_keys(dict):
18
+ '''
19
+ checks if entries in dictionary are mat-objects. If yes
20
+ todict is called to change them to nested dictionaries
21
+ '''
22
+ for key in dict:
23
+ if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
24
+ dict[key] = _todict(dict[key])
25
+ return dict
26
+
27
+ def _todict(matobj):
28
+ '''
29
+ A recursive function which constructs from matobjects nested dictionaries
30
+ '''
31
+ dict = {}
32
+ for strg in matobj._fieldnames:
33
+ elem = matobj.__dict__[strg]
34
+ if isinstance(elem, spio.matlab.mio5_params.mat_struct):
35
+ dict[strg] = _todict(elem)
36
+ else:
37
+ dict[strg] = elem
38
+ return dict
39
+
40
+
41
+ def dict_to_nonedict(opt):
42
+ if isinstance(opt, dict):
43
+ new_opt = dict()
44
+ for key, sub_opt in opt.items():
45
+ new_opt[key] = dict_to_nonedict(sub_opt)
46
+ return NoneDict(**new_opt)
47
+ elif isinstance(opt, list):
48
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
49
+ else:
50
+ return opt
51
+
52
+
53
+ class NoneDict(dict):
54
+ def __missing__(self, key):
55
+ return None
56
+
57
+
58
+ def mat2json(mat_path=None, filepath = None):
59
+ """
60
+ Converts .mat file to .json and writes new file
61
+ Parameters
62
+ ----------
63
+ mat_path: Str
64
+ path/filename .mat存放路径
65
+ filepath: Str
66
+ 如果需要保存成json, 添加这一路径. 否则不保存
67
+ Returns
68
+ 返回转化的字典
69
+ -------
70
+ None
71
+ Examples
72
+ --------
73
+ >>> mat2json(blah blah)
74
+ """
75
+
76
+ matlabFile = loadmat(mat_path)
77
+ #pop all those dumb fields that don't let you jsonize file
78
+ matlabFile.pop('__header__')
79
+ matlabFile.pop('__version__')
80
+ matlabFile.pop('__globals__')
81
+ #jsonize the file - orientation is 'index'
82
+ matlabFile = pd.Series(matlabFile).to_json()
83
+
84
+ if filepath:
85
+ json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json'
86
+ with open(json_path, 'w') as f:
87
+ f.write(matlabFile)
88
+ return matlabFile
core/data/deg_kair_utils/utils_matconvnet.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import torch
4
+ from collections import OrderedDict
5
+
6
+ # import scipy.io as io
7
+ import hdf5storage
8
+
9
+ """
10
+ # --------------------------------------------
11
+ # Convert matconvnet SimpleNN model into pytorch model
12
+ # --------------------------------------------
13
+ # Kai Zhang ([email protected])
14
+ # https://github.com/cszn
15
+ # 28/Nov/2019
16
+ # --------------------------------------------
17
+ """
18
+
19
+
20
+ def weights2tensor(x, squeeze=False, in_features=None, out_features=None):
21
+ """Modified version of https://github.com/albanie/pytorch-mcn
22
+ Adjust memory layout and load weights as torch tensor
23
+ Args:
24
+ x (ndaray): a numpy array, corresponding to a set of network weights
25
+ stored in column major order
26
+ squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove
27
+ singletons from the trailing dimensions. So after converting to
28
+ pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1)
29
+ it will be reshaped to a matrix with shape (A,B).
30
+ in_features (int :: None): used to reshape weights for a linear block.
31
+ out_features (int :: None): used to reshape weights for a linear block.
32
+ Returns:
33
+ torch.tensor: a permuted sets of weights, matching the pytorch layout
34
+ convention
35
+ """
36
+ if x.ndim == 4:
37
+ x = x.transpose((3, 2, 0, 1))
38
+ # for FFDNet, pixel-shuffle layer
39
+ # if x.shape[1]==13:
40
+ # x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:]
41
+ # if x.shape[0]==12:
42
+ # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
43
+ # if x.shape[1]==5:
44
+ # x=x[:,[0,2,1,3, 4],:,:]
45
+ # if x.shape[0]==4:
46
+ # x=x[[0,2,1,3],:,:,:]
47
+ ## for SRMD, pixel-shuffle layer
48
+ # if x.shape[0]==12:
49
+ # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
50
+ # if x.shape[0]==27:
51
+ # x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:]
52
+ # if x.shape[0]==48:
53
+ # x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:]
54
+
55
+ elif x.ndim == 3: # add by Kai
56
+ x = x[:,:,:,None]
57
+ x = x.transpose((3, 2, 0, 1))
58
+ elif x.ndim == 2:
59
+ if x.shape[1] == 1:
60
+ x = x.flatten()
61
+ if squeeze:
62
+ if in_features and out_features:
63
+ x = x.reshape((out_features, in_features))
64
+ x = np.squeeze(x)
65
+ return torch.from_numpy(np.ascontiguousarray(x))
66
+
67
+
68
+ def save_model(network, save_path):
69
+ state_dict = network.state_dict()
70
+ for key, param in state_dict.items():
71
+ state_dict[key] = param.cpu()
72
+ torch.save(state_dict, save_path)
73
+
74
+
75
+ if __name__ == '__main__':
76
+
77
+
78
+ # from utils import utils_logger
79
+ # import logging
80
+ # utils_logger.logger_info('a', 'a.log')
81
+ # logger = logging.getLogger('a')
82
+ #
83
+ # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat')
84
+ mcn = hdf5storage.loadmat('models/modelcolor.mat')
85
+
86
+
87
+ #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0])
88
+
89
+ mat_net = OrderedDict()
90
+ for idx in range(25):
91
+ mat_net[str(idx)] = OrderedDict()
92
+ count = -1
93
+
94
+ print(idx)
95
+ for i in range(13):
96
+
97
+ if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv':
98
+
99
+ count += 1
100
+ w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0]
101
+ # print(w.shape)
102
+ w = weights2tensor(w)
103
+ # print(w.shape)
104
+
105
+ b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1]
106
+ b = weights2tensor(b)
107
+ print(b.shape)
108
+
109
+ mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w
110
+ mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b
111
+
112
+ torch.save(mat_net, 'model_zoo/modelcolor.pth')
113
+
114
+
115
+
116
+ # from models.network_dncnn import IRCNN as net
117
+ # network = net(in_nc=3, out_nc=3, nc=64)
118
+ # state_dict = network.state_dict()
119
+ #
120
+ # #show_kv(state_dict)
121
+ #
122
+ # for i in range(len(mcn['net'][0][0][0])):
123
+ # print(mcn['net'][0][0][0][i][0][0][0][0])
124
+ #
125
+ # count = -1
126
+ # mat_net = OrderedDict()
127
+ # for i in range(len(mcn['net'][0][0][0])):
128
+ # if mcn['net'][0][0][0][i][0][0][0][0] == 'conv':
129
+ #
130
+ # count += 1
131
+ # w = mcn['net'][0][0][0][i][0][1][0][0]
132
+ # print(w.shape)
133
+ # w = weights2tensor(w)
134
+ # print(w.shape)
135
+ #
136
+ # b = mcn['net'][0][0][0][i][0][1][0][1]
137
+ # b = weights2tensor(b)
138
+ # print(b.shape)
139
+ #
140
+ # mat_net['model.{:d}.weight'.format(count*2)] = w
141
+ # mat_net['model.{:d}.bias'.format(count*2)] = b
142
+ #
143
+ # torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth')
144
+ #
145
+ #
146
+ #
147
+ # crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth')
148
+ # def show_kv(net):
149
+ # for k, v in net.items():
150
+ # print(k)
151
+ #
152
+ # show_kv(crt_net)
153
+
154
+
155
+ # from models.network_dncnn import DnCNN as net
156
+ # network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
157
+
158
+ # from models.network_srmd import SRMD as net
159
+ # #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
160
+ # network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
161
+ #
162
+ # from models.network_rrdb import RRDB as net
163
+ # network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv')
164
+ #
165
+ # state_dict = network.state_dict()
166
+ # for key, param in state_dict.items():
167
+ # print(key)
168
+ # from models.network_imdn import IMDN as net
169
+ # network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle')
170
+ # state_dict = network.state_dict()
171
+ # mat_net = OrderedDict()
172
+ # for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()):
173
+ # mat_net[key] = param2
174
+ # torch.save(mat_net, 'model_zoo/imdn_x4_1.pth')
175
+ #
176
+
177
+ # net_old = torch.load('net_old.pth')
178
+ # def show_kv(net):
179
+ # for k, v in net.items():
180
+ # print(k)
181
+ #
182
+ # show_kv(net_old)
183
+ # from models.network_dpsr import MSRResNet_prior as net
184
+ # model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
185
+ # state_dict = network.state_dict()
186
+ # net_new = OrderedDict()
187
+ # for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()):
188
+ # net_new[key] = param_old
189
+ # torch.save(net_new, 'net_new.pth')
190
+
191
+
192
+ # print(key)
193
+ # print(param.size())
194
+
195
+
196
+
197
+ # run utils/utils_matconvnet.py
core/data/deg_kair_utils/utils_model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import torch
4
+ from utils import utils_image as util
5
+ import re
6
+ import glob
7
+ import os
8
+
9
+
10
+ '''
11
+ # --------------------------------------------
12
+ # Model
13
+ # --------------------------------------------
14
+ # Kai Zhang (github: https://github.com/cszn)
15
+ # 03/Mar/2019
16
+ # --------------------------------------------
17
+ '''
18
+
19
+
20
+ def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
21
+ """
22
+ # ---------------------------------------
23
+ # Kai Zhang (github: https://github.com/cszn)
24
+ # 03/Mar/2019
25
+ # ---------------------------------------
26
+ Args:
27
+ save_dir: model folder
28
+ net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
29
+ pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
30
+
31
+ Return:
32
+ init_iter: iteration number
33
+ init_path: model path
34
+ # ---------------------------------------
35
+ """
36
+
37
+ file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
38
+ if file_list:
39
+ iter_exist = []
40
+ for file_ in file_list:
41
+ iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
42
+ iter_exist.append(int(iter_current[0]))
43
+ init_iter = max(iter_exist)
44
+ init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
45
+ else:
46
+ init_iter = 0
47
+ init_path = pretrained_path
48
+ return init_iter, init_path
49
+
50
+
51
+ def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
52
+ '''
53
+ # ---------------------------------------
54
+ # Kai Zhang (github: https://github.com/cszn)
55
+ # 03/Mar/2019
56
+ # ---------------------------------------
57
+ Args:
58
+ model: trained model
59
+ L: input Low-quality image
60
+ mode:
61
+ (0) normal: test(model, L)
62
+ (1) pad: test_pad(model, L, modulo=16)
63
+ (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
64
+ (3) x8: test_x8(model, L, modulo=1) ^_^
65
+ (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
66
+ refield: effective receptive filed of the network, 32 is enough
67
+ useful when split, i.e., mode=2, 4
68
+ min_size: min_sizeXmin_size image, e.g., 256X256 image
69
+ useful when split, i.e., mode=2, 4
70
+ sf: scale factor for super-resolution, otherwise 1
71
+ modulo: 1 if split
72
+ useful when pad, i.e., mode=1
73
+
74
+ Returns:
75
+ E: estimated image
76
+ # ---------------------------------------
77
+ '''
78
+ if mode == 0:
79
+ E = test(model, L)
80
+ elif mode == 1:
81
+ E = test_pad(model, L, modulo, sf)
82
+ elif mode == 2:
83
+ E = test_split(model, L, refield, min_size, sf, modulo)
84
+ elif mode == 3:
85
+ E = test_x8(model, L, modulo, sf)
86
+ elif mode == 4:
87
+ E = test_split_x8(model, L, refield, min_size, sf, modulo)
88
+ return E
89
+
90
+
91
+ '''
92
+ # --------------------------------------------
93
+ # normal (0)
94
+ # --------------------------------------------
95
+ '''
96
+
97
+
98
+ def test(model, L):
99
+ E = model(L)
100
+ return E
101
+
102
+
103
+ '''
104
+ # --------------------------------------------
105
+ # pad (1)
106
+ # --------------------------------------------
107
+ '''
108
+
109
+
110
+ def test_pad(model, L, modulo=16, sf=1):
111
+ h, w = L.size()[-2:]
112
+ paddingBottom = int(np.ceil(h/modulo)*modulo-h)
113
+ paddingRight = int(np.ceil(w/modulo)*modulo-w)
114
+ L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
115
+ E = model(L)
116
+ E = E[..., :h*sf, :w*sf]
117
+ return E
118
+
119
+
120
+ '''
121
+ # --------------------------------------------
122
+ # split (function)
123
+ # --------------------------------------------
124
+ '''
125
+
126
+
127
+ def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
128
+ """
129
+ Args:
130
+ model: trained model
131
+ L: input Low-quality image
132
+ refield: effective receptive filed of the network, 32 is enough
133
+ min_size: min_sizeXmin_size image, e.g., 256X256 image
134
+ sf: scale factor for super-resolution, otherwise 1
135
+ modulo: 1 if split
136
+
137
+ Returns:
138
+ E: estimated result
139
+ """
140
+ h, w = L.size()[-2:]
141
+ if h*w <= min_size**2:
142
+ L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
143
+ E = model(L)
144
+ E = E[..., :h*sf, :w*sf]
145
+ else:
146
+ top = slice(0, (h//2//refield+1)*refield)
147
+ bottom = slice(h - (h//2//refield+1)*refield, h)
148
+ left = slice(0, (w//2//refield+1)*refield)
149
+ right = slice(w - (w//2//refield+1)*refield, w)
150
+ Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
151
+
152
+ if h * w <= 4*(min_size**2):
153
+ Es = [model(Ls[i]) for i in range(4)]
154
+ else:
155
+ Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
156
+
157
+ b, c = Es[0].size()[:2]
158
+ E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
159
+
160
+ E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
161
+ E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
162
+ E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
163
+ E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
164
+ return E
165
+
166
+
167
+ '''
168
+ # --------------------------------------------
169
+ # split (2)
170
+ # --------------------------------------------
171
+ '''
172
+
173
+
174
+ def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
175
+ E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
176
+ return E
177
+
178
+
179
+ '''
180
+ # --------------------------------------------
181
+ # x8 (3)
182
+ # --------------------------------------------
183
+ '''
184
+
185
+
186
+ def test_x8(model, L, modulo=1, sf=1):
187
+ E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)]
188
+ for i in range(len(E_list)):
189
+ if i == 3 or i == 5:
190
+ E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i)
191
+ else:
192
+ E_list[i] = util.augment_img_tensor4(E_list[i], mode=i)
193
+ output_cat = torch.stack(E_list, dim=0)
194
+ E = output_cat.mean(dim=0, keepdim=False)
195
+ return E
196
+
197
+
198
+ '''
199
+ # --------------------------------------------
200
+ # split and x8 (4)
201
+ # --------------------------------------------
202
+ '''
203
+
204
+
205
+ def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
206
+ E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
207
+ for k, i in enumerate(range(len(E_list))):
208
+ if i==3 or i==5:
209
+ E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i)
210
+ else:
211
+ E_list[k] = util.augment_img_tensor4(E_list[k], mode=i)
212
+ output_cat = torch.stack(E_list, dim=0)
213
+ E = output_cat.mean(dim=0, keepdim=False)
214
+ return E
215
+
216
+
217
+ '''
218
+ # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
219
+ # _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
220
+ # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
221
+ '''
222
+
223
+
224
+ '''
225
+ # --------------------------------------------
226
+ # print
227
+ # --------------------------------------------
228
+ '''
229
+
230
+
231
+ # --------------------------------------------
232
+ # print model
233
+ # --------------------------------------------
234
+ def print_model(model):
235
+ msg = describe_model(model)
236
+ print(msg)
237
+
238
+
239
+ # --------------------------------------------
240
+ # print params
241
+ # --------------------------------------------
242
+ def print_params(model):
243
+ msg = describe_params(model)
244
+ print(msg)
245
+
246
+
247
+ '''
248
+ # --------------------------------------------
249
+ # information
250
+ # --------------------------------------------
251
+ '''
252
+
253
+
254
+ # --------------------------------------------
255
+ # model inforation
256
+ # --------------------------------------------
257
+ def info_model(model):
258
+ msg = describe_model(model)
259
+ return msg
260
+
261
+
262
+ # --------------------------------------------
263
+ # params inforation
264
+ # --------------------------------------------
265
+ def info_params(model):
266
+ msg = describe_params(model)
267
+ return msg
268
+
269
+
270
+ '''
271
+ # --------------------------------------------
272
+ # description
273
+ # --------------------------------------------
274
+ '''
275
+
276
+
277
+ # --------------------------------------------
278
+ # model name and total number of parameters
279
+ # --------------------------------------------
280
+ def describe_model(model):
281
+ if isinstance(model, torch.nn.DataParallel):
282
+ model = model.module
283
+ msg = '\n'
284
+ msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
285
+ msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
286
+ msg += 'Net structure:\n{}'.format(str(model)) + '\n'
287
+ return msg
288
+
289
+
290
+ # --------------------------------------------
291
+ # parameters description
292
+ # --------------------------------------------
293
+ def describe_params(model):
294
+ if isinstance(model, torch.nn.DataParallel):
295
+ model = model.module
296
+ msg = '\n'
297
+ msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
298
+ for name, param in model.state_dict().items():
299
+ if not 'num_batches_tracked' in name:
300
+ v = param.data.clone().float()
301
+ msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
302
+ return msg
303
+
304
+
305
+ if __name__ == '__main__':
306
+
307
+ class Net(torch.nn.Module):
308
+ def __init__(self, in_channels=3, out_channels=3):
309
+ super(Net, self).__init__()
310
+ self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
311
+
312
+ def forward(self, x):
313
+ x = self.conv(x)
314
+ return x
315
+
316
+ start = torch.cuda.Event(enable_timing=True)
317
+ end = torch.cuda.Event(enable_timing=True)
318
+
319
+ model = Net()
320
+ model = model.eval()
321
+ print_model(model)
322
+ print_params(model)
323
+ x = torch.randn((2,3,401,401))
324
+ torch.cuda.empty_cache()
325
+ with torch.no_grad():
326
+ for mode in range(5):
327
+ y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1)
328
+ print(y.shape)
329
+
330
+ # run utils/utils_model.py
core/data/deg_kair_utils/utils_modelsummary.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+
5
+ '''
6
+ ---- 1) FLOPs: floating point operations
7
+ ---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
8
+ ---- 3) #Conv2d: the number of ‘Conv2d’ layers
9
+ # --------------------------------------------
10
+ # Kai Zhang (github: https://github.com/cszn)
11
+ # 21/July/2020
12
+ # --------------------------------------------
13
+ # Reference
14
+ https://github.com/sovrasov/flops-counter.pytorch.git
15
+
16
+ # If you use this code, please consider the following citation:
17
+
18
+ @inproceedings{zhang2020aim, %
19
+ title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
20
+ author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
21
+ booktitle={European Conference on Computer Vision Workshops},
22
+ year={2020}
23
+ }
24
+ # --------------------------------------------
25
+ '''
26
+
27
+ def get_model_flops(model, input_res, print_per_layer_stat=True,
28
+ input_constructor=None):
29
+ assert type(input_res) is tuple, 'Please provide the size of the input image.'
30
+ assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
31
+ flops_model = add_flops_counting_methods(model)
32
+ flops_model.eval().start_flops_count()
33
+ if input_constructor:
34
+ input = input_constructor(input_res)
35
+ _ = flops_model(**input)
36
+ else:
37
+ device = list(flops_model.parameters())[-1].device
38
+ batch = torch.FloatTensor(1, *input_res).to(device)
39
+ _ = flops_model(batch)
40
+
41
+ if print_per_layer_stat:
42
+ print_model_with_flops(flops_model)
43
+ flops_count = flops_model.compute_average_flops_cost()
44
+ flops_model.stop_flops_count()
45
+
46
+ return flops_count
47
+
48
+ def get_model_activation(model, input_res, input_constructor=None):
49
+ assert type(input_res) is tuple, 'Please provide the size of the input image.'
50
+ assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
51
+ activation_model = add_activation_counting_methods(model)
52
+ activation_model.eval().start_activation_count()
53
+ if input_constructor:
54
+ input = input_constructor(input_res)
55
+ _ = activation_model(**input)
56
+ else:
57
+ device = list(activation_model.parameters())[-1].device
58
+ batch = torch.FloatTensor(1, *input_res).to(device)
59
+ _ = activation_model(batch)
60
+
61
+ activation_count, num_conv = activation_model.compute_average_activation_cost()
62
+ activation_model.stop_activation_count()
63
+
64
+ return activation_count, num_conv
65
+
66
+
67
+ def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
68
+ input_constructor=None):
69
+ assert type(input_res) is tuple
70
+ assert len(input_res) >= 3
71
+ flops_model = add_flops_counting_methods(model)
72
+ flops_model.eval().start_flops_count()
73
+ if input_constructor:
74
+ input = input_constructor(input_res)
75
+ _ = flops_model(**input)
76
+ else:
77
+ batch = torch.FloatTensor(1, *input_res)
78
+ _ = flops_model(batch)
79
+
80
+ if print_per_layer_stat:
81
+ print_model_with_flops(flops_model)
82
+ flops_count = flops_model.compute_average_flops_cost()
83
+ params_count = get_model_parameters_number(flops_model)
84
+ flops_model.stop_flops_count()
85
+
86
+ if as_strings:
87
+ return flops_to_string(flops_count), params_to_string(params_count)
88
+
89
+ return flops_count, params_count
90
+
91
+
92
+ def flops_to_string(flops, units='GMac', precision=2):
93
+ if units is None:
94
+ if flops // 10**9 > 0:
95
+ return str(round(flops / 10.**9, precision)) + ' GMac'
96
+ elif flops // 10**6 > 0:
97
+ return str(round(flops / 10.**6, precision)) + ' MMac'
98
+ elif flops // 10**3 > 0:
99
+ return str(round(flops / 10.**3, precision)) + ' KMac'
100
+ else:
101
+ return str(flops) + ' Mac'
102
+ else:
103
+ if units == 'GMac':
104
+ return str(round(flops / 10.**9, precision)) + ' ' + units
105
+ elif units == 'MMac':
106
+ return str(round(flops / 10.**6, precision)) + ' ' + units
107
+ elif units == 'KMac':
108
+ return str(round(flops / 10.**3, precision)) + ' ' + units
109
+ else:
110
+ return str(flops) + ' Mac'
111
+
112
+
113
+ def params_to_string(params_num):
114
+ if params_num // 10 ** 6 > 0:
115
+ return str(round(params_num / 10 ** 6, 2)) + ' M'
116
+ elif params_num // 10 ** 3:
117
+ return str(round(params_num / 10 ** 3, 2)) + ' k'
118
+ else:
119
+ return str(params_num)
120
+
121
+
122
+ def print_model_with_flops(model, units='GMac', precision=3):
123
+ total_flops = model.compute_average_flops_cost()
124
+
125
+ def accumulate_flops(self):
126
+ if is_supported_instance(self):
127
+ return self.__flops__ / model.__batch_counter__
128
+ else:
129
+ sum = 0
130
+ for m in self.children():
131
+ sum += m.accumulate_flops()
132
+ return sum
133
+
134
+ def flops_repr(self):
135
+ accumulated_flops_cost = self.accumulate_flops()
136
+ return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
137
+ '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
138
+ self.original_extra_repr()])
139
+
140
+ def add_extra_repr(m):
141
+ m.accumulate_flops = accumulate_flops.__get__(m)
142
+ flops_extra_repr = flops_repr.__get__(m)
143
+ if m.extra_repr != flops_extra_repr:
144
+ m.original_extra_repr = m.extra_repr
145
+ m.extra_repr = flops_extra_repr
146
+ assert m.extra_repr != m.original_extra_repr
147
+
148
+ def del_extra_repr(m):
149
+ if hasattr(m, 'original_extra_repr'):
150
+ m.extra_repr = m.original_extra_repr
151
+ del m.original_extra_repr
152
+ if hasattr(m, 'accumulate_flops'):
153
+ del m.accumulate_flops
154
+
155
+ model.apply(add_extra_repr)
156
+ print(model)
157
+ model.apply(del_extra_repr)
158
+
159
+
160
+ def get_model_parameters_number(model):
161
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
162
+ return params_num
163
+
164
+
165
+ def add_flops_counting_methods(net_main_module):
166
+ # adding additional methods to the existing module object,
167
+ # this is done this way so that each function has access to self object
168
+ # embed()
169
+ net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
170
+ net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
171
+ net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
172
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
173
+
174
+ net_main_module.reset_flops_count()
175
+ return net_main_module
176
+
177
+
178
+ def compute_average_flops_cost(self):
179
+ """
180
+ A method that will be available after add_flops_counting_methods() is called
181
+ on a desired net object.
182
+
183
+ Returns current mean flops consumption per image.
184
+
185
+ """
186
+
187
+ flops_sum = 0
188
+ for module in self.modules():
189
+ if is_supported_instance(module):
190
+ flops_sum += module.__flops__
191
+
192
+ return flops_sum
193
+
194
+
195
+ def start_flops_count(self):
196
+ """
197
+ A method that will be available after add_flops_counting_methods() is called
198
+ on a desired net object.
199
+
200
+ Activates the computation of mean flops consumption per image.
201
+ Call it before you run the network.
202
+
203
+ """
204
+ self.apply(add_flops_counter_hook_function)
205
+
206
+
207
+ def stop_flops_count(self):
208
+ """
209
+ A method that will be available after add_flops_counting_methods() is called
210
+ on a desired net object.
211
+
212
+ Stops computing the mean flops consumption per image.
213
+ Call whenever you want to pause the computation.
214
+
215
+ """
216
+ self.apply(remove_flops_counter_hook_function)
217
+
218
+
219
+ def reset_flops_count(self):
220
+ """
221
+ A method that will be available after add_flops_counting_methods() is called
222
+ on a desired net object.
223
+
224
+ Resets statistics computed so far.
225
+
226
+ """
227
+ self.apply(add_flops_counter_variable_or_reset)
228
+
229
+
230
+ def add_flops_counter_hook_function(module):
231
+ if is_supported_instance(module):
232
+ if hasattr(module, '__flops_handle__'):
233
+ return
234
+
235
+ if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
236
+ handle = module.register_forward_hook(conv_flops_counter_hook)
237
+ elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
238
+ handle = module.register_forward_hook(relu_flops_counter_hook)
239
+ elif isinstance(module, nn.Linear):
240
+ handle = module.register_forward_hook(linear_flops_counter_hook)
241
+ elif isinstance(module, (nn.BatchNorm2d)):
242
+ handle = module.register_forward_hook(bn_flops_counter_hook)
243
+ else:
244
+ handle = module.register_forward_hook(empty_flops_counter_hook)
245
+ module.__flops_handle__ = handle
246
+
247
+
248
+ def remove_flops_counter_hook_function(module):
249
+ if is_supported_instance(module):
250
+ if hasattr(module, '__flops_handle__'):
251
+ module.__flops_handle__.remove()
252
+ del module.__flops_handle__
253
+
254
+
255
+ def add_flops_counter_variable_or_reset(module):
256
+ if is_supported_instance(module):
257
+ module.__flops__ = 0
258
+
259
+
260
+ # ---- Internal functions
261
+ def is_supported_instance(module):
262
+ if isinstance(module,
263
+ (
264
+ nn.Conv2d, nn.ConvTranspose2d,
265
+ nn.BatchNorm2d,
266
+ nn.Linear,
267
+ nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
268
+ )):
269
+ return True
270
+
271
+ return False
272
+
273
+
274
+ def conv_flops_counter_hook(conv_module, input, output):
275
+ # Can have multiple inputs, getting the first one
276
+ # input = input[0]
277
+
278
+ batch_size = output.shape[0]
279
+ output_dims = list(output.shape[2:])
280
+
281
+ kernel_dims = list(conv_module.kernel_size)
282
+ in_channels = conv_module.in_channels
283
+ out_channels = conv_module.out_channels
284
+ groups = conv_module.groups
285
+
286
+ filters_per_channel = out_channels // groups
287
+ conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
288
+
289
+ active_elements_count = batch_size * np.prod(output_dims)
290
+ overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
291
+
292
+ # overall_flops = overall_conv_flops
293
+
294
+ conv_module.__flops__ += int(overall_conv_flops)
295
+ # conv_module.__output_dims__ = output_dims
296
+
297
+
298
+ def relu_flops_counter_hook(module, input, output):
299
+ active_elements_count = output.numel()
300
+ module.__flops__ += int(active_elements_count)
301
+ # print(module.__flops__, id(module))
302
+ # print(module)
303
+
304
+
305
+ def linear_flops_counter_hook(module, input, output):
306
+ input = input[0]
307
+ if len(input.shape) == 1:
308
+ batch_size = 1
309
+ module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
310
+ else:
311
+ batch_size = input.shape[0]
312
+ module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
313
+
314
+
315
+ def bn_flops_counter_hook(module, input, output):
316
+ # input = input[0]
317
+ # TODO: need to check here
318
+ # batch_flops = np.prod(input.shape)
319
+ # if module.affine:
320
+ # batch_flops *= 2
321
+ # module.__flops__ += int(batch_flops)
322
+ batch = output.shape[0]
323
+ output_dims = output.shape[2:]
324
+ channels = module.num_features
325
+ batch_flops = batch * channels * np.prod(output_dims)
326
+ if module.affine:
327
+ batch_flops *= 2
328
+ module.__flops__ += int(batch_flops)
329
+
330
+
331
+ # ---- Count the number of convolutional layers and the activation
332
+ def add_activation_counting_methods(net_main_module):
333
+ # adding additional methods to the existing module object,
334
+ # this is done this way so that each function has access to self object
335
+ # embed()
336
+ net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
337
+ net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
338
+ net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
339
+ net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
340
+
341
+ net_main_module.reset_activation_count()
342
+ return net_main_module
343
+
344
+
345
+ def compute_average_activation_cost(self):
346
+ """
347
+ A method that will be available after add_activation_counting_methods() is called
348
+ on a desired net object.
349
+
350
+ Returns current mean activation consumption per image.
351
+
352
+ """
353
+
354
+ activation_sum = 0
355
+ num_conv = 0
356
+ for module in self.modules():
357
+ if is_supported_instance_for_activation(module):
358
+ activation_sum += module.__activation__
359
+ num_conv += module.__num_conv__
360
+ return activation_sum, num_conv
361
+
362
+
363
+ def start_activation_count(self):
364
+ """
365
+ A method that will be available after add_activation_counting_methods() is called
366
+ on a desired net object.
367
+
368
+ Activates the computation of mean activation consumption per image.
369
+ Call it before you run the network.
370
+
371
+ """
372
+ self.apply(add_activation_counter_hook_function)
373
+
374
+
375
+ def stop_activation_count(self):
376
+ """
377
+ A method that will be available after add_activation_counting_methods() is called
378
+ on a desired net object.
379
+
380
+ Stops computing the mean activation consumption per image.
381
+ Call whenever you want to pause the computation.
382
+
383
+ """
384
+ self.apply(remove_activation_counter_hook_function)
385
+
386
+
387
+ def reset_activation_count(self):
388
+ """
389
+ A method that will be available after add_activation_counting_methods() is called
390
+ on a desired net object.
391
+
392
+ Resets statistics computed so far.
393
+
394
+ """
395
+ self.apply(add_activation_counter_variable_or_reset)
396
+
397
+
398
+ def add_activation_counter_hook_function(module):
399
+ if is_supported_instance_for_activation(module):
400
+ if hasattr(module, '__activation_handle__'):
401
+ return
402
+
403
+ if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
404
+ handle = module.register_forward_hook(conv_activation_counter_hook)
405
+ module.__activation_handle__ = handle
406
+
407
+
408
+ def remove_activation_counter_hook_function(module):
409
+ if is_supported_instance_for_activation(module):
410
+ if hasattr(module, '__activation_handle__'):
411
+ module.__activation_handle__.remove()
412
+ del module.__activation_handle__
413
+
414
+
415
+ def add_activation_counter_variable_or_reset(module):
416
+ if is_supported_instance_for_activation(module):
417
+ module.__activation__ = 0
418
+ module.__num_conv__ = 0
419
+
420
+
421
+ def is_supported_instance_for_activation(module):
422
+ if isinstance(module,
423
+ (
424
+ nn.Conv2d, nn.ConvTranspose2d,
425
+ )):
426
+ return True
427
+
428
+ return False
429
+
430
+ def conv_activation_counter_hook(module, input, output):
431
+ """
432
+ Calculate the activations in the convolutional operation.
433
+ Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
434
+ :param module:
435
+ :param input:
436
+ :param output:
437
+ :return:
438
+ """
439
+ module.__activation__ += output.numel()
440
+ module.__num_conv__ += 1
441
+
442
+
443
+ def empty_flops_counter_hook(module, input, output):
444
+ module.__flops__ += 0
445
+
446
+
447
+ def upsample_flops_counter_hook(module, input, output):
448
+ output_size = output[0]
449
+ batch_size = output_size.shape[0]
450
+ output_elements_count = batch_size
451
+ for val in output_size.shape[1:]:
452
+ output_elements_count *= val
453
+ module.__flops__ += int(output_elements_count)
454
+
455
+
456
+ def pool_flops_counter_hook(module, input, output):
457
+ input = input[0]
458
+ module.__flops__ += int(np.prod(input.shape))
459
+
460
+
461
+ def dconv_flops_counter_hook(dconv_module, input, output):
462
+ input = input[0]
463
+
464
+ batch_size = input.shape[0]
465
+ output_dims = list(output.shape[2:])
466
+
467
+ m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
468
+ out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
469
+ # groups = dconv_module.groups
470
+
471
+ # filters_per_channel = out_channels // groups
472
+ conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
473
+ conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
474
+ active_elements_count = batch_size * np.prod(output_dims)
475
+
476
+ overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
477
+ overall_flops = overall_conv_flops
478
+
479
+ dconv_module.__flops__ += int(overall_flops)
480
+ # dconv_module.__output_dims__ = output_dims
481
+
482
+
483
+
484
+
485
+
core/data/deg_kair_utils/utils_option.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ from datetime import datetime
4
+ import json
5
+ import re
6
+ import glob
7
+
8
+
9
+ '''
10
+ # --------------------------------------------
11
+ # Kai Zhang (github: https://github.com/cszn)
12
+ # 03/Mar/2019
13
+ # --------------------------------------------
14
+ # https://github.com/xinntao/BasicSR
15
+ # --------------------------------------------
16
+ '''
17
+
18
+
19
+ def get_timestamp():
20
+ return datetime.now().strftime('_%y%m%d_%H%M%S')
21
+
22
+
23
+ def parse(opt_path, is_train=True):
24
+
25
+ # ----------------------------------------
26
+ # remove comments starting with '//'
27
+ # ----------------------------------------
28
+ json_str = ''
29
+ with open(opt_path, 'r') as f:
30
+ for line in f:
31
+ line = line.split('//')[0] + '\n'
32
+ json_str += line
33
+
34
+ # ----------------------------------------
35
+ # initialize opt
36
+ # ----------------------------------------
37
+ opt = json.loads(json_str, object_pairs_hook=OrderedDict)
38
+
39
+ opt['opt_path'] = opt_path
40
+ opt['is_train'] = is_train
41
+
42
+ # ----------------------------------------
43
+ # set default
44
+ # ----------------------------------------
45
+ if 'merge_bn' not in opt:
46
+ opt['merge_bn'] = False
47
+ opt['merge_bn_startpoint'] = -1
48
+
49
+ if 'scale' not in opt:
50
+ opt['scale'] = 1
51
+
52
+ # ----------------------------------------
53
+ # datasets
54
+ # ----------------------------------------
55
+ for phase, dataset in opt['datasets'].items():
56
+ phase = phase.split('_')[0]
57
+ dataset['phase'] = phase
58
+ dataset['scale'] = opt['scale'] # broadcast
59
+ dataset['n_channels'] = opt['n_channels'] # broadcast
60
+ if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
61
+ dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
62
+ if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
63
+ dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
64
+
65
+ # ----------------------------------------
66
+ # path
67
+ # ----------------------------------------
68
+ for key, path in opt['path'].items():
69
+ if path and key in opt['path']:
70
+ opt['path'][key] = os.path.expanduser(path)
71
+
72
+ path_task = os.path.join(opt['path']['root'], opt['task'])
73
+ opt['path']['task'] = path_task
74
+ opt['path']['log'] = path_task
75
+ opt['path']['options'] = os.path.join(path_task, 'options')
76
+
77
+ if is_train:
78
+ opt['path']['models'] = os.path.join(path_task, 'models')
79
+ opt['path']['images'] = os.path.join(path_task, 'images')
80
+ else: # test
81
+ opt['path']['images'] = os.path.join(path_task, 'test_images')
82
+
83
+ # ----------------------------------------
84
+ # network
85
+ # ----------------------------------------
86
+ opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
87
+
88
+ # ----------------------------------------
89
+ # GPU devices
90
+ # ----------------------------------------
91
+ gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
92
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
93
+ print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
94
+
95
+ # ----------------------------------------
96
+ # default setting for distributeddataparallel
97
+ # ----------------------------------------
98
+ if 'find_unused_parameters' not in opt:
99
+ opt['find_unused_parameters'] = True
100
+ if 'use_static_graph' not in opt:
101
+ opt['use_static_graph'] = False
102
+ if 'dist' not in opt:
103
+ opt['dist'] = False
104
+ opt['num_gpu'] = len(opt['gpu_ids'])
105
+ print('number of GPUs is: ' + str(opt['num_gpu']))
106
+
107
+ # ----------------------------------------
108
+ # default setting for perceptual loss
109
+ # ----------------------------------------
110
+ if 'F_feature_layer' not in opt['train']:
111
+ opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34]
112
+ if 'F_weights' not in opt['train']:
113
+ opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0]
114
+ if 'F_lossfn_type' not in opt['train']:
115
+ opt['train']['F_lossfn_type'] = 'l1'
116
+ if 'F_use_input_norm' not in opt['train']:
117
+ opt['train']['F_use_input_norm'] = True
118
+ if 'F_use_range_norm' not in opt['train']:
119
+ opt['train']['F_use_range_norm'] = False
120
+
121
+ # ----------------------------------------
122
+ # default setting for optimizer
123
+ # ----------------------------------------
124
+ if 'G_optimizer_type' not in opt['train']:
125
+ opt['train']['G_optimizer_type'] = "adam"
126
+ if 'G_optimizer_betas' not in opt['train']:
127
+ opt['train']['G_optimizer_betas'] = [0.9,0.999]
128
+ if 'G_scheduler_restart_weights' not in opt['train']:
129
+ opt['train']['G_scheduler_restart_weights'] = 1
130
+ if 'G_optimizer_wd' not in opt['train']:
131
+ opt['train']['G_optimizer_wd'] = 0
132
+ if 'G_optimizer_reuse' not in opt['train']:
133
+ opt['train']['G_optimizer_reuse'] = False
134
+ if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
135
+ opt['train']['D_optimizer_reuse'] = False
136
+
137
+ # ----------------------------------------
138
+ # default setting of strict for model loading
139
+ # ----------------------------------------
140
+ if 'G_param_strict' not in opt['train']:
141
+ opt['train']['G_param_strict'] = True
142
+ if 'netD' in opt and 'D_param_strict' not in opt['path']:
143
+ opt['train']['D_param_strict'] = True
144
+ if 'E_param_strict' not in opt['path']:
145
+ opt['train']['E_param_strict'] = True
146
+
147
+ # ----------------------------------------
148
+ # Exponential Moving Average
149
+ # ----------------------------------------
150
+ if 'E_decay' not in opt['train']:
151
+ opt['train']['E_decay'] = 0
152
+
153
+ # ----------------------------------------
154
+ # default setting for discriminator
155
+ # ----------------------------------------
156
+ if 'netD' in opt:
157
+ if 'net_type' not in opt['netD']:
158
+ opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet
159
+ if 'in_nc' not in opt['netD']:
160
+ opt['netD']['in_nc'] = 3
161
+ if 'base_nc' not in opt['netD']:
162
+ opt['netD']['base_nc'] = 64
163
+ if 'n_layers' not in opt['netD']:
164
+ opt['netD']['n_layers'] = 3
165
+ if 'norm_type' not in opt['netD']:
166
+ opt['netD']['norm_type'] = 'spectral'
167
+
168
+
169
+ return opt
170
+
171
+
172
+ def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
173
+ """
174
+ Args:
175
+ save_dir: model folder
176
+ net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
177
+ pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
178
+
179
+ Return:
180
+ init_iter: iteration number
181
+ init_path: model path
182
+ """
183
+ file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
184
+ if file_list:
185
+ iter_exist = []
186
+ for file_ in file_list:
187
+ iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
188
+ iter_exist.append(int(iter_current[0]))
189
+ init_iter = max(iter_exist)
190
+ init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
191
+ else:
192
+ init_iter = 0
193
+ init_path = pretrained_path
194
+ return init_iter, init_path
195
+
196
+
197
+ '''
198
+ # --------------------------------------------
199
+ # convert the opt into json file
200
+ # --------------------------------------------
201
+ '''
202
+
203
+
204
+ def save(opt):
205
+ opt_path = opt['opt_path']
206
+ opt_path_copy = opt['path']['options']
207
+ dirname, filename_ext = os.path.split(opt_path)
208
+ filename, ext = os.path.splitext(filename_ext)
209
+ dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
210
+ with open(dump_path, 'w') as dump_file:
211
+ json.dump(opt, dump_file, indent=2)
212
+
213
+
214
+ '''
215
+ # --------------------------------------------
216
+ # dict to string for logger
217
+ # --------------------------------------------
218
+ '''
219
+
220
+
221
+ def dict2str(opt, indent_l=1):
222
+ msg = ''
223
+ for k, v in opt.items():
224
+ if isinstance(v, dict):
225
+ msg += ' ' * (indent_l * 2) + k + ':[\n'
226
+ msg += dict2str(v, indent_l + 1)
227
+ msg += ' ' * (indent_l * 2) + ']\n'
228
+ else:
229
+ msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
230
+ return msg
231
+
232
+
233
+ '''
234
+ # --------------------------------------------
235
+ # convert OrderedDict to NoneDict,
236
+ # return None for missing key
237
+ # --------------------------------------------
238
+ '''
239
+
240
+
241
+ def dict_to_nonedict(opt):
242
+ if isinstance(opt, dict):
243
+ new_opt = dict()
244
+ for key, sub_opt in opt.items():
245
+ new_opt[key] = dict_to_nonedict(sub_opt)
246
+ return NoneDict(**new_opt)
247
+ elif isinstance(opt, list):
248
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
249
+ else:
250
+ return opt
251
+
252
+
253
+ class NoneDict(dict):
254
+ def __missing__(self, key):
255
+ return None
core/data/deg_kair_utils/utils_params.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torchvision
4
+
5
+ from models import basicblock as B
6
+
7
+ def show_kv(net):
8
+ for k, v in net.items():
9
+ print(k)
10
+
11
+ # should run train debug mode first to get an initial model
12
+ #crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
13
+ #
14
+ #for k, v in crt_net.items():
15
+ # print(k)
16
+ #for k, v in crt_net.items():
17
+ # if k in pretrained_net:
18
+ # crt_net[k] = pretrained_net[k]
19
+ # print('replace ... ', k)
20
+
21
+ # x2 -> x4
22
+ #crt_net['model.5.weight'] = pretrained_net['model.2.weight']
23
+ #crt_net['model.5.bias'] = pretrained_net['model.2.bias']
24
+ #crt_net['model.8.weight'] = pretrained_net['model.5.weight']
25
+ #crt_net['model.8.bias'] = pretrained_net['model.5.bias']
26
+ #crt_net['model.10.weight'] = pretrained_net['model.7.weight']
27
+ #crt_net['model.10.bias'] = pretrained_net['model.7.bias']
28
+ #torch.save(crt_net, '../pretrained_tmp.pth')
29
+
30
+ # x2 -> x3
31
+ '''
32
+ in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
33
+ new_filter = torch.Tensor(576, 64, 3, 3)
34
+ new_filter[0:256, :, :, :] = in_filter
35
+ new_filter[256:512, :, :, :] = in_filter
36
+ new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
37
+ crt_net['model.2.weight'] = new_filter
38
+
39
+ in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3
40
+ new_bias = torch.Tensor(576)
41
+ new_bias[0:256] = in_bias
42
+ new_bias[256:512] = in_bias
43
+ new_bias[512:] = in_bias[0:576 - 512]
44
+ crt_net['model.2.bias'] = new_bias
45
+
46
+ torch.save(crt_net, '../pretrained_tmp.pth')
47
+ '''
48
+
49
+ # x2 -> x8
50
+ '''
51
+ crt_net['model.5.weight'] = pretrained_net['model.2.weight']
52
+ crt_net['model.5.bias'] = pretrained_net['model.2.bias']
53
+ crt_net['model.8.weight'] = pretrained_net['model.2.weight']
54
+ crt_net['model.8.bias'] = pretrained_net['model.2.bias']
55
+ crt_net['model.11.weight'] = pretrained_net['model.5.weight']
56
+ crt_net['model.11.bias'] = pretrained_net['model.5.bias']
57
+ crt_net['model.13.weight'] = pretrained_net['model.7.weight']
58
+ crt_net['model.13.bias'] = pretrained_net['model.7.bias']
59
+ torch.save(crt_net, '../pretrained_tmp.pth')
60
+ '''
61
+
62
+ # x3/4/8 RGB -> Y
63
+
64
+ def rgb2gray_net(net, only_input=True):
65
+
66
+ if only_input:
67
+ in_filter = net['0.weight']
68
+ in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
69
+ in_new_filter.unsqueeze_(1)
70
+ net['0.weight'] = in_new_filter
71
+
72
+ # out_filter = pretrained_net['model.13.weight']
73
+ # out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
74
+ # out_filter[2, :, :, :] * 0.114
75
+ # out_new_filter.unsqueeze_(0)
76
+ # crt_net['model.13.weight'] = out_new_filter
77
+ # out_bias = pretrained_net['model.13.bias']
78
+ # out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
79
+ # out_new_bias = torch.Tensor(1).fill_(out_new_bias)
80
+ # crt_net['model.13.bias'] = out_new_bias
81
+
82
+ # torch.save(crt_net, '../pretrained_tmp.pth')
83
+
84
+ return net
85
+
86
+
87
+
88
+ if __name__ == '__main__':
89
+
90
+ net = torchvision.models.vgg19(pretrained=True)
91
+ for k,v in net.features.named_parameters():
92
+ if k=='0.weight':
93
+ in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
94
+ in_new_filter.unsqueeze_(1)
95
+ v = in_new_filter
96
+ print(v.shape)
97
+ print(v[0,0,0,0])
98
+ if k=='0.bias':
99
+ in_new_bias = v
100
+ print(v[0])
101
+
102
+ print(net.features[0])
103
+
104
+ net.features[0] = B.conv(1, 64, mode='C')
105
+
106
+ print(net.features[0])
107
+ net.features[0].weight.data=in_new_filter
108
+ net.features[0].bias.data=in_new_bias
109
+
110
+ for k,v in net.features.named_parameters():
111
+ if k=='0.weight':
112
+ print(v[0,0,0,0])
113
+ if k=='0.bias':
114
+ print(v[0])
115
+
116
+ # transfer parameters of old model to new one
117
+ model_old = torch.load(model_path)
118
+ state_dict = model.state_dict()
119
+ for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
120
+ state_dict[key2] = param
121
+ print([key, key2])
122
+ # print([param.size(), param2.size()])
123
+ torch.save(state_dict, 'model_new.pth')
124
+
125
+
126
+ # rgb2gray_net(net)
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
core/data/deg_kair_utils/utils_receptivefield.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # online calculation: https://fomoro.com/research/article/receptive-field-calculator#
4
+
5
+ # [filter size, stride, padding]
6
+ #Assume the two dimensions are the same
7
+ #Each kernel requires the following parameters:
8
+ # - k_i: kernel size
9
+ # - s_i: stride
10
+ # - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow)
11
+ #
12
+ #Each layer i requires the following parameters to be fully represented:
13
+ # - n_i: number of feature (data layer has n_1 = imagesize )
14
+ # - j_i: distance (projected to image pixel distance) between center of two adjacent features
15
+ # - r_i: receptive field of a feature in layer i
16
+ # - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding)
17
+
18
+ import math
19
+
20
+ def outFromIn(conv, layerIn):
21
+ n_in = layerIn[0]
22
+ j_in = layerIn[1]
23
+ r_in = layerIn[2]
24
+ start_in = layerIn[3]
25
+ k = conv[0]
26
+ s = conv[1]
27
+ p = conv[2]
28
+
29
+ n_out = math.floor((n_in - k + 2*p)/s) + 1
30
+ actualP = (n_out-1)*s - n_in + k
31
+ pR = math.ceil(actualP/2)
32
+ pL = math.floor(actualP/2)
33
+
34
+ j_out = j_in * s
35
+ r_out = r_in + (k - 1)*j_in
36
+ start_out = start_in + ((k-1)/2 - pL)*j_in
37
+ return n_out, j_out, r_out, start_out
38
+
39
+ def printLayer(layer, layer_name):
40
+ print(layer_name + ":")
41
+ print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3]))
42
+
43
+
44
+
45
+ layerInfos = []
46
+ if __name__ == '__main__':
47
+
48
+ convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]]
49
+ layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12']
50
+ imsize = 128
51
+
52
+ print ("-------Net summary------")
53
+ currentLayer = [imsize, 1, 1, 0.5]
54
+ printLayer(currentLayer, "input image")
55
+ for i in range(len(convnet)):
56
+ currentLayer = outFromIn(convnet[i], currentLayer)
57
+ layerInfos.append(currentLayer)
58
+ printLayer(currentLayer, layer_names[i])
59
+
60
+
61
+ # run utils/utils_receptivefield.py
62
+
core/data/deg_kair_utils/utils_regularizers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ '''
6
+ # --------------------------------------------
7
+ # Kai Zhang (github: https://github.com/cszn)
8
+ # 03/Mar/2019
9
+ # --------------------------------------------
10
+ '''
11
+
12
+
13
+ # --------------------------------------------
14
+ # SVD Orthogonal Regularization
15
+ # --------------------------------------------
16
+ def regularizer_orth(m):
17
+ """
18
+ # ----------------------------------------
19
+ # SVD Orthogonal Regularization
20
+ # ----------------------------------------
21
+ # Applies regularization to the training by performing the
22
+ # orthogonalization technique described in the paper
23
+ # This function is to be called by the torch.nn.Module.apply() method,
24
+ # which applies svd_orthogonalization() to every layer of the model.
25
+ # usage: net.apply(regularizer_orth)
26
+ # ----------------------------------------
27
+ """
28
+ classname = m.__class__.__name__
29
+ if classname.find('Conv') != -1:
30
+ w = m.weight.data.clone()
31
+ c_out, c_in, f1, f2 = w.size()
32
+ # dtype = m.weight.data.type()
33
+ w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
34
+ # self.netG.apply(svd_orthogonalization)
35
+ u, s, v = torch.svd(w)
36
+ s[s > 1.5] = s[s > 1.5] - 1e-4
37
+ s[s < 0.5] = s[s < 0.5] + 1e-4
38
+ w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
39
+ m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
40
+ else:
41
+ pass
42
+
43
+
44
+ # --------------------------------------------
45
+ # SVD Orthogonal Regularization
46
+ # --------------------------------------------
47
+ def regularizer_orth2(m):
48
+ """
49
+ # ----------------------------------------
50
+ # Applies regularization to the training by performing the
51
+ # orthogonalization technique described in the paper
52
+ # This function is to be called by the torch.nn.Module.apply() method,
53
+ # which applies svd_orthogonalization() to every layer of the model.
54
+ # usage: net.apply(regularizer_orth2)
55
+ # ----------------------------------------
56
+ """
57
+ classname = m.__class__.__name__
58
+ if classname.find('Conv') != -1:
59
+ w = m.weight.data.clone()
60
+ c_out, c_in, f1, f2 = w.size()
61
+ # dtype = m.weight.data.type()
62
+ w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
63
+ u, s, v = torch.svd(w)
64
+ s_mean = s.mean()
65
+ s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
66
+ s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
67
+ w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
68
+ m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
69
+ else:
70
+ pass
71
+
72
+
73
+
74
+ def regularizer_clip(m):
75
+ """
76
+ # ----------------------------------------
77
+ # usage: net.apply(regularizer_clip)
78
+ # ----------------------------------------
79
+ """
80
+ eps = 1e-4
81
+ c_min = -1.5
82
+ c_max = 1.5
83
+
84
+ classname = m.__class__.__name__
85
+ if classname.find('Conv') != -1 or classname.find('Linear') != -1:
86
+ w = m.weight.data.clone()
87
+ w[w > c_max] -= eps
88
+ w[w < c_min] += eps
89
+ m.weight.data = w
90
+
91
+ if m.bias is not None:
92
+ b = m.bias.data.clone()
93
+ b[b > c_max] -= eps
94
+ b[b < c_min] += eps
95
+ m.bias.data = b
96
+
97
+ # elif classname.find('BatchNorm2d') != -1:
98
+ #
99
+ # rv = m.running_var.data.clone()
100
+ # rm = m.running_mean.data.clone()
101
+ #
102
+ # if m.affine:
103
+ # m.weight.data
104
+ # m.bias.data
core/data/deg_kair_utils/utils_sisr.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from utils import utils_image as util
3
+ import random
4
+
5
+ import scipy
6
+ import scipy.stats as ss
7
+ import scipy.io as io
8
+ from scipy import ndimage
9
+ from scipy.interpolate import interp2d
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ """
16
+ # --------------------------------------------
17
+ # Super-Resolution
18
+ # --------------------------------------------
19
+ #
20
+ # Kai Zhang ([email protected])
21
+ # https://github.com/cszn
22
+ # modified by Kai Zhang (github: https://github.com/cszn)
23
+ # 03/03/2020
24
+ # --------------------------------------------
25
+ """
26
+
27
+
28
+ """
29
+ # --------------------------------------------
30
+ # anisotropic Gaussian kernels
31
+ # --------------------------------------------
32
+ """
33
+
34
+
35
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
36
+ """ generate an anisotropic Gaussian kernel
37
+ Args:
38
+ ksize : e.g., 15, kernel size
39
+ theta : [0, pi], rotation angle range
40
+ l1 : [0.1,50], scaling of eigenvalues
41
+ l2 : [0.1,l1], scaling of eigenvalues
42
+ If l1 = l2, will get an isotropic Gaussian kernel.
43
+ Returns:
44
+ k : kernel
45
+ """
46
+
47
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
48
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
49
+ D = np.array([[l1, 0], [0, l2]])
50
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
51
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
52
+
53
+ return k
54
+
55
+
56
+ def gm_blur_kernel(mean, cov, size=15):
57
+ center = size / 2.0 + 0.5
58
+ k = np.zeros([size, size])
59
+ for y in range(size):
60
+ for x in range(size):
61
+ cy = y - center + 1
62
+ cx = x - center + 1
63
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
64
+
65
+ k = k / np.sum(k)
66
+ return k
67
+
68
+
69
+ """
70
+ # --------------------------------------------
71
+ # calculate PCA projection matrix
72
+ # --------------------------------------------
73
+ """
74
+
75
+
76
+ def get_pca_matrix(x, dim_pca=15):
77
+ """
78
+ Args:
79
+ x: 225x10000 matrix
80
+ dim_pca: 15
81
+ Returns:
82
+ pca_matrix: 15x225
83
+ """
84
+ C = np.dot(x, x.T)
85
+ w, v = scipy.linalg.eigh(C)
86
+ pca_matrix = v[:, -dim_pca:].T
87
+
88
+ return pca_matrix
89
+
90
+
91
+ def show_pca(x):
92
+ """
93
+ x: PCA projection matrix, e.g., 15x225
94
+ """
95
+ for i in range(x.shape[0]):
96
+ xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
97
+ util.surf(xc)
98
+
99
+
100
+ def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
101
+ kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
102
+ for i in range(num_samples):
103
+
104
+ theta = np.pi*np.random.rand(1)
105
+ l1 = 0.1+l_max*np.random.rand(1)
106
+ l2 = 0.1+(l1-0.1)*np.random.rand(1)
107
+
108
+ k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
109
+
110
+ # util.imshow(k)
111
+
112
+ kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F')
113
+
114
+ # io.savemat('k.mat', {'k': kernels})
115
+
116
+ pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
117
+
118
+ io.savemat(path, {'p': pca_matrix})
119
+
120
+ return pca_matrix
121
+
122
+
123
+ """
124
+ # --------------------------------------------
125
+ # shifted anisotropic Gaussian kernels
126
+ # --------------------------------------------
127
+ """
128
+
129
+
130
+ def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
131
+ """"
132
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
133
+ # Kai Zhang
134
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
135
+ # max_var = 2.5 * sf
136
+ """
137
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
138
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
139
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
140
+ theta = np.random.rand() * np.pi # random theta
141
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
142
+
143
+ # Set COV matrix using Lambdas and Theta
144
+ LAMBDA = np.diag([lambda_1, lambda_2])
145
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
146
+ [np.sin(theta), np.cos(theta)]])
147
+ SIGMA = Q @ LAMBDA @ Q.T
148
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
149
+
150
+ # Set expectation position (shifting kernel for aligned image)
151
+ MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
152
+ MU = MU[None, None, :, None]
153
+
154
+ # Create meshgrid for Gaussian
155
+ [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
156
+ Z = np.stack([X, Y], 2)[:, :, :, None]
157
+
158
+ # Calcualte Gaussian for every pixel of the kernel
159
+ ZZ = Z-MU
160
+ ZZ_t = ZZ.transpose(0,1,3,2)
161
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
162
+
163
+ # shift the kernel so it will be centered
164
+ #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
165
+
166
+ # Normalize the kernel and return
167
+ #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
168
+ kernel = raw_kernel / np.sum(raw_kernel)
169
+ return kernel
170
+
171
+
172
+ def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
173
+ """"
174
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
175
+ # Kai Zhang
176
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
177
+ # max_var = 2.5 * sf
178
+ """
179
+ sf = random.choice([1, 2, 3, 4])
180
+ scale_factor = np.array([sf, sf])
181
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
182
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
183
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
184
+ theta = np.random.rand() * np.pi # random theta
185
+ noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
186
+
187
+ # Set COV matrix using Lambdas and Theta
188
+ LAMBDA = np.diag([lambda_1, lambda_2])
189
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
190
+ [np.sin(theta), np.cos(theta)]])
191
+ SIGMA = Q @ LAMBDA @ Q.T
192
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
193
+
194
+ # Set expectation position (shifting kernel for aligned image)
195
+ MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
196
+ MU = MU[None, None, :, None]
197
+
198
+ # Create meshgrid for Gaussian
199
+ [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
200
+ Z = np.stack([X, Y], 2)[:, :, :, None]
201
+
202
+ # Calcualte Gaussian for every pixel of the kernel
203
+ ZZ = Z-MU
204
+ ZZ_t = ZZ.transpose(0,1,3,2)
205
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
206
+
207
+ # shift the kernel so it will be centered
208
+ #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
209
+
210
+ # Normalize the kernel and return
211
+ #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
212
+ kernel = raw_kernel / np.sum(raw_kernel)
213
+ return kernel
214
+
215
+
216
+ """
217
+ # --------------------------------------------
218
+ # degradation models
219
+ # --------------------------------------------
220
+ """
221
+
222
+
223
+ def bicubic_degradation(x, sf=3):
224
+ '''
225
+ Args:
226
+ x: HxWxC image, [0, 1]
227
+ sf: down-scale factor
228
+ Return:
229
+ bicubicly downsampled LR image
230
+ '''
231
+ x = util.imresize_np(x, scale=1/sf)
232
+ return x
233
+
234
+
235
+ def srmd_degradation(x, k, sf=3):
236
+ ''' blur + bicubic downsampling
237
+ Args:
238
+ x: HxWxC image, [0, 1]
239
+ k: hxw, double
240
+ sf: down-scale factor
241
+ Return:
242
+ downsampled LR image
243
+ Reference:
244
+ @inproceedings{zhang2018learning,
245
+ title={Learning a single convolutional super-resolution network for multiple degradations},
246
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
247
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
248
+ pages={3262--3271},
249
+ year={2018}
250
+ }
251
+ '''
252
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
253
+ x = bicubic_degradation(x, sf=sf)
254
+ return x
255
+
256
+
257
+ def dpsr_degradation(x, k, sf=3):
258
+
259
+ ''' bicubic downsampling + blur
260
+ Args:
261
+ x: HxWxC image, [0, 1]
262
+ k: hxw, double
263
+ sf: down-scale factor
264
+ Return:
265
+ downsampled LR image
266
+ Reference:
267
+ @inproceedings{zhang2019deep,
268
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
269
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
270
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
271
+ pages={1671--1681},
272
+ year={2019}
273
+ }
274
+ '''
275
+ x = bicubic_degradation(x, sf=sf)
276
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
277
+ return x
278
+
279
+
280
+ def classical_degradation(x, k, sf=3):
281
+ ''' blur + downsampling
282
+
283
+ Args:
284
+ x: HxWxC image, [0, 1]/[0, 255]
285
+ k: hxw, double
286
+ sf: down-scale factor
287
+
288
+ Return:
289
+ downsampled LR image
290
+ '''
291
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
292
+ #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
293
+ st = 0
294
+ return x[st::sf, st::sf, ...]
295
+
296
+
297
+ def modcrop_np(img, sf):
298
+ '''
299
+ Args:
300
+ img: numpy image, WxH or WxHxC
301
+ sf: scale factor
302
+ Return:
303
+ cropped image
304
+ '''
305
+ w, h = img.shape[:2]
306
+ im = np.copy(img)
307
+ return im[:w - w % sf, :h - h % sf, ...]
308
+
309
+
310
+ '''
311
+ # =================
312
+ # Numpy
313
+ # =================
314
+ '''
315
+
316
+
317
+ def shift_pixel(x, sf, upper_left=True):
318
+ """shift pixel for super-resolution with different scale factors
319
+ Args:
320
+ x: WxHxC or WxH, image or kernel
321
+ sf: scale factor
322
+ upper_left: shift direction
323
+ """
324
+ h, w = x.shape[:2]
325
+ shift = (sf-1)*0.5
326
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
327
+ if upper_left:
328
+ x1 = xv + shift
329
+ y1 = yv + shift
330
+ else:
331
+ x1 = xv - shift
332
+ y1 = yv - shift
333
+
334
+ x1 = np.clip(x1, 0, w-1)
335
+ y1 = np.clip(y1, 0, h-1)
336
+
337
+ if x.ndim == 2:
338
+ x = interp2d(xv, yv, x)(x1, y1)
339
+ if x.ndim == 3:
340
+ for i in range(x.shape[-1]):
341
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
342
+
343
+ return x
344
+
345
+
346
+ '''
347
+ # =================
348
+ # pytorch
349
+ # =================
350
+ '''
351
+
352
+
353
+ def splits(a, sf):
354
+ '''
355
+ a: tensor NxCxWxHx2
356
+ sf: scale factor
357
+ out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
358
+ '''
359
+ b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
360
+ b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
361
+ return b
362
+
363
+
364
+ def c2c(x):
365
+ return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
366
+
367
+
368
+ def r2c(x):
369
+ return torch.stack([x, torch.zeros_like(x)], -1)
370
+
371
+
372
+ def cdiv(x, y):
373
+ a, b = x[..., 0], x[..., 1]
374
+ c, d = y[..., 0], y[..., 1]
375
+ cd2 = c**2 + d**2
376
+ return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
377
+
378
+
379
+ def csum(x, y):
380
+ return torch.stack([x[..., 0] + y, x[..., 1]], -1)
381
+
382
+
383
+ def cabs(x):
384
+ return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
385
+
386
+
387
+ def cmul(t1, t2):
388
+ '''
389
+ complex multiplication
390
+ t1: NxCxHxWx2
391
+ output: NxCxHxWx2
392
+ '''
393
+ real1, imag1 = t1[..., 0], t1[..., 1]
394
+ real2, imag2 = t2[..., 0], t2[..., 1]
395
+ return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
396
+
397
+
398
+ def cconj(t, inplace=False):
399
+ '''
400
+ # complex's conjugation
401
+ t: NxCxHxWx2
402
+ output: NxCxHxWx2
403
+ '''
404
+ c = t.clone() if not inplace else t
405
+ c[..., 1] *= -1
406
+ return c
407
+
408
+
409
+ def rfft(t):
410
+ return torch.rfft(t, 2, onesided=False)
411
+
412
+
413
+ def irfft(t):
414
+ return torch.irfft(t, 2, onesided=False)
415
+
416
+
417
+ def fft(t):
418
+ return torch.fft(t, 2)
419
+
420
+
421
+ def ifft(t):
422
+ return torch.ifft(t, 2)
423
+
424
+
425
+ def p2o(psf, shape):
426
+ '''
427
+ Args:
428
+ psf: NxCxhxw
429
+ shape: [H,W]
430
+
431
+ Returns:
432
+ otf: NxCxHxWx2
433
+ '''
434
+ otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
435
+ otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
436
+ for axis, axis_size in enumerate(psf.shape[2:]):
437
+ otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
438
+ otf = torch.rfft(otf, 2, onesided=False)
439
+ n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
440
+ otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
441
+ return otf
442
+
443
+
444
+ '''
445
+ # =================
446
+ PyTorch
447
+ # =================
448
+ '''
449
+
450
+ def INVLS_pytorch(FB, FBC, F2B, FR, tau, sf=2):
451
+ '''
452
+ FB: NxCxWxHx2
453
+ F2B: NxCxWxHx2
454
+
455
+ x1 = FB.*FR;
456
+ FBR = BlockMM(nr,nc,Nb,m,x1);
457
+ invW = BlockMM(nr,nc,Nb,m,F2B);
458
+ invWBR = FBR./(invW + tau*Nb);
459
+ fun = @(block_struct) block_struct.data.*invWBR;
460
+ FCBinvWBR = blockproc(FBC,[nr,nc],fun);
461
+ FX = (FR-FCBinvWBR)/tau;
462
+ Xest = real(ifft2(FX));
463
+ '''
464
+ x1 = cmul(FB, FR)
465
+ FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
466
+ invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
467
+ invWBR = cdiv(FBR, csum(invW, tau))
468
+ FCBinvWBR = cmul(FBC, invWBR.repeat(1,1,sf,sf,1))
469
+ FX = (FR-FCBinvWBR)/tau
470
+ Xest = torch.irfft(FX, 2, onesided=False)
471
+ return Xest
472
+
473
+
474
+ def real2complex(x):
475
+ return torch.stack([x, torch.zeros_like(x)], -1)
476
+
477
+
478
+ def modcrop(img, sf):
479
+ '''
480
+ img: tensor image, NxCxWxH or CxWxH or WxH
481
+ sf: scale factor
482
+ '''
483
+ w, h = img.shape[-2:]
484
+ im = img.clone()
485
+ return im[..., :w - w % sf, :h - h % sf]
486
+
487
+
488
+ def upsample(x, sf=3, center=False):
489
+ '''
490
+ x: tensor image, NxCxWxH
491
+ '''
492
+ st = (sf-1)//2 if center else 0
493
+ z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
494
+ z[..., st::sf, st::sf].copy_(x)
495
+ return z
496
+
497
+
498
+ def downsample(x, sf=3, center=False):
499
+ st = (sf-1)//2 if center else 0
500
+ return x[..., st::sf, st::sf]
501
+
502
+
503
+ def circular_pad(x, pad):
504
+ '''
505
+ # x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
506
+ '''
507
+ x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
508
+ x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
509
+ x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
510
+ x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
511
+ return x
512
+
513
+
514
+ def pad_circular(input, padding):
515
+ # type: (Tensor, List[int]) -> Tensor
516
+ """
517
+ Arguments
518
+ :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
519
+ :param padding: (tuple): m-elem tuple where m is the degree of convolution
520
+ Returns
521
+ :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
522
+ H + 2 * padding[1]], W + 2 * padding[2]))`
523
+ """
524
+ offset = 3
525
+ for dimension in range(input.dim() - offset + 1):
526
+ input = dim_pad_circular(input, padding[dimension], dimension + offset)
527
+ return input
528
+
529
+
530
+ def dim_pad_circular(input, padding, dimension):
531
+ # type: (Tensor, int, int) -> Tensor
532
+ input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
533
+ [slice(0, padding)]]], dim=dimension - 1)
534
+ input = torch.cat([input[[slice(None)] * (dimension - 1) +
535
+ [slice(-2 * padding, -padding)]], input], dim=dimension - 1)
536
+ return input
537
+
538
+
539
+ def imfilter(x, k):
540
+ '''
541
+ x: image, NxcxHxW
542
+ k: kernel, cx1xhxw
543
+ '''
544
+ x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
545
+ x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
546
+ return x
547
+
548
+
549
+ def G(x, k, sf=3, center=False):
550
+ '''
551
+ x: image, NxcxHxW
552
+ k: kernel, cx1xhxw
553
+ sf: scale factor
554
+ center: the first one or the moddle one
555
+
556
+ Matlab function:
557
+ tmp = imfilter(x,h,'circular');
558
+ y = downsample2(tmp,K);
559
+ '''
560
+ x = downsample(imfilter(x, k), sf=sf, center=center)
561
+ return x
562
+
563
+
564
+ def Gt(x, k, sf=3, center=False):
565
+ '''
566
+ x: image, NxcxHxW
567
+ k: kernel, cx1xhxw
568
+ sf: scale factor
569
+ center: the first one or the moddle one
570
+
571
+ Matlab function:
572
+ tmp = upsample2(x,K);
573
+ y = imfilter(tmp,h,'circular');
574
+ '''
575
+ x = imfilter(upsample(x, sf=sf, center=center), k)
576
+ return x
577
+
578
+
579
+ def interpolation_down(x, sf, center=False):
580
+ mask = torch.zeros_like(x)
581
+ if center:
582
+ start = torch.tensor((sf-1)//2)
583
+ mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
584
+ LR = x[..., start::sf, start::sf]
585
+ else:
586
+ mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
587
+ LR = x[..., ::sf, ::sf]
588
+ y = x.mul(mask)
589
+
590
+ return LR, y, mask
591
+
592
+
593
+ '''
594
+ # =================
595
+ Numpy
596
+ # =================
597
+ '''
598
+
599
+
600
+ def blockproc(im, blocksize, fun):
601
+ xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
602
+ xblocks_proc = []
603
+ for xb in xblocks:
604
+ yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
605
+ yblocks_proc = []
606
+ for yb in yblocks:
607
+ yb_proc = fun(yb)
608
+ yblocks_proc.append(yb_proc)
609
+ xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
610
+
611
+ proc = np.concatenate(xblocks_proc, axis=0)
612
+
613
+ return proc
614
+
615
+
616
+ def fun_reshape(a):
617
+ return np.reshape(a, (-1,1,a.shape[-1]), order='F')
618
+
619
+
620
+ def fun_mul(a, b):
621
+ return a*b
622
+
623
+
624
+ def BlockMM(nr, nc, Nb, m, x1):
625
+ '''
626
+ myfun = @(block_struct) reshape(block_struct.data,m,1);
627
+ x1 = blockproc(x1,[nr nc],myfun);
628
+ x1 = reshape(x1,m,Nb);
629
+ x1 = sum(x1,2);
630
+ x = reshape(x1,nr,nc);
631
+ '''
632
+ fun = fun_reshape
633
+ x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
634
+ x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
635
+ x1 = np.sum(x1, 1)
636
+ x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
637
+ return x
638
+
639
+
640
+ def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
641
+ '''
642
+ x1 = FB.*FR;
643
+ FBR = BlockMM(nr,nc,Nb,m,x1);
644
+ invW = BlockMM(nr,nc,Nb,m,F2B);
645
+ invWBR = FBR./(invW + tau*Nb);
646
+ fun = @(block_struct) block_struct.data.*invWBR;
647
+ FCBinvWBR = blockproc(FBC,[nr,nc],fun);
648
+ FX = (FR-FCBinvWBR)/tau;
649
+ Xest = real(ifft2(FX));
650
+ '''
651
+ x1 = FB*FR
652
+ FBR = BlockMM(nr, nc, Nb, m, x1)
653
+ invW = BlockMM(nr, nc, Nb, m, F2B)
654
+ invWBR = FBR/(invW + tau*Nb)
655
+ FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
656
+ FX = (FR-FCBinvWBR)/tau
657
+ Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
658
+ return Xest
659
+
660
+
661
+ def psf2otf(psf, shape=None):
662
+ """
663
+ Convert point-spread function to optical transfer function.
664
+ Compute the Fast Fourier Transform (FFT) of the point-spread
665
+ function (PSF) array and creates the optical transfer function (OTF)
666
+ array that is not influenced by the PSF off-centering.
667
+ By default, the OTF array is the same size as the PSF array.
668
+ To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
669
+ post-pads the PSF array (down or to the right) with zeros to match
670
+ dimensions specified in OUTSIZE, then circularly shifts the values of
671
+ the PSF array up (or to the left) until the central pixel reaches (1,1)
672
+ position.
673
+ Parameters
674
+ ----------
675
+ psf : `numpy.ndarray`
676
+ PSF array
677
+ shape : int
678
+ Output shape of the OTF array
679
+ Returns
680
+ -------
681
+ otf : `numpy.ndarray`
682
+ OTF array
683
+ Notes
684
+ -----
685
+ Adapted from MATLAB psf2otf function
686
+ """
687
+ if type(shape) == type(None):
688
+ shape = psf.shape
689
+ shape = np.array(shape)
690
+ if np.all(psf == 0):
691
+ # return np.zeros_like(psf)
692
+ return np.zeros(shape)
693
+ if len(psf.shape) == 1:
694
+ psf = psf.reshape((1, psf.shape[0]))
695
+ inshape = psf.shape
696
+ psf = zero_pad(psf, shape, position='corner')
697
+ for axis, axis_size in enumerate(inshape):
698
+ psf = np.roll(psf, -int(axis_size / 2), axis=axis)
699
+ # Compute the OTF
700
+ otf = np.fft.fft2(psf, axes=(0, 1))
701
+ # Estimate the rough number of operations involved in the FFT
702
+ # and discard the PSF imaginary part if within roundoff error
703
+ # roundoff error = machine epsilon = sys.float_info.epsilon
704
+ # or np.finfo().eps
705
+ n_ops = np.sum(psf.size * np.log2(psf.shape))
706
+ otf = np.real_if_close(otf, tol=n_ops)
707
+ return otf
708
+
709
+
710
+ def zero_pad(image, shape, position='corner'):
711
+ """
712
+ Extends image to a certain size with zeros
713
+ Parameters
714
+ ----------
715
+ image: real 2d `numpy.ndarray`
716
+ Input image
717
+ shape: tuple of int
718
+ Desired output shape of the image
719
+ position : str, optional
720
+ The position of the input image in the output one:
721
+ * 'corner'
722
+ top-left corner (default)
723
+ * 'center'
724
+ centered
725
+ Returns
726
+ -------
727
+ padded_img: real `numpy.ndarray`
728
+ The zero-padded image
729
+ """
730
+ shape = np.asarray(shape, dtype=int)
731
+ imshape = np.asarray(image.shape, dtype=int)
732
+ if np.alltrue(imshape == shape):
733
+ return image
734
+ if np.any(shape <= 0):
735
+ raise ValueError("ZERO_PAD: null or negative shape given")
736
+ dshape = shape - imshape
737
+ if np.any(dshape < 0):
738
+ raise ValueError("ZERO_PAD: target size smaller than source one")
739
+ pad_img = np.zeros(shape, dtype=image.dtype)
740
+ idx, idy = np.indices(imshape)
741
+ if position == 'center':
742
+ if np.any(dshape % 2 != 0):
743
+ raise ValueError("ZERO_PAD: source and target shapes "
744
+ "have different parity.")
745
+ offx, offy = dshape // 2
746
+ else:
747
+ offx, offy = (0, 0)
748
+ pad_img[idx + offx, idy + offy] = image
749
+ return pad_img
750
+
751
+
752
+ def upsample_np(x, sf=3, center=False):
753
+ st = (sf-1)//2 if center else 0
754
+ z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
755
+ z[st::sf, st::sf, ...] = x
756
+ return z
757
+
758
+
759
+ def downsample_np(x, sf=3, center=False):
760
+ st = (sf-1)//2 if center else 0
761
+ return x[st::sf, st::sf, ...]
762
+
763
+
764
+ def imfilter_np(x, k):
765
+ '''
766
+ x: image, NxcxHxW
767
+ k: kernel, cx1xhxw
768
+ '''
769
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
770
+ return x
771
+
772
+
773
+ def G_np(x, k, sf=3, center=False):
774
+ '''
775
+ x: image, NxcxHxW
776
+ k: kernel, cx1xhxw
777
+
778
+ Matlab function:
779
+ tmp = imfilter(x,h,'circular');
780
+ y = downsample2(tmp,K);
781
+ '''
782
+ x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
783
+ return x
784
+
785
+
786
+ def Gt_np(x, k, sf=3, center=False):
787
+ '''
788
+ x: image, NxcxHxW
789
+ k: kernel, cx1xhxw
790
+
791
+ Matlab function:
792
+ tmp = upsample2(x,K);
793
+ y = imfilter(tmp,h,'circular');
794
+ '''
795
+ x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
796
+ return x
797
+
798
+
799
+ if __name__ == '__main__':
800
+ img = util.imread_uint('test.bmp', 3)
801
+
802
+ img = util.uint2single(img)
803
+ k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
804
+ util.imshow(k*10)
805
+
806
+
807
+ for sf in [2, 3, 4]:
808
+
809
+ # modcrop
810
+ img = modcrop_np(img, sf=sf)
811
+
812
+ # 1) bicubic degradation
813
+ img_b = bicubic_degradation(img, sf=sf)
814
+ print(img_b.shape)
815
+
816
+ # 2) srmd degradation
817
+ img_s = srmd_degradation(img, k, sf=sf)
818
+ print(img_s.shape)
819
+
820
+ # 3) dpsr degradation
821
+ img_d = dpsr_degradation(img, k, sf=sf)
822
+ print(img_d.shape)
823
+
824
+ # 4) classical degradation
825
+ img_d = classical_degradation(img, k, sf=sf)
826
+ print(img_d.shape)
827
+
828
+ k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
829
+ #print(k)
830
+ # util.imshow(k*10)
831
+
832
+ k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
833
+ # util.imshow(k*10)
834
+
835
+
836
+ # PCA
837
+ # pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
838
+ # print(pca_matrix.shape)
839
+ # show_pca(pca_matrix)
840
+ # run utils/utils_sisr.py
841
+ # run utils_sisr.py
842
+
843
+
844
+
845
+
846
+
847
+
848
+
core/data/deg_kair_utils/utils_video.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import random
6
+ from os import path as osp
7
+ from torch.nn import functional as F
8
+ from abc import ABCMeta, abstractmethod
9
+
10
+
11
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
12
+ """Scan a directory to find the interested files.
13
+
14
+ Args:
15
+ dir_path (str): Path of the directory.
16
+ suffix (str | tuple(str), optional): File suffix that we are
17
+ interested in. Default: None.
18
+ recursive (bool, optional): If set to True, recursively scan the
19
+ directory. Default: False.
20
+ full_path (bool, optional): If set to True, include the dir_path.
21
+ Default: False.
22
+
23
+ Returns:
24
+ A generator for all the interested files with relative paths.
25
+ """
26
+
27
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
28
+ raise TypeError('"suffix" must be a string or tuple of strings')
29
+
30
+ root = dir_path
31
+
32
+ def _scandir(dir_path, suffix, recursive):
33
+ for entry in os.scandir(dir_path):
34
+ if not entry.name.startswith('.') and entry.is_file():
35
+ if full_path:
36
+ return_path = entry.path
37
+ else:
38
+ return_path = osp.relpath(entry.path, root)
39
+
40
+ if suffix is None:
41
+ yield return_path
42
+ elif return_path.endswith(suffix):
43
+ yield return_path
44
+ else:
45
+ if recursive:
46
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
47
+ else:
48
+ continue
49
+
50
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
51
+
52
+
53
+ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
54
+ """Read a sequence of images from a given folder path.
55
+
56
+ Args:
57
+ path (list[str] | str): List of image paths or image folder path.
58
+ require_mod_crop (bool): Require mod crop for each image.
59
+ Default: False.
60
+ scale (int): Scale factor for mod_crop. Default: 1.
61
+ return_imgname(bool): Whether return image names. Default False.
62
+
63
+ Returns:
64
+ Tensor: size (t, c, h, w), RGB, [0, 1].
65
+ list[str]: Returned image name list.
66
+ """
67
+ if isinstance(path, list):
68
+ img_paths = path
69
+ else:
70
+ img_paths = sorted(list(scandir(path, full_path=True)))
71
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
72
+
73
+ if require_mod_crop:
74
+ imgs = [mod_crop(img, scale) for img in imgs]
75
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
76
+ imgs = torch.stack(imgs, dim=0)
77
+
78
+ if return_imgname:
79
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
80
+ return imgs, imgnames
81
+ else:
82
+ return imgs
83
+
84
+
85
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
86
+ """Numpy array to tensor.
87
+
88
+ Args:
89
+ imgs (list[ndarray] | ndarray): Input images.
90
+ bgr2rgb (bool): Whether to change bgr to rgb.
91
+ float32 (bool): Whether to change to float32.
92
+
93
+ Returns:
94
+ list[tensor] | tensor: Tensor images. If returned results only have
95
+ one element, just return tensor.
96
+ """
97
+
98
+ def _totensor(img, bgr2rgb, float32):
99
+ if img.shape[2] == 3 and bgr2rgb:
100
+ if img.dtype == 'float64':
101
+ img = img.astype('float32')
102
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103
+ img = torch.from_numpy(img.transpose(2, 0, 1))
104
+ if float32:
105
+ img = img.float()
106
+ return img
107
+
108
+ if isinstance(imgs, list):
109
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
110
+ else:
111
+ return _totensor(imgs, bgr2rgb, float32)
112
+
113
+
114
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
115
+ """Convert torch Tensors into image numpy arrays.
116
+
117
+ After clamping to [min, max], values will be normalized to [0, 1].
118
+
119
+ Args:
120
+ tensor (Tensor or list[Tensor]): Accept shapes:
121
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
122
+ 2) 3D Tensor of shape (3/1 x H x W);
123
+ 3) 2D Tensor of shape (H x W).
124
+ Tensor channel should be in RGB order.
125
+ rgb2bgr (bool): Whether to change rgb to bgr.
126
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
127
+ to uint8 type with range [0, 255]; otherwise, float type with
128
+ range [0, 1]. Default: ``np.uint8``.
129
+ min_max (tuple[int]): min and max values for clamp.
130
+
131
+ Returns:
132
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
133
+ shape (H x W). The channel order is BGR.
134
+ """
135
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
136
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
137
+
138
+ if torch.is_tensor(tensor):
139
+ tensor = [tensor]
140
+ result = []
141
+ for _tensor in tensor:
142
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
143
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
144
+
145
+ n_dim = _tensor.dim()
146
+ if n_dim == 4:
147
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
148
+ img_np = img_np.transpose(1, 2, 0)
149
+ if rgb2bgr:
150
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
151
+ elif n_dim == 3:
152
+ img_np = _tensor.numpy()
153
+ img_np = img_np.transpose(1, 2, 0)
154
+ if img_np.shape[2] == 1: # gray image
155
+ img_np = np.squeeze(img_np, axis=2)
156
+ else:
157
+ if rgb2bgr:
158
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
159
+ elif n_dim == 2:
160
+ img_np = _tensor.numpy()
161
+ else:
162
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
163
+ if out_type == np.uint8:
164
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
165
+ img_np = (img_np * 255.0).round()
166
+ img_np = img_np.astype(out_type)
167
+ result.append(img_np)
168
+ if len(result) == 1:
169
+ result = result[0]
170
+ return result
171
+
172
+
173
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
174
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
175
+
176
+ We use vertical flip and transpose for rotation implementation.
177
+ All the images in the list use the same augmentation.
178
+
179
+ Args:
180
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
181
+ is an ndarray, it will be transformed to a list.
182
+ hflip (bool): Horizontal flip. Default: True.
183
+ rotation (bool): Ratotation. Default: True.
184
+ flows (list[ndarray]: Flows to be augmented. If the input is an
185
+ ndarray, it will be transformed to a list.
186
+ Dimension is (h, w, 2). Default: None.
187
+ return_status (bool): Return the status of flip and rotation.
188
+ Default: False.
189
+
190
+ Returns:
191
+ list[ndarray] | ndarray: Augmented images and flows. If returned
192
+ results only have one element, just return ndarray.
193
+
194
+ """
195
+ hflip = hflip and random.random() < 0.5
196
+ vflip = rotation and random.random() < 0.5
197
+ rot90 = rotation and random.random() < 0.5
198
+
199
+ def _augment(img):
200
+ if hflip: # horizontal
201
+ cv2.flip(img, 1, img)
202
+ if vflip: # vertical
203
+ cv2.flip(img, 0, img)
204
+ if rot90:
205
+ img = img.transpose(1, 0, 2)
206
+ return img
207
+
208
+ def _augment_flow(flow):
209
+ if hflip: # horizontal
210
+ cv2.flip(flow, 1, flow)
211
+ flow[:, :, 0] *= -1
212
+ if vflip: # vertical
213
+ cv2.flip(flow, 0, flow)
214
+ flow[:, :, 1] *= -1
215
+ if rot90:
216
+ flow = flow.transpose(1, 0, 2)
217
+ flow = flow[:, :, [1, 0]]
218
+ return flow
219
+
220
+ if not isinstance(imgs, list):
221
+ imgs = [imgs]
222
+ imgs = [_augment(img) for img in imgs]
223
+ if len(imgs) == 1:
224
+ imgs = imgs[0]
225
+
226
+ if flows is not None:
227
+ if not isinstance(flows, list):
228
+ flows = [flows]
229
+ flows = [_augment_flow(flow) for flow in flows]
230
+ if len(flows) == 1:
231
+ flows = flows[0]
232
+ return imgs, flows
233
+ else:
234
+ if return_status:
235
+ return imgs, (hflip, vflip, rot90)
236
+ else:
237
+ return imgs
238
+
239
+
240
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
241
+ """Paired random crop. Support Numpy array and Tensor inputs.
242
+
243
+ It crops lists of lq and gt images with corresponding locations.
244
+
245
+ Args:
246
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
247
+ should have the same shape. If the input is an ndarray, it will
248
+ be transformed to a list containing itself.
249
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
250
+ should have the same shape. If the input is an ndarray, it will
251
+ be transformed to a list containing itself.
252
+ gt_patch_size (int): GT patch size.
253
+ scale (int): Scale factor.
254
+ gt_path (str): Path to ground-truth. Default: None.
255
+
256
+ Returns:
257
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
258
+ only have one element, just return ndarray.
259
+ """
260
+
261
+ if not isinstance(img_gts, list):
262
+ img_gts = [img_gts]
263
+ if not isinstance(img_lqs, list):
264
+ img_lqs = [img_lqs]
265
+
266
+ # determine input type: Numpy array or Tensor
267
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
268
+
269
+ if input_type == 'Tensor':
270
+ h_lq, w_lq = img_lqs[0].size()[-2:]
271
+ h_gt, w_gt = img_gts[0].size()[-2:]
272
+ else:
273
+ h_lq, w_lq = img_lqs[0].shape[0:2]
274
+ h_gt, w_gt = img_gts[0].shape[0:2]
275
+ lq_patch_size = gt_patch_size // scale
276
+
277
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
278
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
279
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
280
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
281
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
282
+ f'({lq_patch_size}, {lq_patch_size}). '
283
+ f'Please remove {gt_path}.')
284
+
285
+ # randomly choose top and left coordinates for lq patch
286
+ top = random.randint(0, h_lq - lq_patch_size)
287
+ left = random.randint(0, w_lq - lq_patch_size)
288
+
289
+ # crop lq patch
290
+ if input_type == 'Tensor':
291
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
292
+ else:
293
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
294
+
295
+ # crop corresponding gt patch
296
+ top_gt, left_gt = int(top * scale), int(left * scale)
297
+ if input_type == 'Tensor':
298
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
299
+ else:
300
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
301
+ if len(img_gts) == 1:
302
+ img_gts = img_gts[0]
303
+ if len(img_lqs) == 1:
304
+ img_lqs = img_lqs[0]
305
+ return img_gts, img_lqs
306
+
307
+
308
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
309
+ class BaseStorageBackend(metaclass=ABCMeta):
310
+ """Abstract class of storage backends.
311
+
312
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
313
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
314
+ as texts.
315
+ """
316
+
317
+ @abstractmethod
318
+ def get(self, filepath):
319
+ pass
320
+
321
+ @abstractmethod
322
+ def get_text(self, filepath):
323
+ pass
324
+
325
+
326
+ class MemcachedBackend(BaseStorageBackend):
327
+ """Memcached storage backend.
328
+
329
+ Attributes:
330
+ server_list_cfg (str): Config file for memcached server list.
331
+ client_cfg (str): Config file for memcached client.
332
+ sys_path (str | None): Additional path to be appended to `sys.path`.
333
+ Default: None.
334
+ """
335
+
336
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
337
+ if sys_path is not None:
338
+ import sys
339
+ sys.path.append(sys_path)
340
+ try:
341
+ import mc
342
+ except ImportError:
343
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
344
+
345
+ self.server_list_cfg = server_list_cfg
346
+ self.client_cfg = client_cfg
347
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
348
+ # mc.pyvector servers as a point which points to a memory cache
349
+ self._mc_buffer = mc.pyvector()
350
+
351
+ def get(self, filepath):
352
+ filepath = str(filepath)
353
+ import mc
354
+ self._client.Get(filepath, self._mc_buffer)
355
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
356
+ return value_buf
357
+
358
+ def get_text(self, filepath):
359
+ raise NotImplementedError
360
+
361
+
362
+ class HardDiskBackend(BaseStorageBackend):
363
+ """Raw hard disks storage backend."""
364
+
365
+ def get(self, filepath):
366
+ filepath = str(filepath)
367
+ with open(filepath, 'rb') as f:
368
+ value_buf = f.read()
369
+ return value_buf
370
+
371
+ def get_text(self, filepath):
372
+ filepath = str(filepath)
373
+ with open(filepath, 'r') as f:
374
+ value_buf = f.read()
375
+ return value_buf
376
+
377
+
378
+ class LmdbBackend(BaseStorageBackend):
379
+ """Lmdb storage backend.
380
+
381
+ Args:
382
+ db_paths (str | list[str]): Lmdb database paths.
383
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
384
+ readonly (bool, optional): Lmdb environment parameter. If True,
385
+ disallow any write operations. Default: True.
386
+ lock (bool, optional): Lmdb environment parameter. If False, when
387
+ concurrent access occurs, do not lock the database. Default: False.
388
+ readahead (bool, optional): Lmdb environment parameter. If False,
389
+ disable the OS filesystem readahead mechanism, which may improve
390
+ random read performance when a database is larger than RAM.
391
+ Default: False.
392
+
393
+ Attributes:
394
+ db_paths (list): Lmdb database path.
395
+ _client (list): A list of several lmdb envs.
396
+ """
397
+
398
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
399
+ try:
400
+ import lmdb
401
+ except ImportError:
402
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
403
+
404
+ if isinstance(client_keys, str):
405
+ client_keys = [client_keys]
406
+
407
+ if isinstance(db_paths, list):
408
+ self.db_paths = [str(v) for v in db_paths]
409
+ elif isinstance(db_paths, str):
410
+ self.db_paths = [str(db_paths)]
411
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
412
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
413
+
414
+ self._client = {}
415
+ for client, path in zip(client_keys, self.db_paths):
416
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
417
+
418
+ def get(self, filepath, client_key):
419
+ """Get values according to the filepath from one lmdb named client_key.
420
+
421
+ Args:
422
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
423
+ client_key (str): Used for distinguishing different lmdb envs.
424
+ """
425
+ filepath = str(filepath)
426
+ assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
427
+ client = self._client[client_key]
428
+ with client.begin(write=False) as txn:
429
+ value_buf = txn.get(filepath.encode('ascii'))
430
+ return value_buf
431
+
432
+ def get_text(self, filepath):
433
+ raise NotImplementedError
434
+
435
+
436
+ class FileClient(object):
437
+ """A general file client to access files in different backend.
438
+
439
+ The client loads a file or text in a specified backend from its path
440
+ and return it as a binary file. it can also register other backend
441
+ accessor with a given name and backend class.
442
+
443
+ Attributes:
444
+ backend (str): The storage backend type. Options are "disk",
445
+ "memcached" and "lmdb".
446
+ client (:obj:`BaseStorageBackend`): The backend object.
447
+ """
448
+
449
+ _backends = {
450
+ 'disk': HardDiskBackend,
451
+ 'memcached': MemcachedBackend,
452
+ 'lmdb': LmdbBackend,
453
+ }
454
+
455
+ def __init__(self, backend='disk', **kwargs):
456
+ if backend not in self._backends:
457
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
458
+ f' are {list(self._backends.keys())}')
459
+ self.backend = backend
460
+ self.client = self._backends[backend](**kwargs)
461
+
462
+ def get(self, filepath, client_key='default'):
463
+ # client_key is used only for lmdb, where different fileclients have
464
+ # different lmdb environments.
465
+ if self.backend == 'lmdb':
466
+ return self.client.get(filepath, client_key)
467
+ else:
468
+ return self.client.get(filepath)
469
+
470
+ def get_text(self, filepath):
471
+ return self.client.get_text(filepath)
472
+
473
+
474
+ def imfrombytes(content, flag='color', float32=False):
475
+ """Read an image from bytes.
476
+
477
+ Args:
478
+ content (bytes): Image bytes got from files or other streams.
479
+ flag (str): Flags specifying the color type of a loaded image,
480
+ candidates are `color`, `grayscale` and `unchanged`.
481
+ float32 (bool): Whether to change to float32., If True, will also norm
482
+ to [0, 1]. Default: False.
483
+
484
+ Returns:
485
+ ndarray: Loaded image array.
486
+ """
487
+ img_np = np.frombuffer(content, np.uint8)
488
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
489
+ img = cv2.imdecode(img_np, imread_flags[flag])
490
+ if float32:
491
+ img = img.astype(np.float32) / 255.
492
+ return img
493
+
core/data/deg_kair_utils/utils_videoio.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import random
6
+ from os import path as osp
7
+ from torchvision.utils import make_grid
8
+ import sys
9
+ from pathlib import Path
10
+ import six
11
+ from collections import OrderedDict
12
+ import math
13
+ import glob
14
+ import av
15
+ import io
16
+ from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
17
+ CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
18
+ CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
19
+
20
+ if sys.version_info <= (3, 3):
21
+ FileNotFoundError = IOError
22
+ else:
23
+ FileNotFoundError = FileNotFoundError
24
+
25
+
26
+ def is_str(x):
27
+ """Whether the input is an string instance."""
28
+ return isinstance(x, six.string_types)
29
+
30
+
31
+ def is_filepath(x):
32
+ return is_str(x) or isinstance(x, Path)
33
+
34
+
35
+ def fopen(filepath, *args, **kwargs):
36
+ if is_str(filepath):
37
+ return open(filepath, *args, **kwargs)
38
+ elif isinstance(filepath, Path):
39
+ return filepath.open(*args, **kwargs)
40
+ raise ValueError('`filepath` should be a string or a Path')
41
+
42
+
43
+ def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
44
+ if not osp.isfile(filename):
45
+ raise FileNotFoundError(msg_tmpl.format(filename))
46
+
47
+
48
+ def mkdir_or_exist(dir_name, mode=0o777):
49
+ if dir_name == '':
50
+ return
51
+ dir_name = osp.expanduser(dir_name)
52
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
53
+
54
+
55
+ def symlink(src, dst, overwrite=True, **kwargs):
56
+ if os.path.lexists(dst) and overwrite:
57
+ os.remove(dst)
58
+ os.symlink(src, dst, **kwargs)
59
+
60
+
61
+ def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
62
+ """Scan a directory to find the interested files.
63
+ Args:
64
+ dir_path (str | :obj:`Path`): Path of the directory.
65
+ suffix (str | tuple(str), optional): File suffix that we are
66
+ interested in. Default: None.
67
+ recursive (bool, optional): If set to True, recursively scan the
68
+ directory. Default: False.
69
+ case_sensitive (bool, optional) : If set to False, ignore the case of
70
+ suffix. Default: True.
71
+ Returns:
72
+ A generator for all the interested files with relative paths.
73
+ """
74
+ if isinstance(dir_path, (str, Path)):
75
+ dir_path = str(dir_path)
76
+ else:
77
+ raise TypeError('"dir_path" must be a string or Path object')
78
+
79
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
80
+ raise TypeError('"suffix" must be a string or tuple of strings')
81
+
82
+ if suffix is not None and not case_sensitive:
83
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(
84
+ item.lower() for item in suffix)
85
+
86
+ root = dir_path
87
+
88
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
89
+ for entry in os.scandir(dir_path):
90
+ if not entry.name.startswith('.') and entry.is_file():
91
+ rel_path = osp.relpath(entry.path, root)
92
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
93
+ if suffix is None or _rel_path.endswith(suffix):
94
+ yield rel_path
95
+ elif recursive and os.path.isdir(entry.path):
96
+ # scan recursively if entry.path is a directory
97
+ yield from _scandir(entry.path, suffix, recursive,
98
+ case_sensitive)
99
+
100
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
101
+
102
+
103
+ class Cache:
104
+
105
+ def __init__(self, capacity):
106
+ self._cache = OrderedDict()
107
+ self._capacity = int(capacity)
108
+ if capacity <= 0:
109
+ raise ValueError('capacity must be a positive integer')
110
+
111
+ @property
112
+ def capacity(self):
113
+ return self._capacity
114
+
115
+ @property
116
+ def size(self):
117
+ return len(self._cache)
118
+
119
+ def put(self, key, val):
120
+ if key in self._cache:
121
+ return
122
+ if len(self._cache) >= self.capacity:
123
+ self._cache.popitem(last=False)
124
+ self._cache[key] = val
125
+
126
+ def get(self, key, default=None):
127
+ val = self._cache[key] if key in self._cache else default
128
+ return val
129
+
130
+
131
+ class VideoReader:
132
+ """Video class with similar usage to a list object.
133
+
134
+ This video warpper class provides convenient apis to access frames.
135
+ There exists an issue of OpenCV's VideoCapture class that jumping to a
136
+ certain frame may be inaccurate. It is fixed in this class by checking
137
+ the position after jumping each time.
138
+ Cache is used when decoding videos. So if the same frame is visited for
139
+ the second time, there is no need to decode again if it is stored in the
140
+ cache.
141
+
142
+ """
143
+
144
+ def __init__(self, filename, cache_capacity=10):
145
+ # Check whether the video path is a url
146
+ if not filename.startswith(('https://', 'http://')):
147
+ check_file_exist(filename, 'Video file not found: ' + filename)
148
+ self._vcap = cv2.VideoCapture(filename)
149
+ assert cache_capacity > 0
150
+ self._cache = Cache(cache_capacity)
151
+ self._position = 0
152
+ # get basic info
153
+ self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
154
+ self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
155
+ self._fps = self._vcap.get(CAP_PROP_FPS)
156
+ self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
157
+ self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
158
+
159
+ @property
160
+ def vcap(self):
161
+ """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
162
+ return self._vcap
163
+
164
+ @property
165
+ def opened(self):
166
+ """bool: Indicate whether the video is opened."""
167
+ return self._vcap.isOpened()
168
+
169
+ @property
170
+ def width(self):
171
+ """int: Width of video frames."""
172
+ return self._width
173
+
174
+ @property
175
+ def height(self):
176
+ """int: Height of video frames."""
177
+ return self._height
178
+
179
+ @property
180
+ def resolution(self):
181
+ """tuple: Video resolution (width, height)."""
182
+ return (self._width, self._height)
183
+
184
+ @property
185
+ def fps(self):
186
+ """float: FPS of the video."""
187
+ return self._fps
188
+
189
+ @property
190
+ def frame_cnt(self):
191
+ """int: Total frames of the video."""
192
+ return self._frame_cnt
193
+
194
+ @property
195
+ def fourcc(self):
196
+ """str: "Four character code" of the video."""
197
+ return self._fourcc
198
+
199
+ @property
200
+ def position(self):
201
+ """int: Current cursor position, indicating frame decoded."""
202
+ return self._position
203
+
204
+ def _get_real_position(self):
205
+ return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
206
+
207
+ def _set_real_position(self, frame_id):
208
+ self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
209
+ pos = self._get_real_position()
210
+ for _ in range(frame_id - pos):
211
+ self._vcap.read()
212
+ self._position = frame_id
213
+
214
+ def read(self):
215
+ """Read the next frame.
216
+
217
+ If the next frame have been decoded before and in the cache, then
218
+ return it directly, otherwise decode, cache and return it.
219
+
220
+ Returns:
221
+ ndarray or None: Return the frame if successful, otherwise None.
222
+ """
223
+ # pos = self._position
224
+ if self._cache:
225
+ img = self._cache.get(self._position)
226
+ if img is not None:
227
+ ret = True
228
+ else:
229
+ if self._position != self._get_real_position():
230
+ self._set_real_position(self._position)
231
+ ret, img = self._vcap.read()
232
+ if ret:
233
+ self._cache.put(self._position, img)
234
+ else:
235
+ ret, img = self._vcap.read()
236
+ if ret:
237
+ self._position += 1
238
+ return img
239
+
240
+ def get_frame(self, frame_id):
241
+ """Get frame by index.
242
+
243
+ Args:
244
+ frame_id (int): Index of the expected frame, 0-based.
245
+
246
+ Returns:
247
+ ndarray or None: Return the frame if successful, otherwise None.
248
+ """
249
+ if frame_id < 0 or frame_id >= self._frame_cnt:
250
+ raise IndexError(
251
+ f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
252
+ if frame_id == self._position:
253
+ return self.read()
254
+ if self._cache:
255
+ img = self._cache.get(frame_id)
256
+ if img is not None:
257
+ self._position = frame_id + 1
258
+ return img
259
+ self._set_real_position(frame_id)
260
+ ret, img = self._vcap.read()
261
+ if ret:
262
+ if self._cache:
263
+ self._cache.put(self._position, img)
264
+ self._position += 1
265
+ return img
266
+
267
+ def current_frame(self):
268
+ """Get the current frame (frame that is just visited).
269
+
270
+ Returns:
271
+ ndarray or None: If the video is fresh, return None, otherwise
272
+ return the frame.
273
+ """
274
+ if self._position == 0:
275
+ return None
276
+ return self._cache.get(self._position - 1)
277
+
278
+ def cvt2frames(self,
279
+ frame_dir,
280
+ file_start=0,
281
+ filename_tmpl='{:06d}.jpg',
282
+ start=0,
283
+ max_num=0,
284
+ show_progress=False):
285
+ """Convert a video to frame images.
286
+
287
+ Args:
288
+ frame_dir (str): Output directory to store all the frame images.
289
+ file_start (int): Filenames will start from the specified number.
290
+ filename_tmpl (str): Filename template with the index as the
291
+ placeholder.
292
+ start (int): The starting frame index.
293
+ max_num (int): Maximum number of frames to be written.
294
+ show_progress (bool): Whether to show a progress bar.
295
+ """
296
+ mkdir_or_exist(frame_dir)
297
+ if max_num == 0:
298
+ task_num = self.frame_cnt - start
299
+ else:
300
+ task_num = min(self.frame_cnt - start, max_num)
301
+ if task_num <= 0:
302
+ raise ValueError('start must be less than total frame number')
303
+ if start > 0:
304
+ self._set_real_position(start)
305
+
306
+ def write_frame(file_idx):
307
+ img = self.read()
308
+ if img is None:
309
+ return
310
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
311
+ cv2.imwrite(filename, img)
312
+
313
+ if show_progress:
314
+ pass
315
+ #track_progress(write_frame, range(file_start,file_start + task_num))
316
+ else:
317
+ for i in range(task_num):
318
+ write_frame(file_start + i)
319
+
320
+ def __len__(self):
321
+ return self.frame_cnt
322
+
323
+ def __getitem__(self, index):
324
+ if isinstance(index, slice):
325
+ return [
326
+ self.get_frame(i)
327
+ for i in range(*index.indices(self.frame_cnt))
328
+ ]
329
+ # support negative indexing
330
+ if index < 0:
331
+ index += self.frame_cnt
332
+ if index < 0:
333
+ raise IndexError('index out of range')
334
+ return self.get_frame(index)
335
+
336
+ def __iter__(self):
337
+ self._set_real_position(0)
338
+ return self
339
+
340
+ def __next__(self):
341
+ img = self.read()
342
+ if img is not None:
343
+ return img
344
+ else:
345
+ raise StopIteration
346
+
347
+ next = __next__
348
+
349
+ def __enter__(self):
350
+ return self
351
+
352
+ def __exit__(self, exc_type, exc_value, traceback):
353
+ self._vcap.release()
354
+
355
+
356
+ def frames2video(frame_dir,
357
+ video_file,
358
+ fps=30,
359
+ fourcc='XVID',
360
+ filename_tmpl='{:06d}.jpg',
361
+ start=0,
362
+ end=0,
363
+ show_progress=False):
364
+ """Read the frame images from a directory and join them as a video.
365
+
366
+ Args:
367
+ frame_dir (str): The directory containing video frames.
368
+ video_file (str): Output filename.
369
+ fps (float): FPS of the output video.
370
+ fourcc (str): Fourcc of the output video, this should be compatible
371
+ with the output file type.
372
+ filename_tmpl (str): Filename template with the index as the variable.
373
+ start (int): Starting frame index.
374
+ end (int): Ending frame index.
375
+ show_progress (bool): Whether to show a progress bar.
376
+ """
377
+ if end == 0:
378
+ ext = filename_tmpl.split('.')[-1]
379
+ end = len([name for name in scandir(frame_dir, ext)])
380
+ first_file = osp.join(frame_dir, filename_tmpl.format(start))
381
+ check_file_exist(first_file, 'The start frame not found: ' + first_file)
382
+ img = cv2.imread(first_file)
383
+ height, width = img.shape[:2]
384
+ resolution = (width, height)
385
+ vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
386
+ resolution)
387
+
388
+ def write_frame(file_idx):
389
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
390
+ img = cv2.imread(filename)
391
+ vwriter.write(img)
392
+
393
+ if show_progress:
394
+ pass
395
+ # track_progress(write_frame, range(start, end))
396
+ else:
397
+ for i in range(start, end):
398
+ write_frame(i)
399
+ vwriter.release()
400
+
401
+
402
+ def video2images(video_path, output_dir):
403
+ vidcap = cv2.VideoCapture(video_path)
404
+ in_fps = vidcap.get(cv2.CAP_PROP_FPS)
405
+ print('video fps:', in_fps)
406
+ if not os.path.isdir(output_dir):
407
+ os.makedirs(output_dir)
408
+ loaded, frame = vidcap.read()
409
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
410
+ print(f'number of total frames is: {total_frames:06}')
411
+ for i_frame in range(total_frames):
412
+ if i_frame % 100 == 0:
413
+ print(f'{i_frame:06} / {total_frames:06}')
414
+ frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png')
415
+ cv2.imwrite(frame_name, frame)
416
+ loaded, frame = vidcap.read()
417
+
418
+
419
+ def images2video(image_dir, video_path, fps=24, image_ext='png'):
420
+ '''
421
+ #codec = cv2.VideoWriter_fourcc(*'XVID')
422
+ #codec = cv2.VideoWriter_fourcc('A','V','C','1')
423
+ #codec = cv2.VideoWriter_fourcc('Y','U','V','1')
424
+ #codec = cv2.VideoWriter_fourcc('P','I','M','1')
425
+ #codec = cv2.VideoWriter_fourcc('M','J','P','G')
426
+ codec = cv2.VideoWriter_fourcc('M','P','4','2')
427
+ #codec = cv2.VideoWriter_fourcc('D','I','V','3')
428
+ #codec = cv2.VideoWriter_fourcc('D','I','V','X')
429
+ #codec = cv2.VideoWriter_fourcc('U','2','6','3')
430
+ #codec = cv2.VideoWriter_fourcc('I','2','6','3')
431
+ #codec = cv2.VideoWriter_fourcc('F','L','V','1')
432
+ #codec = cv2.VideoWriter_fourcc('H','2','6','4')
433
+ #codec = cv2.VideoWriter_fourcc('A','Y','U','V')
434
+ #codec = cv2.VideoWriter_fourcc('I','U','Y','V')
435
+ 编码器常用的几种:
436
+ cv2.VideoWriter_fourcc("I", "4", "2", "0")
437
+ 压缩的yuv颜色编码器,4:2:0色彩度子采样 兼容性好,产生很大的视频 avi
438
+ cv2.VideoWriter_fourcc("P", I", "M", "1")
439
+ 采用mpeg-1编码,文件为avi
440
+ cv2.VideoWriter_fourcc("X", "V", "T", "D")
441
+ 采用mpeg-4编码,得到视频大小平均 拓展名avi
442
+ cv2.VideoWriter_fourcc("T", "H", "E", "O")
443
+ Ogg Vorbis, 拓展名为ogv
444
+ cv2.VideoWriter_fourcc("F", "L", "V", "1")
445
+ FLASH视频,拓展名为.flv
446
+ '''
447
+ image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext))))
448
+ print(len(image_files))
449
+ height, width, _ = cv2.imread(image_files[0]).shape
450
+ out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V')
451
+ out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height))
452
+
453
+ for image_file in image_files:
454
+ img = cv2.imread(image_file)
455
+ img = cv2.resize(img, (width, height), interpolation=3)
456
+ out_video.write(img)
457
+ out_video.release()
458
+
459
+
460
+ def add_video_compression(imgs):
461
+ codec_type = ['libx264', 'h264', 'mpeg4']
462
+ codec_prob = [1 / 3., 1 / 3., 1 / 3.]
463
+ codec = random.choices(codec_type, codec_prob)[0]
464
+ # codec = 'mpeg4'
465
+ bitrate = [1e4, 1e5]
466
+ bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
467
+
468
+ buf = io.BytesIO()
469
+ with av.open(buf, 'w', 'mp4') as container:
470
+ stream = container.add_stream(codec, rate=1)
471
+ stream.height = imgs[0].shape[0]
472
+ stream.width = imgs[0].shape[1]
473
+ stream.pix_fmt = 'yuv420p'
474
+ stream.bit_rate = bitrate
475
+
476
+ for img in imgs:
477
+ img = np.uint8((img.clip(0, 1)*255.).round())
478
+ frame = av.VideoFrame.from_ndarray(img, format='rgb24')
479
+ frame.pict_type = 'NONE'
480
+ # pdb.set_trace()
481
+ for packet in stream.encode(frame):
482
+ container.mux(packet)
483
+
484
+ # Flush stream
485
+ for packet in stream.encode():
486
+ container.mux(packet)
487
+
488
+ outputs = []
489
+ with av.open(buf, 'r', 'mp4') as container:
490
+ if container.streams.video:
491
+ for frame in container.decode(**{'video': 0}):
492
+ outputs.append(
493
+ frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
494
+
495
+ #outputs = np.stack(outputs, axis=0)
496
+ return outputs
497
+
498
+
499
+ if __name__ == '__main__':
500
+
501
+ # -----------------------------------
502
+ # test VideoReader(filename, cache_capacity=10)
503
+ # -----------------------------------
504
+ # video_reader = VideoReader('utils/test.mp4')
505
+ # from utils import utils_image as util
506
+ # inputs = []
507
+ # for frame in video_reader:
508
+ # print(frame.dtype)
509
+ # util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
510
+ # #util.imshow(np.flip(frame, axis=2))
511
+
512
+ # -----------------------------------
513
+ # test video2images(video_path, output_dir)
514
+ # -----------------------------------
515
+ # video2images('utils/test.mp4', 'frames')
516
+
517
+ # -----------------------------------
518
+ # test images2video(image_dir, video_path, fps=24, image_ext='png')
519
+ # -----------------------------------
520
+ # images2video('frames', 'video_02.mp4', fps=30, image_ext='png')
521
+
522
+
523
+ # -----------------------------------
524
+ # test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png')
525
+ # -----------------------------------
526
+ # frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png')
527
+
528
+
529
+ # -----------------------------------
530
+ # test add_video_compression(imgs)
531
+ # -----------------------------------
532
+ # imgs = []
533
+ # image_ext = 'png'
534
+ # frames = 'frames'
535
+ # from utils import utils_image as util
536
+ # image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext))))
537
+ # for i, image_file in enumerate(image_files):
538
+ # if i < 7:
539
+ # img = util.imread_uint(image_file, 3)
540
+ # img = util.uint2single(img)
541
+ # imgs.append(img)
542
+ #
543
+ # results = add_video_compression(imgs)
544
+ # for i, img in enumerate(results):
545
+ # util.imshow(util.single2uint(img))
546
+ # util.imsave(util.single2uint(img),f'{i:05}.png')
547
+
548
+ # run utils/utils_video.py
549
+
550
+
551
+
552
+
553
+
554
+
555
+
core/scripts/__init__.py ADDED
File without changes
core/scripts/cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ from .. import WarpCore
4
+ from .. import templates
5
+
6
+
7
+ def template_init(args):
8
+ return ''''
9
+
10
+
11
+ '''.strip()
12
+
13
+
14
+ def init_template(args):
15
+ parser = argparse.ArgumentParser(description='WarpCore template init tool')
16
+ parser.add_argument('-t', '--template', type=str, default='WarpCore')
17
+ args = parser.parse_args(args)
18
+
19
+ if args.template == 'WarpCore':
20
+ template_cls = WarpCore
21
+ else:
22
+ try:
23
+ template_cls = __import__(args.template)
24
+ except ModuleNotFoundError:
25
+ template_cls = getattr(templates, args.template)
26
+ print(template_cls)
27
+
28
+
29
+ def main():
30
+ if len(sys.argv) < 2:
31
+ print('Usage: core <command>')
32
+ sys.exit(1)
33
+ if sys.argv[1] == 'init':
34
+ init_template(sys.argv[2:])
35
+ else:
36
+ print('Unknown command')
37
+ sys.exit(1)
38
+
39
+
40
+ if __name__ == '__main__':
41
+ main()
core/templates/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .diffusion import DiffusionCore
core/templates/diffusion.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .. import WarpCore
2
+ from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
3
+ from abc import abstractmethod
4
+ from dataclasses import dataclass
5
+ import torch
6
+ from torch import nn
7
+ from torch.utils.data import DataLoader
8
+ from gdf import GDF
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import wandb
12
+
13
+ import webdataset as wds
14
+ from webdataset.handlers import warn_and_continue
15
+ from torch.distributed import barrier
16
+ from enum import Enum
17
+
18
+ class TargetReparametrization(Enum):
19
+ EPSILON = 'epsilon'
20
+ X0 = 'x0'
21
+
22
+ class DiffusionCore(WarpCore):
23
+ @dataclass(frozen=True)
24
+ class Config(WarpCore.Config):
25
+ # TRAINING PARAMS
26
+ lr: float = EXPECTED_TRAIN
27
+ grad_accum_steps: int = EXPECTED_TRAIN
28
+ batch_size: int = EXPECTED_TRAIN
29
+ updates: int = EXPECTED_TRAIN
30
+ warmup_updates: int = EXPECTED_TRAIN
31
+ save_every: int = 500
32
+ backup_every: int = 20000
33
+ use_fsdp: bool = True
34
+
35
+ # EMA UPDATE
36
+ ema_start_iters: int = None
37
+ ema_iters: int = None
38
+ ema_beta: float = None
39
+
40
+ # GDF setting
41
+ gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
42
+
43
+ @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
44
+ class Info(WarpCore.Info):
45
+ ema_loss: float = None
46
+
47
+ @dataclass(frozen=True)
48
+ class Models(WarpCore.Models):
49
+ generator : nn.Module = EXPECTED
50
+ generator_ema : nn.Module = None # optional
51
+
52
+ @dataclass(frozen=True)
53
+ class Optimizers(WarpCore.Optimizers):
54
+ generator : any = EXPECTED
55
+
56
+ @dataclass(frozen=True)
57
+ class Schedulers(WarpCore.Schedulers):
58
+ generator: any = None
59
+
60
+ @dataclass(frozen=True)
61
+ class Extras(WarpCore.Extras):
62
+ gdf: GDF = EXPECTED
63
+ sampling_configs: dict = EXPECTED
64
+
65
+ # --------------------------------------------
66
+ info: Info
67
+ config: Config
68
+
69
+ @abstractmethod
70
+ def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
71
+ raise NotImplementedError("This method needs to be overriden")
72
+
73
+ @abstractmethod
74
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
75
+ raise NotImplementedError("This method needs to be overriden")
76
+
77
+ @abstractmethod
78
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
79
+ raise NotImplementedError("This method needs to be overriden")
80
+
81
+ @abstractmethod
82
+ def webdataset_path(self, extras: Extras):
83
+ raise NotImplementedError("This method needs to be overriden")
84
+
85
+ @abstractmethod
86
+ def webdataset_filters(self, extras: Extras):
87
+ raise NotImplementedError("This method needs to be overriden")
88
+
89
+ @abstractmethod
90
+ def webdataset_preprocessors(self, extras: Extras):
91
+ raise NotImplementedError("This method needs to be overriden")
92
+
93
+ @abstractmethod
94
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
95
+ raise NotImplementedError("This method needs to be overriden")
96
+ # -------------
97
+
98
+ def setup_data(self, extras: Extras) -> WarpCore.Data:
99
+ # SETUP DATASET
100
+ dataset_path = self.webdataset_path(extras)
101
+ preprocessors = self.webdataset_preprocessors(extras)
102
+ filters = self.webdataset_filters(extras)
103
+
104
+ handler = warn_and_continue # None
105
+ # handler = None
106
+ dataset = wds.WebDataset(
107
+ dataset_path, resampled=True, handler=handler
108
+ ).select(filters).shuffle(690, handler=handler).decode(
109
+ "pilrgb", handler=handler
110
+ ).to_tuple(
111
+ *[p[0] for p in preprocessors], handler=handler
112
+ ).map_tuple(
113
+ *[p[1] for p in preprocessors], handler=handler
114
+ ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
115
+
116
+ # SETUP DATALOADER
117
+ real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
118
+ dataloader = DataLoader(
119
+ dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
120
+ )
121
+
122
+ return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
123
+
124
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
125
+ batch = next(data.iterator)
126
+
127
+ with torch.no_grad():
128
+ conditions = self.get_conditions(batch, models, extras)
129
+ latents = self.encode_latents(batch, models, extras)
130
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
131
+
132
+ # FORWARD PASS
133
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
134
+ pred = models.generator(noised, noise_cond, **conditions)
135
+ if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
136
+ pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
137
+ target = noise
138
+ elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
139
+ pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
140
+ target = latents
141
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
142
+ loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
143
+
144
+ return loss, loss_adjusted
145
+
146
+ def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
147
+ start_iter = self.info.iter+1
148
+ max_iters = self.config.updates * self.config.grad_accum_steps
149
+ if self.is_main_node:
150
+ print(f"STARTING AT STEP: {start_iter}/{max_iters}")
151
+
152
+ pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
153
+ models.generator.train()
154
+ for i in pbar:
155
+ # FORWARD PASS
156
+ loss, loss_adjusted = self.forward_pass(data, extras, models)
157
+
158
+ # BACKWARD PASS
159
+ if i % self.config.grad_accum_steps == 0 or i == max_iters:
160
+ loss_adjusted.backward()
161
+ grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
162
+ optimizers_dict = optimizers.to_dict()
163
+ for k in optimizers_dict:
164
+ optimizers_dict[k].step()
165
+ schedulers_dict = schedulers.to_dict()
166
+ for k in schedulers_dict:
167
+ schedulers_dict[k].step()
168
+ models.generator.zero_grad(set_to_none=True)
169
+ self.info.total_steps += 1
170
+ else:
171
+ with models.generator.no_sync():
172
+ loss_adjusted.backward()
173
+ self.info.iter = i
174
+
175
+ # UPDATE EMA
176
+ if models.generator_ema is not None and i % self.config.ema_iters == 0:
177
+ update_weights_ema(
178
+ models.generator_ema, models.generator,
179
+ beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
180
+ )
181
+
182
+ # UPDATE LOSS METRICS
183
+ self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
184
+
185
+ if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
186
+ wandb.alert(
187
+ title=f"NaN value encountered in training run {self.info.wandb_run_id}",
188
+ text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
189
+ wait_duration=60*30
190
+ )
191
+
192
+ if self.is_main_node:
193
+ logs = {
194
+ 'loss': self.info.ema_loss,
195
+ 'raw_loss': loss.mean().item(),
196
+ 'grad_norm': grad_norm.item(),
197
+ 'lr': optimizers.generator.param_groups[0]['lr'],
198
+ 'total_steps': self.info.total_steps,
199
+ }
200
+
201
+ pbar.set_postfix(logs)
202
+ if self.config.wandb_project is not None:
203
+ wandb.log(logs)
204
+
205
+ if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
206
+ # SAVE AND CHECKPOINT STUFF
207
+ if np.isnan(loss.mean().item()):
208
+ if self.is_main_node and self.config.wandb_project is not None:
209
+ tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
210
+ wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
211
+ else:
212
+ self.save_checkpoints(models, optimizers)
213
+ if self.is_main_node:
214
+ create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
215
+ self.sample(models, data, extras)
216
+
217
+ def models_to_save(self):
218
+ return ['generator', 'generator_ema']
219
+
220
+ def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
221
+ barrier()
222
+ suffix = '' if suffix is None else suffix
223
+ self.save_info(self.info, suffix=suffix)
224
+ models_dict = models.to_dict()
225
+ optimizers_dict = optimizers.to_dict()
226
+ for key in self.models_to_save():
227
+ model = models_dict[key]
228
+ if model is not None:
229
+ self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
230
+ for key in optimizers_dict:
231
+ optimizer = optimizers_dict[key]
232
+ if optimizer is not None:
233
+ self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
234
+ if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
235
+ self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
236
+ torch.cuda.empty_cache()
core/utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
2
+ from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
3
+
4
+ # MOVE IT SOMERWHERE ELSE
5
+ def update_weights_ema(tgt_model, src_model, beta=0.999):
6
+ for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
7
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
8
+ for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
9
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
core/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (763 Bytes). View file
 
core/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (804 Bytes). View file
 
core/utils/__pycache__/base_dto.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
core/utils/__pycache__/base_dto.cpython-39.pyc ADDED
Binary file (3.11 kB). View file
 
core/utils/__pycache__/save_and_load.cpython-310.pyc ADDED
Binary file (2.19 kB). View file
 
core/utils/__pycache__/save_and_load.cpython-39.pyc ADDED
Binary file (2.2 kB). View file
 
core/utils/base_dto.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from dataclasses import dataclass, _MISSING_TYPE
3
+ from munch import Munch
4
+
5
+ EXPECTED = "___REQUIRED___"
6
+ EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
7
+
8
+ # pylint: disable=invalid-field-call
9
+ def nested_dto(x, raw=False):
10
+ return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
11
+
12
+ @dataclass(frozen=True)
13
+ class Base:
14
+ training: bool = None
15
+ def __new__(cls, **kwargs):
16
+ training = kwargs.get('training', True)
17
+ setteable_fields = cls.setteable_fields(**kwargs)
18
+ mandatory_fields = cls.mandatory_fields(**kwargs)
19
+ invalid_kwargs = [
20
+ {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
21
+ ]
22
+ print(mandatory_fields)
23
+ assert (
24
+ len(invalid_kwargs) == 0
25
+ ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
26
+ missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
27
+ assert (
28
+ len(missing_kwargs) == 0
29
+ ), f"Required fields missing initializing this DTO: {missing_kwargs}."
30
+ return object.__new__(cls)
31
+
32
+
33
+ @classmethod
34
+ def setteable_fields(cls, **kwargs):
35
+ return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
36
+
37
+ @classmethod
38
+ def mandatory_fields(cls, **kwargs):
39
+ training = kwargs.get('training', True)
40
+ return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
41
+
42
+ @classmethod
43
+ def from_dict(cls, kwargs):
44
+ for k in kwargs:
45
+ if isinstance(kwargs[k], (dict, list, tuple)):
46
+ kwargs[k] = Munch.fromDict(kwargs[k])
47
+ return cls(**kwargs)
48
+
49
+ def to_dict(self):
50
+ # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
51
+ selfdict = {}
52
+ for k in dataclasses.fields(self):
53
+ selfdict[k.name] = getattr(self, k.name)
54
+ if isinstance(selfdict[k.name], Munch):
55
+ selfdict[k.name] = selfdict[k.name].toDict()
56
+ return selfdict
core/utils/save_and_load.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ from pathlib import Path
5
+ import safetensors
6
+ import wandb
7
+
8
+
9
+ def create_folder_if_necessary(path):
10
+ path = "/".join(path.split("/")[:-1])
11
+ Path(path).mkdir(parents=True, exist_ok=True)
12
+
13
+
14
+ def safe_save(ckpt, path):
15
+ try:
16
+ os.remove(f"{path}.bak")
17
+ except OSError:
18
+ pass
19
+ try:
20
+ os.rename(path, f"{path}.bak")
21
+ except OSError:
22
+ pass
23
+ if path.endswith(".pt") or path.endswith(".ckpt"):
24
+ torch.save(ckpt, path)
25
+ elif path.endswith(".json"):
26
+ with open(path, "w", encoding="utf-8") as f:
27
+ json.dump(ckpt, f, indent=4)
28
+ elif path.endswith(".safetensors"):
29
+ safetensors.torch.save_file(ckpt, path)
30
+ else:
31
+ raise ValueError(f"File extension not supported: {path}")
32
+
33
+
34
+ def load_or_fail(path, wandb_run_id=None):
35
+ accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
36
+ try:
37
+ assert any(
38
+ [path.endswith(ext) for ext in accepted_extensions]
39
+ ), f"Automatic loading not supported for this extension: {path}"
40
+ if not os.path.exists(path):
41
+ checkpoint = None
42
+ elif path.endswith(".pt") or path.endswith(".ckpt"):
43
+ checkpoint = torch.load(path, map_location="cpu")
44
+ elif path.endswith(".json"):
45
+ with open(path, "r", encoding="utf-8") as f:
46
+ checkpoint = json.load(f)
47
+ elif path.endswith(".safetensors"):
48
+ checkpoint = {}
49
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
50
+ for key in f.keys():
51
+ checkpoint[key] = f.get_tensor(key)
52
+ return checkpoint
53
+ except Exception as e:
54
+ if wandb_run_id is not None:
55
+ wandb.alert(
56
+ title=f"Corrupt checkpoint for run {wandb_run_id}",
57
+ text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
58
+ )
59
+ raise e
figures/California_000490.jpg ADDED

Git LFS Details

  • SHA256: 84fbae1a942a233e619fcc3ddc89b6038f909df6003bbbb1b10a15309e0ecd2e
  • Pointer size: 132 Bytes
  • Size of remote file: 6.06 MB
figures/example_dataset/000008.jpg ADDED

Git LFS Details

  • SHA256: 763beb950f9794497a2fbdd62a48c8fd91f9543e9b9e6b5cf6c38bfcfdd34a02
  • Pointer size: 132 Bytes
  • Size of remote file: 2.88 MB
figures/example_dataset/000008.json ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ { "caption": "The image captures the iconic Shard, a modern skyscraper that stands as the tallest building in the United Kingdom. The Shard, with its glass and steel structure, pierces the sky, its pointed top reaching towards the heavens. The photograph is taken from a low angle, which emphasizes the height and grandeur of the building. The sky forms a beautiful backdrop, painted in hues of pinkish-orange, suggesting that the photo was taken at sunset. The Shard is nestled between two other buildings, their presence subtly hinted at in the foreground. The image does not contain any discernible text or countable objects, and there are no visible actions taking place. The relative positions of the objects confirm that the Shard is the central focus of the image, with the other buildings and the sky providing context to its location. The image is devoid of any aesthetic descriptions, focusing solely on the factual representation of the scene."
2
+ }
figures/example_dataset/000012.jpg ADDED

Git LFS Details

  • SHA256: 7fea5f3176ed289556acf32c7dd3635db8944cb64be53f109b84554eb4da5bf3
  • Pointer size: 132 Bytes
  • Size of remote file: 2.67 MB
figures/example_dataset/000012.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"caption": "cars in a road during daytime"}
figures/teaser.jpg ADDED

Git LFS Details

  • SHA256: def5700a069d5f754b45ec02802e258c1c1473ad82fd10d2e62cc87e75a8a5e1
  • Pointer size: 132 Bytes
  • Size of remote file: 7.95 MB
gdf/__init__.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .scalers import *
3
+ from .targets import *
4
+ from .schedulers import *
5
+ from .noise_conditions import *
6
+ from .loss_weights import *
7
+ from .samplers import *
8
+ import torch.nn.functional as F
9
+ import math
10
+ class GDF():
11
+ def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0):
12
+ self.schedule = schedule
13
+ self.input_scaler = input_scaler
14
+ self.target = target
15
+ self.noise_cond = noise_cond
16
+ self.loss_weight = loss_weight
17
+ self.offset_noise = offset_noise
18
+
19
+ def setup_limits(self, stretch_max=True, stretch_min=True, shift=1):
20
+ stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift)
21
+ return stretched_limits
22
+
23
+ def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None):
24
+ if epsilon is None:
25
+ epsilon = torch.randn_like(x0)
26
+ if self.offset_noise > 0:
27
+ if offset is None:
28
+ offset = torch.randn([x0.size(0), x0.size(1)] + [1]*(len(x0.shape)-2)).to(x0.device)
29
+ epsilon = epsilon + offset * self.offset_noise
30
+ logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device)
31
+ a, b = self.input_scaler(logSNR) # B
32
+ if len(a.shape) == 1:
33
+ a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) # BxCxHxW
34
+ #print('in line 33 a b', a.shape, b.shape, x0.shape, logSNR.shape, logSNR, self.noise_cond(logSNR))
35
+ target = self.target(x0, epsilon, logSNR, a, b)
36
+
37
+ # noised, noise, logSNR, t_cond
38
+ #noised, noise, target, logSNR, noise_cond, loss_weight
39
+ return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift)
40
+
41
+ def undiffuse(self, x, logSNR, pred):
42
+ a, b = self.input_scaler(logSNR)
43
+ if len(a.shape) == 1:
44
+ a, b = a.view(-1, *[1]*(len(x.shape)-1)), b.view(-1, *[1]*(len(x.shape)-1))
45
+ return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b)
46
+
47
+ def sample(self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"):
48
+ sampler_params = {} if sampler_params is None else sampler_params
49
+ if sampler is None:
50
+ sampler = DDPMSampler(self)
51
+ r_range = torch.linspace(t_start, t_end, timesteps+1)
52
+ schedule = self.schedule if schedule is None else schedule
53
+ logSNR_range = schedule(r_range, shift=shift)[:, None].expand(
54
+ -1, shape[0] if x_init is None else x_init.size(0)
55
+ ).to(device)
56
+
57
+ x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
58
+
59
+ if cfg is not None:
60
+ if unconditional_inputs is None:
61
+ unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
62
+ model_inputs = {
63
+ k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor)
64
+ else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list)
65
+ else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict)
66
+ else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
67
+ }
68
+
69
+ for i in range(0, timesteps):
70
+ noise_cond = self.noise_cond(logSNR_range[i])
71
+ if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
72
+ cfg_val = cfg
73
+ if isinstance(cfg_val, (list, tuple)):
74
+ assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
75
+ cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
76
+
77
+ pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2)
78
+
79
+ pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
80
+ if cfg_rho > 0:
81
+ std_pos, std_cfg = pred.std(), pred_cfg.std()
82
+ pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
83
+ else:
84
+ pred = pred_cfg
85
+ else:
86
+ pred = model(x, noise_cond, **model_inputs)
87
+ x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
88
+ x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params)
89
+ #print('in line 86', x0.shape, x.shape, i, )
90
+ altered_vars = yield (x0, x, pred)
91
+
92
+ # Update some running variables if the user wants
93
+ if altered_vars is not None:
94
+ cfg = altered_vars.get('cfg', cfg)
95
+ cfg_rho = altered_vars.get('cfg_rho', cfg_rho)
96
+ sampler = altered_vars.get('sampler', sampler)
97
+ model_inputs = altered_vars.get('model_inputs', model_inputs)
98
+ x = altered_vars.get('x', x)
99
+ x_init = altered_vars.get('x_init', x_init)
100
+
101
+ class GDF_dual_fixlrt(GDF):
102
+ def ref_noise(self, noised, x0, logSNR):
103
+ a, b = self.input_scaler(logSNR)
104
+ if len(a.shape) == 1:
105
+ a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1))
106
+ #print('in line 210', a.shape, b.shape, x0.shape, noised.shape)
107
+ return self.target.noise_givenx0_noised(x0, noised, logSNR, a, b)
108
+
109
+ def sample(self, model, model_inputs, shape, shape_lr, unconditional_inputs=None, sampler=None,
110
+ schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None,
111
+ cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"):
112
+ sampler_params = {} if sampler_params is None else sampler_params
113
+ if sampler is None:
114
+ sampler = DDPMSampler(self)
115
+ r_range = torch.linspace(t_start, t_end, timesteps+1)
116
+ schedule = self.schedule if schedule is None else schedule
117
+ logSNR_range = schedule(r_range, shift=shift)[:, None].expand(
118
+ -1, shape[0] if x_init is None else x_init.size(0)
119
+ ).to(device)
120
+
121
+ x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
122
+ x_lr = sampler.init_x(shape_lr).to(device) if x_init is None else x_init.clone()
123
+ if cfg is not None:
124
+ if unconditional_inputs is None:
125
+ unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
126
+ model_inputs = {
127
+ k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor)
128
+ else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list)
129
+ else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict)
130
+ else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
131
+ }
132
+
133
+ ###############################################lr sampling
134
+
135
+ guide_feas = [None] * timesteps
136
+
137
+ for i in range(0, timesteps):
138
+ noise_cond = self.noise_cond(logSNR_range[i])
139
+ if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
140
+ cfg_val = cfg
141
+ if isinstance(cfg_val, (list, tuple)):
142
+ assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
143
+ cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
144
+
145
+
146
+
147
+ if i == timesteps -1 :
148
+ output, guide_lr_enc, guide_lr_dec = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs)
149
+ guide_feas[i] = ([f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_enc], [f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_dec])
150
+ else:
151
+ output, _, _ = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs)
152
+
153
+ pred, pred_unconditional = output.chunk(2)
154
+
155
+
156
+ pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
157
+ if cfg_rho > 0:
158
+ std_pos, std_cfg = pred.std(), pred_cfg.std()
159
+ pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
160
+ else:
161
+ pred = pred_cfg
162
+ else:
163
+ pred = model(x_lr, noise_cond, **model_inputs)
164
+ x0_lr, epsilon_lr = self.undiffuse(x_lr, logSNR_range[i], pred)
165
+ x_lr = sampler(x_lr, x0_lr, epsilon_lr, logSNR_range[i], logSNR_range[i+1], **sampler_params)
166
+
167
+ ###############################################hr HR sampling
168
+ for i in range(0, timesteps):
169
+ noise_cond = self.noise_cond(logSNR_range[i])
170
+ if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
171
+ cfg_val = cfg
172
+ if isinstance(cfg_val, (list, tuple)):
173
+ assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
174
+ cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
175
+
176
+ out_pred, t_emb = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), \
177
+ lr_guide=guide_feas[timesteps -1] if i <=19 else None , **model_inputs, require_t=True, guide_weight=1 - i/timesteps)
178
+ pred, pred_unconditional = out_pred.chunk(2)
179
+ pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
180
+ if cfg_rho > 0:
181
+ std_pos, std_cfg = pred.std(), pred_cfg.std()
182
+ pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
183
+ else:
184
+ pred = pred_cfg
185
+ else:
186
+ pred = model(x, noise_cond, guide_lr=(guide_lr_enc, guide_lr_dec), **model_inputs)
187
+ x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
188
+
189
+ x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params)
190
+ altered_vars = yield (x0, x, pred, x_lr)
191
+
192
+
193
+
194
+ # Update some running variables if the user wants
195
+ if altered_vars is not None:
196
+ cfg = altered_vars.get('cfg', cfg)
197
+ cfg_rho = altered_vars.get('cfg_rho', cfg_rho)
198
+ sampler = altered_vars.get('sampler', sampler)
199
+ model_inputs = altered_vars.get('model_inputs', model_inputs)
200
+ x = altered_vars.get('x', x)
201
+ x_init = altered_vars.get('x_init', x_init)
202
+
203
+
204
+
205
+
gdf/loss_weights.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ # --- Loss Weighting
5
+ class BaseLossWeight():
6
+ def weight(self, logSNR):
7
+ raise NotImplementedError("this method needs to be overridden")
8
+
9
+ def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
10
+ clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
11
+ if shift != 1:
12
+ logSNR = logSNR.clone() + 2 * np.log(shift)
13
+ return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)
14
+
15
+ class ComposedLossWeight(BaseLossWeight):
16
+ def __init__(self, div, mul):
17
+ self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
18
+ self.div = [div] if isinstance(div, BaseLossWeight) else div
19
+
20
+ def weight(self, logSNR):
21
+ prod, div = 1, 1
22
+ for m in self.mul:
23
+ prod *= m.weight(logSNR)
24
+ for d in self.div:
25
+ div *= d.weight(logSNR)
26
+ return prod/div
27
+
28
+ class ConstantLossWeight(BaseLossWeight):
29
+ def __init__(self, v=1):
30
+ self.v = v
31
+
32
+ def weight(self, logSNR):
33
+ return torch.ones_like(logSNR) * self.v
34
+
35
+ class SNRLossWeight(BaseLossWeight):
36
+ def weight(self, logSNR):
37
+ return logSNR.exp()
38
+
39
+ class P2LossWeight(BaseLossWeight):
40
+ def __init__(self, k=1.0, gamma=1.0, s=1.0):
41
+ self.k, self.gamma, self.s = k, gamma, s
42
+
43
+ def weight(self, logSNR):
44
+ return (self.k + (logSNR * self.s).exp()) ** -self.gamma
45
+
46
+ class SNRPlusOneLossWeight(BaseLossWeight):
47
+ def weight(self, logSNR):
48
+ return logSNR.exp() + 1
49
+
50
+ class MinSNRLossWeight(BaseLossWeight):
51
+ def __init__(self, max_snr=5):
52
+ self.max_snr = max_snr
53
+
54
+ def weight(self, logSNR):
55
+ return logSNR.exp().clamp(max=self.max_snr)
56
+
57
+ class MinSNRPlusOneLossWeight(BaseLossWeight):
58
+ def __init__(self, max_snr=5):
59
+ self.max_snr = max_snr
60
+
61
+ def weight(self, logSNR):
62
+ return (logSNR.exp() + 1).clamp(max=self.max_snr)
63
+
64
+ class TruncatedSNRLossWeight(BaseLossWeight):
65
+ def __init__(self, min_snr=1):
66
+ self.min_snr = min_snr
67
+
68
+ def weight(self, logSNR):
69
+ return logSNR.exp().clamp(min=self.min_snr)
70
+
71
+ class SechLossWeight(BaseLossWeight):
72
+ def __init__(self, div=2):
73
+ self.div = div
74
+
75
+ def weight(self, logSNR):
76
+ return 1/(logSNR/self.div).cosh()
77
+
78
+ class DebiasedLossWeight(BaseLossWeight):
79
+ def weight(self, logSNR):
80
+ return 1/logSNR.exp().sqrt()
81
+
82
+ class SigmoidLossWeight(BaseLossWeight):
83
+ def __init__(self, s=1):
84
+ self.s = s
85
+
86
+ def weight(self, logSNR):
87
+ return (logSNR * self.s).sigmoid()
88
+
89
+ class AdaptiveLossWeight(BaseLossWeight):
90
+ def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
91
+ self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1)
92
+ self.bucket_losses = torch.ones(buckets)
93
+ self.weight_range = weight_range
94
+
95
+ def weight(self, logSNR):
96
+ indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
97
+ return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)
98
+
99
+ def update_buckets(self, logSNR, loss, beta=0.99):
100
+ indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
101
+ self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta)
gdf/noise_conditions.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class BaseNoiseCond():
5
+ def __init__(self, *args, shift=1, clamp_range=None, **kwargs):
6
+ clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
7
+ self.shift = shift
8
+ self.clamp_range = clamp_range
9
+ self.setup(*args, **kwargs)
10
+
11
+ def setup(self, *args, **kwargs):
12
+ pass # this method is optional, override it if required
13
+
14
+ def cond(self, logSNR):
15
+ raise NotImplementedError("this method needs to be overriden")
16
+
17
+ def __call__(self, logSNR):
18
+ if self.shift != 1:
19
+ logSNR = logSNR.clone() + 2 * np.log(self.shift)
20
+ return self.cond(logSNR).clamp(*self.clamp_range)
21
+
22
+ class CosineTNoiseCond(BaseNoiseCond):
23
+ def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999]
24
+ self.s = torch.tensor([s])
25
+ self.clamp_range = clamp_range
26
+ self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
27
+
28
+ def cond(self, logSNR):
29
+ var = logSNR.sigmoid()
30
+ var = var.clamp(*self.clamp_range)
31
+ s, min_var = self.s.to(var.device), self.min_var.to(var.device)
32
+ t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
33
+ return t
34
+
35
+ class EDMNoiseCond(BaseNoiseCond):
36
+ def cond(self, logSNR):
37
+ return -logSNR/8
38
+
39
+ class SigmoidNoiseCond(BaseNoiseCond):
40
+ def cond(self, logSNR):
41
+ return (-logSNR).sigmoid()
42
+
43
+ class LogSNRNoiseCond(BaseNoiseCond):
44
+ def cond(self, logSNR):
45
+ return logSNR
46
+
47
+ class EDMSigmaNoiseCond(BaseNoiseCond):
48
+ def setup(self, sigma_data=1):
49
+ self.sigma_data = sigma_data
50
+
51
+ def cond(self, logSNR):
52
+ return torch.exp(-logSNR / 2) * self.sigma_data
53
+
54
+ class RectifiedFlowsNoiseCond(BaseNoiseCond):
55
+ def cond(self, logSNR):
56
+ _a = logSNR.exp() - 1
57
+ _a[_a == 0] = 1e-3 # Avoid division by zero
58
+ a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a)
59
+ return a
60
+
61
+ # Any NoiseCond that cannot be described easily as a continuous function of t
62
+ # It needs to define self.x and self.y in the setup() method
63
+ class PiecewiseLinearNoiseCond(BaseNoiseCond):
64
+ def setup(self):
65
+ self.x = None
66
+ self.y = None
67
+
68
+ def piecewise_linear(self, y, xs, ys):
69
+ indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y)
70
+ x_min, x_max = xs[indices], xs[indices+1]
71
+ y_min, y_max = ys[indices], ys[indices+1]
72
+ x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min)
73
+ return x
74
+
75
+ def cond(self, logSNR):
76
+ var = logSNR.sigmoid()
77
+ t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0)
78
+ return t
79
+
80
+ class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond):
81
+ def setup(self, linear_range=[0.00085, 0.012], total_steps=1000):
82
+ self.total_steps = total_steps
83
+ linear_range_sqrt = [r**0.5 for r in linear_range]
84
+ self.x = torch.linspace(0, 1, total_steps+1)
85
+
86
+ alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2
87
+ self.y = alphas.cumprod(dim=-1)
88
+
89
+ def cond(self, logSNR):
90
+ return super().cond(logSNR).clamp(0, 1)
91
+
92
+ class DiscreteNoiseCond(BaseNoiseCond):
93
+ def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]):
94
+ self.noise_cond = noise_cond
95
+ self.steps = steps
96
+ self.continuous_range = continuous_range
97
+
98
+ def cond(self, logSNR):
99
+ cond = self.noise_cond(logSNR)
100
+ cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0])
101
+ return cond.mul(self.steps).long()
102
+
gdf/readme.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generic Diffusion Framework (GDF)
2
+
3
+ # Basic usage
4
+ GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM
5
+ , EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different
6
+ frameworks
7
+
8
+ Using GDF is very straighforward, first of all just define an instance of the GDF class:
9
+
10
+ ```python
11
+ from gdf import GDF
12
+ from gdf import CosineSchedule
13
+ from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight
14
+
15
+ gdf = GDF(
16
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
17
+ input_scaler=VPScaler(), target=EpsilonTarget(),
18
+ noise_cond=CosineTNoiseCond(),
19
+ loss_weight=P2LossWeight(),
20
+ )
21
+ ```
22
+
23
+ You need to define the following components:
24
+ * **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution.
25
+ * **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule.
26
+ * **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows)
27
+ * **Target**: What the target is during training, usually: epsilon, x0 or v
28
+ * **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8`
29
+ * **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use
30
+
31
+ All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just:
32
+ ```python
33
+ class VPScaler():
34
+ def __call__(self, logSNR):
35
+ a_squared = logSNR.sigmoid()
36
+ a = a_squared.sqrt()
37
+ b = (1-a_squared).sqrt()
38
+ return a, b
39
+
40
+ ```
41
+
42
+ So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc...
43
+
44
+ ### Training
45
+
46
+ When you define your training loop you can get all you need by just doing:
47
+ ```python
48
+ shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution
49
+ for inputs, extra_conditions in dataloader_iterator:
50
+ noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift)
51
+ pred = diffusion_model(noised, noise_cond, extra_conditions)
52
+
53
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
54
+ loss_adjusted = (loss * loss_weight).mean()
55
+
56
+ loss_adjusted.backward()
57
+ optimizer.step()
58
+ optimizer.zero_grad(set_to_none=True)
59
+ ```
60
+
61
+ And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the
62
+ training from the GDF class.
63
+
64
+ ### Sampling
65
+
66
+ The other important part is sampling, when you want to use this framework to sample you can just do the following:
67
+
68
+ ```python
69
+ from gdf import DDPMSampler
70
+
71
+ shift = 1
72
+ sampling_configs = {
73
+ "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift,
74
+ "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999])
75
+ }
76
+
77
+ *_, (sampled, _, _) = gdf.sample(
78
+ diffusion_model, {"cond": extra_conditions}, latents.shape,
79
+ unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)},
80
+ device=device, **sampling_configs
81
+ )
82
+ ```
83
+
84
+ # Available modules
85
+
86
+ TODO