roubaofeipi
commited on
Upload 100 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- core/__init__.py +372 -0
- core/data/__init__.py +69 -0
- core/data/bucketeer.py +88 -0
- core/data/bucketeer_deg.py +91 -0
- core/data/deg_kair_utils/test.bmp +0 -0
- core/data/deg_kair_utils/test.png +3 -0
- core/data/deg_kair_utils/utils_alignfaces.py +263 -0
- core/data/deg_kair_utils/utils_blindsr.py +631 -0
- core/data/deg_kair_utils/utils_bnorm.py +91 -0
- core/data/deg_kair_utils/utils_deblur.py +655 -0
- core/data/deg_kair_utils/utils_dist.py +201 -0
- core/data/deg_kair_utils/utils_googledownload.py +93 -0
- core/data/deg_kair_utils/utils_image.py +1016 -0
- core/data/deg_kair_utils/utils_lmdb.py +205 -0
- core/data/deg_kair_utils/utils_logger.py +66 -0
- core/data/deg_kair_utils/utils_mat.py +88 -0
- core/data/deg_kair_utils/utils_matconvnet.py +197 -0
- core/data/deg_kair_utils/utils_model.py +330 -0
- core/data/deg_kair_utils/utils_modelsummary.py +485 -0
- core/data/deg_kair_utils/utils_option.py +255 -0
- core/data/deg_kair_utils/utils_params.py +135 -0
- core/data/deg_kair_utils/utils_receptivefield.py +62 -0
- core/data/deg_kair_utils/utils_regularizers.py +104 -0
- core/data/deg_kair_utils/utils_sisr.py +848 -0
- core/data/deg_kair_utils/utils_video.py +493 -0
- core/data/deg_kair_utils/utils_videoio.py +555 -0
- core/scripts/__init__.py +0 -0
- core/scripts/cli.py +41 -0
- core/templates/__init__.py +1 -0
- core/templates/diffusion.py +236 -0
- core/utils/__init__.py +9 -0
- core/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- core/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- core/utils/__pycache__/base_dto.cpython-310.pyc +0 -0
- core/utils/__pycache__/base_dto.cpython-39.pyc +0 -0
- core/utils/__pycache__/save_and_load.cpython-310.pyc +0 -0
- core/utils/__pycache__/save_and_load.cpython-39.pyc +0 -0
- core/utils/base_dto.py +56 -0
- core/utils/save_and_load.py +59 -0
- figures/California_000490.jpg +3 -0
- figures/example_dataset/000008.jpg +3 -0
- figures/example_dataset/000008.json +2 -0
- figures/example_dataset/000012.jpg +3 -0
- figures/example_dataset/000012.json +1 -0
- figures/teaser.jpg +3 -0
- gdf/__init__.py +205 -0
- gdf/loss_weights.py +101 -0
- gdf/noise_conditions.py +102 -0
- gdf/readme.md +86 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
core/data/deg_kair_utils/test.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
figures/California_000490.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
figures/example_dataset/000008.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
figures/example_dataset/000012.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
figures/teaser.jpg filter=lfs diff=lfs merge=lfs -text
|
core/__init__.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import wandb
|
6 |
+
import json
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
|
11 |
+
from torch.distributed import init_process_group, destroy_process_group, barrier
|
12 |
+
from torch.distributed.fsdp import (
|
13 |
+
FullyShardedDataParallel as FSDP,
|
14 |
+
FullStateDictConfig,
|
15 |
+
MixedPrecision,
|
16 |
+
ShardingStrategy,
|
17 |
+
StateDictType
|
18 |
+
)
|
19 |
+
|
20 |
+
from .utils import Base, EXPECTED, EXPECTED_TRAIN
|
21 |
+
from .utils import create_folder_if_necessary, safe_save, load_or_fail
|
22 |
+
|
23 |
+
# pylint: disable=unused-argument
|
24 |
+
class WarpCore(ABC):
|
25 |
+
@dataclass(frozen=True)
|
26 |
+
class Config(Base):
|
27 |
+
experiment_id: str = EXPECTED_TRAIN
|
28 |
+
checkpoint_path: str = EXPECTED_TRAIN
|
29 |
+
output_path: str = EXPECTED_TRAIN
|
30 |
+
checkpoint_extension: str = "safetensors"
|
31 |
+
dist_file_subfolder: str = ""
|
32 |
+
allow_tf32: bool = True
|
33 |
+
|
34 |
+
wandb_project: str = None
|
35 |
+
wandb_entity: str = None
|
36 |
+
|
37 |
+
@dataclass() # not frozen, means that fields are mutable
|
38 |
+
class Info(): # not inheriting from Base, because we don't want to enforce the default fields
|
39 |
+
wandb_run_id: str = None
|
40 |
+
total_steps: int = 0
|
41 |
+
iter: int = 0
|
42 |
+
|
43 |
+
@dataclass(frozen=True)
|
44 |
+
class Data(Base):
|
45 |
+
dataset: Dataset = EXPECTED
|
46 |
+
dataloader: DataLoader = EXPECTED
|
47 |
+
iterator: any = EXPECTED
|
48 |
+
|
49 |
+
@dataclass(frozen=True)
|
50 |
+
class Models(Base):
|
51 |
+
pass
|
52 |
+
|
53 |
+
@dataclass(frozen=True)
|
54 |
+
class Optimizers(Base):
|
55 |
+
pass
|
56 |
+
|
57 |
+
@dataclass(frozen=True)
|
58 |
+
class Schedulers(Base):
|
59 |
+
pass
|
60 |
+
|
61 |
+
@dataclass(frozen=True)
|
62 |
+
class Extras(Base):
|
63 |
+
pass
|
64 |
+
# ---------------------------------------
|
65 |
+
info: Info
|
66 |
+
config: Config
|
67 |
+
|
68 |
+
# FSDP stuff
|
69 |
+
fsdp_defaults = {
|
70 |
+
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
|
71 |
+
"cpu_offload": None,
|
72 |
+
"mixed_precision": MixedPrecision(
|
73 |
+
param_dtype=torch.bfloat16,
|
74 |
+
reduce_dtype=torch.bfloat16,
|
75 |
+
buffer_dtype=torch.bfloat16,
|
76 |
+
),
|
77 |
+
"limit_all_gathers": True,
|
78 |
+
}
|
79 |
+
fsdp_fullstate_save_policy = FullStateDictConfig(
|
80 |
+
offload_to_cpu=True, rank0_only=True
|
81 |
+
)
|
82 |
+
# ------------
|
83 |
+
|
84 |
+
# OVERRIDEABLE METHODS
|
85 |
+
|
86 |
+
# [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
|
87 |
+
def setup_extras_pre(self) -> Extras:
|
88 |
+
return self.Extras()
|
89 |
+
|
90 |
+
# setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
|
91 |
+
@abstractmethod
|
92 |
+
def setup_data(self, extras: Extras) -> Data:
|
93 |
+
raise NotImplementedError("This method needs to be overriden")
|
94 |
+
|
95 |
+
# return a dict with all models that are going to be used in the training
|
96 |
+
@abstractmethod
|
97 |
+
def setup_models(self, extras: Extras) -> Models:
|
98 |
+
raise NotImplementedError("This method needs to be overriden")
|
99 |
+
|
100 |
+
# return a dict with all optimizers that are going to be used in the training
|
101 |
+
@abstractmethod
|
102 |
+
def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
|
103 |
+
raise NotImplementedError("This method needs to be overriden")
|
104 |
+
|
105 |
+
# [optionally] return a dict with all schedulers that are going to be used in the training
|
106 |
+
def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
|
107 |
+
return self.Schedulers()
|
108 |
+
|
109 |
+
# [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
|
110 |
+
def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
|
111 |
+
return self.Extras.from_dict(extras.to_dict())
|
112 |
+
|
113 |
+
# perform the training here
|
114 |
+
@abstractmethod
|
115 |
+
def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
|
116 |
+
raise NotImplementedError("This method needs to be overriden")
|
117 |
+
# ------------
|
118 |
+
|
119 |
+
def setup_info(self, full_path=None) -> Info:
|
120 |
+
if full_path is None:
|
121 |
+
full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
|
122 |
+
info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
|
123 |
+
info_dto = self.Info(**info_dict)
|
124 |
+
if info_dto.total_steps > 0 and self.is_main_node:
|
125 |
+
print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
|
126 |
+
return info_dto
|
127 |
+
|
128 |
+
def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
|
129 |
+
if config_file_path is not None:
|
130 |
+
if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
|
131 |
+
with open(config_file_path, "r", encoding="utf-8") as file:
|
132 |
+
loaded_config = yaml.safe_load(file)
|
133 |
+
elif config_file_path.endswith(".json"):
|
134 |
+
with open(config_file_path, "r", encoding="utf-8") as file:
|
135 |
+
loaded_config = json.load(file)
|
136 |
+
else:
|
137 |
+
raise ValueError("Config file must be either a .yml|.yaml or .json file")
|
138 |
+
return self.Config.from_dict({**loaded_config, 'training': training})
|
139 |
+
if config_dict is not None:
|
140 |
+
return self.Config.from_dict({**config_dict, 'training': training})
|
141 |
+
return self.Config(training=training)
|
142 |
+
|
143 |
+
def setup_ddp(self, experiment_id, single_gpu=False):
|
144 |
+
if not single_gpu:
|
145 |
+
local_rank = int(os.environ.get("SLURM_LOCALID"))
|
146 |
+
process_id = int(os.environ.get("SLURM_PROCID"))
|
147 |
+
world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
|
148 |
+
|
149 |
+
self.process_id = process_id
|
150 |
+
self.is_main_node = process_id == 0
|
151 |
+
self.device = torch.device(local_rank)
|
152 |
+
self.world_size = world_size
|
153 |
+
|
154 |
+
dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
|
155 |
+
# if os.path.exists(dist_file_path) and self.is_main_node:
|
156 |
+
# os.remove(dist_file_path)
|
157 |
+
|
158 |
+
torch.cuda.set_device(local_rank)
|
159 |
+
init_process_group(
|
160 |
+
backend="nccl",
|
161 |
+
rank=process_id,
|
162 |
+
world_size=world_size,
|
163 |
+
init_method=f"file://{dist_file_path}",
|
164 |
+
)
|
165 |
+
print(f"[GPU {process_id}] READY")
|
166 |
+
else:
|
167 |
+
print("Running in single thread, DDP not enabled.")
|
168 |
+
|
169 |
+
def setup_wandb(self):
|
170 |
+
if self.is_main_node and self.config.wandb_project is not None:
|
171 |
+
self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
|
172 |
+
wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
|
173 |
+
|
174 |
+
if self.info.total_steps > 0:
|
175 |
+
wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
|
176 |
+
else:
|
177 |
+
wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
|
178 |
+
|
179 |
+
# LOAD UTILITIES ----------
|
180 |
+
def load_model(self, model, model_id=None, full_path=None, strict=True):
|
181 |
+
print('in line 181 load model', type(model), model_id, full_path, strict)
|
182 |
+
if model_id is not None and full_path is None:
|
183 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
|
184 |
+
elif full_path is None and model_id is None:
|
185 |
+
raise ValueError(
|
186 |
+
"This method expects either 'model_id' or 'full_path' to be defined"
|
187 |
+
)
|
188 |
+
|
189 |
+
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
|
190 |
+
if checkpoint is not None:
|
191 |
+
model.load_state_dict(checkpoint, strict=strict)
|
192 |
+
del checkpoint
|
193 |
+
|
194 |
+
return model
|
195 |
+
|
196 |
+
def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
|
197 |
+
if optim_id is not None and full_path is None:
|
198 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
|
199 |
+
elif full_path is None and optim_id is None:
|
200 |
+
raise ValueError(
|
201 |
+
"This method expects either 'optim_id' or 'full_path' to be defined"
|
202 |
+
)
|
203 |
+
|
204 |
+
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
|
205 |
+
if checkpoint is not None:
|
206 |
+
try:
|
207 |
+
if fsdp_model is not None:
|
208 |
+
sharded_optimizer_state_dict = (
|
209 |
+
FSDP.scatter_full_optim_state_dict( # <---- FSDP
|
210 |
+
checkpoint
|
211 |
+
if (
|
212 |
+
self.is_main_node
|
213 |
+
or self.fsdp_defaults["sharding_strategy"]
|
214 |
+
== ShardingStrategy.NO_SHARD
|
215 |
+
)
|
216 |
+
else None,
|
217 |
+
fsdp_model,
|
218 |
+
)
|
219 |
+
)
|
220 |
+
optim.load_state_dict(sharded_optimizer_state_dict)
|
221 |
+
del checkpoint, sharded_optimizer_state_dict
|
222 |
+
else:
|
223 |
+
optim.load_state_dict(checkpoint)
|
224 |
+
# pylint: disable=broad-except
|
225 |
+
except Exception as e:
|
226 |
+
print("!!! Failed loading optimizer, skipping... Exception:", e)
|
227 |
+
|
228 |
+
return optim
|
229 |
+
|
230 |
+
# SAVE UTILITIES ----------
|
231 |
+
def save_info(self, info, suffix=""):
|
232 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
|
233 |
+
create_folder_if_necessary(full_path)
|
234 |
+
if self.is_main_node:
|
235 |
+
safe_save(vars(self.info), full_path)
|
236 |
+
|
237 |
+
def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
|
238 |
+
if model_id is not None and full_path is None:
|
239 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
|
240 |
+
elif full_path is None and model_id is None:
|
241 |
+
raise ValueError(
|
242 |
+
"This method expects either 'model_id' or 'full_path' to be defined"
|
243 |
+
)
|
244 |
+
create_folder_if_necessary(full_path)
|
245 |
+
if is_fsdp:
|
246 |
+
with FSDP.summon_full_params(model):
|
247 |
+
pass
|
248 |
+
with FSDP.state_dict_type(
|
249 |
+
model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
|
250 |
+
):
|
251 |
+
checkpoint = model.state_dict()
|
252 |
+
if self.is_main_node:
|
253 |
+
safe_save(checkpoint, full_path)
|
254 |
+
del checkpoint
|
255 |
+
else:
|
256 |
+
if self.is_main_node:
|
257 |
+
checkpoint = model.state_dict()
|
258 |
+
safe_save(checkpoint, full_path)
|
259 |
+
del checkpoint
|
260 |
+
|
261 |
+
def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
|
262 |
+
if optim_id is not None and full_path is None:
|
263 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
|
264 |
+
elif full_path is None and optim_id is None:
|
265 |
+
raise ValueError(
|
266 |
+
"This method expects either 'optim_id' or 'full_path' to be defined"
|
267 |
+
)
|
268 |
+
create_folder_if_necessary(full_path)
|
269 |
+
if fsdp_model is not None:
|
270 |
+
optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
|
271 |
+
if self.is_main_node:
|
272 |
+
safe_save(optim_statedict, full_path)
|
273 |
+
del optim_statedict
|
274 |
+
else:
|
275 |
+
if self.is_main_node:
|
276 |
+
checkpoint = optim.state_dict()
|
277 |
+
safe_save(checkpoint, full_path)
|
278 |
+
del checkpoint
|
279 |
+
# -----
|
280 |
+
|
281 |
+
def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
|
282 |
+
# Temporary setup, will be overriden by setup_ddp if required
|
283 |
+
self.device = device
|
284 |
+
self.process_id = 0
|
285 |
+
self.is_main_node = True
|
286 |
+
self.world_size = 1
|
287 |
+
# ----
|
288 |
+
|
289 |
+
self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
|
290 |
+
self.info: self.Info = self.setup_info()
|
291 |
+
|
292 |
+
def __call__(self, single_gpu=False):
|
293 |
+
self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
|
294 |
+
self.setup_wandb()
|
295 |
+
if self.config.allow_tf32:
|
296 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
297 |
+
torch.backends.cudnn.allow_tf32 = True
|
298 |
+
|
299 |
+
if self.is_main_node:
|
300 |
+
print()
|
301 |
+
print("**STARTIG JOB WITH CONFIG:**")
|
302 |
+
print(yaml.dump(self.config.to_dict(), default_flow_style=False))
|
303 |
+
print("------------------------------------")
|
304 |
+
print()
|
305 |
+
print("**INFO:**")
|
306 |
+
print(yaml.dump(vars(self.info), default_flow_style=False))
|
307 |
+
print("------------------------------------")
|
308 |
+
print()
|
309 |
+
|
310 |
+
# SETUP STUFF
|
311 |
+
extras = self.setup_extras_pre()
|
312 |
+
assert extras is not None, "setup_extras_pre() must return a DTO"
|
313 |
+
|
314 |
+
data = self.setup_data(extras)
|
315 |
+
assert data is not None, "setup_data() must return a DTO"
|
316 |
+
if self.is_main_node:
|
317 |
+
print("**DATA:**")
|
318 |
+
print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
|
319 |
+
print("------------------------------------")
|
320 |
+
print()
|
321 |
+
|
322 |
+
models = self.setup_models(extras)
|
323 |
+
assert models is not None, "setup_models() must return a DTO"
|
324 |
+
if self.is_main_node:
|
325 |
+
print("**MODELS:**")
|
326 |
+
print(yaml.dump({
|
327 |
+
k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
|
328 |
+
}, default_flow_style=False))
|
329 |
+
print("------------------------------------")
|
330 |
+
print()
|
331 |
+
|
332 |
+
optimizers = self.setup_optimizers(extras, models)
|
333 |
+
assert optimizers is not None, "setup_optimizers() must return a DTO"
|
334 |
+
if self.is_main_node:
|
335 |
+
print("**OPTIMIZERS:**")
|
336 |
+
print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
|
337 |
+
print("------------------------------------")
|
338 |
+
print()
|
339 |
+
|
340 |
+
schedulers = self.setup_schedulers(extras, models, optimizers)
|
341 |
+
assert schedulers is not None, "setup_schedulers() must return a DTO"
|
342 |
+
if self.is_main_node:
|
343 |
+
print("**SCHEDULERS:**")
|
344 |
+
print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
|
345 |
+
print("------------------------------------")
|
346 |
+
print()
|
347 |
+
|
348 |
+
post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
|
349 |
+
assert post_extras is not None, "setup_extras_post() must return a DTO"
|
350 |
+
extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
|
351 |
+
if self.is_main_node:
|
352 |
+
print("**EXTRAS:**")
|
353 |
+
print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
|
354 |
+
print("------------------------------------")
|
355 |
+
print()
|
356 |
+
# -------
|
357 |
+
|
358 |
+
# TRAIN
|
359 |
+
if self.is_main_node:
|
360 |
+
print("**TRAINING STARTING...**")
|
361 |
+
self.train(data, extras, models, optimizers, schedulers)
|
362 |
+
|
363 |
+
if single_gpu is False:
|
364 |
+
barrier()
|
365 |
+
destroy_process_group()
|
366 |
+
if self.is_main_node:
|
367 |
+
print()
|
368 |
+
print("------------------------------------")
|
369 |
+
print()
|
370 |
+
print("**TRAINING COMPLETE**")
|
371 |
+
if self.config.wandb_project is not None:
|
372 |
+
wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
|
core/data/__init__.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import subprocess
|
3 |
+
import yaml
|
4 |
+
import os
|
5 |
+
from .bucketeer import Bucketeer
|
6 |
+
|
7 |
+
class MultiFilter():
|
8 |
+
def __init__(self, rules, default=False):
|
9 |
+
self.rules = rules
|
10 |
+
self.default = default
|
11 |
+
|
12 |
+
def __call__(self, x):
|
13 |
+
try:
|
14 |
+
x_json = x['json']
|
15 |
+
if isinstance(x_json, bytes):
|
16 |
+
x_json = json.loads(x_json)
|
17 |
+
validations = []
|
18 |
+
for k, r in self.rules.items():
|
19 |
+
if isinstance(k, tuple):
|
20 |
+
v = r(*[x_json[kv] for kv in k])
|
21 |
+
else:
|
22 |
+
v = r(x_json[k])
|
23 |
+
validations.append(v)
|
24 |
+
return all(validations)
|
25 |
+
except Exception:
|
26 |
+
return False
|
27 |
+
|
28 |
+
class MultiGetter():
|
29 |
+
def __init__(self, rules):
|
30 |
+
self.rules = rules
|
31 |
+
|
32 |
+
def __call__(self, x_json):
|
33 |
+
if isinstance(x_json, bytes):
|
34 |
+
x_json = json.loads(x_json)
|
35 |
+
outputs = []
|
36 |
+
for k, r in self.rules.items():
|
37 |
+
if isinstance(k, tuple):
|
38 |
+
v = r(*[x_json[kv] for kv in k])
|
39 |
+
else:
|
40 |
+
v = r(x_json[k])
|
41 |
+
outputs.append(v)
|
42 |
+
if len(outputs) == 1:
|
43 |
+
outputs = outputs[0]
|
44 |
+
return outputs
|
45 |
+
|
46 |
+
def setup_webdataset_path(paths, cache_path=None):
|
47 |
+
if cache_path is None or not os.path.exists(cache_path):
|
48 |
+
tar_paths = []
|
49 |
+
if isinstance(paths, str):
|
50 |
+
paths = [paths]
|
51 |
+
for path in paths:
|
52 |
+
if path.strip().endswith(".tar"):
|
53 |
+
# Avoid looking up s3 if we already have a tar file
|
54 |
+
tar_paths.append(path)
|
55 |
+
continue
|
56 |
+
bucket = "/".join(path.split("/")[:3])
|
57 |
+
result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
|
58 |
+
files = result.stdout.decode('utf-8').split()
|
59 |
+
files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
|
60 |
+
tar_paths += files
|
61 |
+
|
62 |
+
with open(cache_path, 'w', encoding='utf-8') as outfile:
|
63 |
+
yaml.dump(tar_paths, outfile, default_flow_style=False)
|
64 |
+
else:
|
65 |
+
with open(cache_path, 'r', encoding='utf-8') as file:
|
66 |
+
tar_paths = yaml.safe_load(file)
|
67 |
+
|
68 |
+
tar_paths_str = ",".join([f"{p}" for p in tar_paths])
|
69 |
+
return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
|
core/data/bucketeer.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import numpy as np
|
4 |
+
from torchtools.transforms import SmartCrop
|
5 |
+
import math
|
6 |
+
|
7 |
+
class Bucketeer():
|
8 |
+
def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
|
9 |
+
assert crop_mode in ['center', 'random', 'smart']
|
10 |
+
self.crop_mode = crop_mode
|
11 |
+
self.ratios = ratios
|
12 |
+
if reverse_list:
|
13 |
+
for r in list(ratios):
|
14 |
+
if 1/r not in self.ratios:
|
15 |
+
self.ratios.append(1/r)
|
16 |
+
self.sizes = {}
|
17 |
+
for dd in density:
|
18 |
+
self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
|
19 |
+
|
20 |
+
self.batch_size = dataloader.batch_size
|
21 |
+
self.iterator = iter(dataloader)
|
22 |
+
all_sizes = []
|
23 |
+
for k, vs in self.sizes.items():
|
24 |
+
all_sizes += vs
|
25 |
+
self.buckets = {s: [] for s in all_sizes}
|
26 |
+
self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
|
27 |
+
self.p_random_ratio = p_random_ratio
|
28 |
+
self.interpolate_nearest = interpolate_nearest
|
29 |
+
|
30 |
+
def get_available_batch(self):
|
31 |
+
for b in self.buckets:
|
32 |
+
if len(self.buckets[b]) >= self.batch_size:
|
33 |
+
batch = self.buckets[b][:self.batch_size]
|
34 |
+
self.buckets[b] = self.buckets[b][self.batch_size:]
|
35 |
+
return batch
|
36 |
+
return None
|
37 |
+
|
38 |
+
def get_closest_size(self, x):
|
39 |
+
w, h = x.size(-1), x.size(-2)
|
40 |
+
|
41 |
+
|
42 |
+
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
|
43 |
+
find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
|
44 |
+
min_ = find_dict[list(find_dict.keys())[0]]
|
45 |
+
find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
|
46 |
+
for dd, val in find_dict.items():
|
47 |
+
if val < min_:
|
48 |
+
min_ = val
|
49 |
+
find_size = self.sizes[dd][best_size_idx]
|
50 |
+
|
51 |
+
return find_size
|
52 |
+
|
53 |
+
def get_resize_size(self, orig_size, tgt_size):
|
54 |
+
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
|
55 |
+
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
|
56 |
+
resize_size = max(alt_min, min(tgt_size))
|
57 |
+
else:
|
58 |
+
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
|
59 |
+
resize_size = max(alt_max, max(tgt_size))
|
60 |
+
|
61 |
+
return resize_size
|
62 |
+
|
63 |
+
def __next__(self):
|
64 |
+
batch = self.get_available_batch()
|
65 |
+
while batch is None:
|
66 |
+
elements = next(self.iterator)
|
67 |
+
for dct in elements:
|
68 |
+
img = dct['images']
|
69 |
+
size = self.get_closest_size(img)
|
70 |
+
resize_size = self.get_resize_size(img.shape[-2:], size)
|
71 |
+
|
72 |
+
if self.interpolate_nearest:
|
73 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
|
74 |
+
else:
|
75 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
|
76 |
+
if self.crop_mode == 'center':
|
77 |
+
img = torchvision.transforms.functional.center_crop(img, size)
|
78 |
+
elif self.crop_mode == 'random':
|
79 |
+
img = torchvision.transforms.RandomCrop(size)(img)
|
80 |
+
elif self.crop_mode == 'smart':
|
81 |
+
self.smartcrop.output_size = size
|
82 |
+
img = self.smartcrop(img)
|
83 |
+
|
84 |
+
self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
|
85 |
+
batch = self.get_available_batch()
|
86 |
+
|
87 |
+
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
|
88 |
+
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
|
core/data/bucketeer_deg.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import numpy as np
|
4 |
+
from torchtools.transforms import SmartCrop
|
5 |
+
import math
|
6 |
+
|
7 |
+
class Bucketeer():
|
8 |
+
def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
|
9 |
+
assert crop_mode in ['center', 'random', 'smart']
|
10 |
+
self.crop_mode = crop_mode
|
11 |
+
self.ratios = ratios
|
12 |
+
if reverse_list:
|
13 |
+
for r in list(ratios):
|
14 |
+
if 1/r not in self.ratios:
|
15 |
+
self.ratios.append(1/r)
|
16 |
+
self.sizes = {}
|
17 |
+
for dd in density:
|
18 |
+
self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
|
19 |
+
print('in line 17 buckteer', self.sizes)
|
20 |
+
self.batch_size = dataloader.batch_size
|
21 |
+
self.iterator = iter(dataloader)
|
22 |
+
all_sizes = []
|
23 |
+
for k, vs in self.sizes.items():
|
24 |
+
all_sizes += vs
|
25 |
+
self.buckets = {s: [] for s in all_sizes}
|
26 |
+
self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
|
27 |
+
self.p_random_ratio = p_random_ratio
|
28 |
+
self.interpolate_nearest = interpolate_nearest
|
29 |
+
|
30 |
+
def get_available_batch(self):
|
31 |
+
for b in self.buckets:
|
32 |
+
if len(self.buckets[b]) >= self.batch_size:
|
33 |
+
batch = self.buckets[b][:self.batch_size]
|
34 |
+
self.buckets[b] = self.buckets[b][self.batch_size:]
|
35 |
+
return batch
|
36 |
+
return None
|
37 |
+
|
38 |
+
def get_closest_size(self, x):
|
39 |
+
w, h = x.size(-1), x.size(-2)
|
40 |
+
#if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
|
41 |
+
# best_size_idx = np.random.randint(len(self.ratios))
|
42 |
+
#print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio)
|
43 |
+
#else:
|
44 |
+
|
45 |
+
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
|
46 |
+
find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
|
47 |
+
min_ = find_dict[list(find_dict.keys())[0]]
|
48 |
+
find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
|
49 |
+
for dd, val in find_dict.items():
|
50 |
+
if val < min_:
|
51 |
+
min_ = val
|
52 |
+
find_size = self.sizes[dd][best_size_idx]
|
53 |
+
|
54 |
+
return find_size
|
55 |
+
|
56 |
+
def get_resize_size(self, orig_size, tgt_size):
|
57 |
+
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
|
58 |
+
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
|
59 |
+
resize_size = max(alt_min, min(tgt_size))
|
60 |
+
else:
|
61 |
+
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
|
62 |
+
resize_size = max(alt_max, max(tgt_size))
|
63 |
+
#print('in line 50', orig_size, tgt_size, resize_size)
|
64 |
+
return resize_size
|
65 |
+
|
66 |
+
def __next__(self):
|
67 |
+
batch = self.get_available_batch()
|
68 |
+
while batch is None:
|
69 |
+
elements = next(self.iterator)
|
70 |
+
for dct in elements:
|
71 |
+
img = dct['images']
|
72 |
+
size = self.get_closest_size(img)
|
73 |
+
resize_size = self.get_resize_size(img.shape[-2:], size)
|
74 |
+
#print('in line 74', img.size(), resize_size)
|
75 |
+
if self.interpolate_nearest:
|
76 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
|
77 |
+
else:
|
78 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
|
79 |
+
if self.crop_mode == 'center':
|
80 |
+
img = torchvision.transforms.functional.center_crop(img, size)
|
81 |
+
elif self.crop_mode == 'random':
|
82 |
+
img = torchvision.transforms.RandomCrop(size)(img)
|
83 |
+
elif self.crop_mode == 'smart':
|
84 |
+
self.smartcrop.output_size = size
|
85 |
+
img = self.smartcrop(img)
|
86 |
+
print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img))
|
87 |
+
self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
|
88 |
+
batch = self.get_available_batch()
|
89 |
+
|
90 |
+
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
|
91 |
+
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
|
core/data/deg_kair_utils/test.bmp
ADDED
core/data/deg_kair_utils/test.png
ADDED
Git LFS Details
|
core/data/deg_kair_utils/utils_alignfaces.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Mon Apr 24 15:43:29 2017
|
4 |
+
@author: zhaoy
|
5 |
+
"""
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from skimage import transform as trans
|
9 |
+
|
10 |
+
# reference facial points, a list of coordinates (x,y)
|
11 |
+
REFERENCE_FACIAL_POINTS = [
|
12 |
+
[30.29459953, 51.69630051],
|
13 |
+
[65.53179932, 51.50139999],
|
14 |
+
[48.02519989, 71.73660278],
|
15 |
+
[33.54930115, 92.3655014],
|
16 |
+
[62.72990036, 92.20410156]
|
17 |
+
]
|
18 |
+
|
19 |
+
DEFAULT_CROP_SIZE = (96, 112)
|
20 |
+
|
21 |
+
|
22 |
+
def _umeyama(src, dst, estimate_scale=True, scale=1.0):
|
23 |
+
"""Estimate N-D similarity transformation with or without scaling.
|
24 |
+
Parameters
|
25 |
+
----------
|
26 |
+
src : (M, N) array
|
27 |
+
Source coordinates.
|
28 |
+
dst : (M, N) array
|
29 |
+
Destination coordinates.
|
30 |
+
estimate_scale : bool
|
31 |
+
Whether to estimate scaling factor.
|
32 |
+
Returns
|
33 |
+
-------
|
34 |
+
T : (N + 1, N + 1)
|
35 |
+
The homogeneous similarity transformation matrix. The matrix contains
|
36 |
+
NaN values only if the problem is not well-conditioned.
|
37 |
+
References
|
38 |
+
----------
|
39 |
+
.. [1] "Least-squares estimation of transformation parameters between two
|
40 |
+
point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
|
41 |
+
"""
|
42 |
+
|
43 |
+
num = src.shape[0]
|
44 |
+
dim = src.shape[1]
|
45 |
+
|
46 |
+
# Compute mean of src and dst.
|
47 |
+
src_mean = src.mean(axis=0)
|
48 |
+
dst_mean = dst.mean(axis=0)
|
49 |
+
|
50 |
+
# Subtract mean from src and dst.
|
51 |
+
src_demean = src - src_mean
|
52 |
+
dst_demean = dst - dst_mean
|
53 |
+
|
54 |
+
# Eq. (38).
|
55 |
+
A = dst_demean.T @ src_demean / num
|
56 |
+
|
57 |
+
# Eq. (39).
|
58 |
+
d = np.ones((dim,), dtype=np.double)
|
59 |
+
if np.linalg.det(A) < 0:
|
60 |
+
d[dim - 1] = -1
|
61 |
+
|
62 |
+
T = np.eye(dim + 1, dtype=np.double)
|
63 |
+
|
64 |
+
U, S, V = np.linalg.svd(A)
|
65 |
+
|
66 |
+
# Eq. (40) and (43).
|
67 |
+
rank = np.linalg.matrix_rank(A)
|
68 |
+
if rank == 0:
|
69 |
+
return np.nan * T
|
70 |
+
elif rank == dim - 1:
|
71 |
+
if np.linalg.det(U) * np.linalg.det(V) > 0:
|
72 |
+
T[:dim, :dim] = U @ V
|
73 |
+
else:
|
74 |
+
s = d[dim - 1]
|
75 |
+
d[dim - 1] = -1
|
76 |
+
T[:dim, :dim] = U @ np.diag(d) @ V
|
77 |
+
d[dim - 1] = s
|
78 |
+
else:
|
79 |
+
T[:dim, :dim] = U @ np.diag(d) @ V
|
80 |
+
|
81 |
+
if estimate_scale:
|
82 |
+
# Eq. (41) and (42).
|
83 |
+
scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
|
84 |
+
else:
|
85 |
+
scale = scale
|
86 |
+
|
87 |
+
T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
|
88 |
+
T[:dim, :dim] *= scale
|
89 |
+
|
90 |
+
return T, scale
|
91 |
+
|
92 |
+
|
93 |
+
class FaceWarpException(Exception):
|
94 |
+
def __str__(self):
|
95 |
+
return 'In File {}:{}'.format(
|
96 |
+
__file__, super.__str__(self))
|
97 |
+
|
98 |
+
|
99 |
+
def get_reference_facial_points(output_size=None,
|
100 |
+
inner_padding_factor=0.0,
|
101 |
+
outer_padding=(0, 0),
|
102 |
+
default_square=False):
|
103 |
+
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
|
104 |
+
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
|
105 |
+
|
106 |
+
# 0) make the inner region a square
|
107 |
+
if default_square:
|
108 |
+
size_diff = max(tmp_crop_size) - tmp_crop_size
|
109 |
+
tmp_5pts += size_diff / 2
|
110 |
+
tmp_crop_size += size_diff
|
111 |
+
|
112 |
+
if (output_size and
|
113 |
+
output_size[0] == tmp_crop_size[0] and
|
114 |
+
output_size[1] == tmp_crop_size[1]):
|
115 |
+
print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
|
116 |
+
return tmp_5pts
|
117 |
+
|
118 |
+
if (inner_padding_factor == 0 and
|
119 |
+
outer_padding == (0, 0)):
|
120 |
+
if output_size is None:
|
121 |
+
print('No paddings to do: return default reference points')
|
122 |
+
return tmp_5pts
|
123 |
+
else:
|
124 |
+
raise FaceWarpException(
|
125 |
+
'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
|
126 |
+
|
127 |
+
# check output size
|
128 |
+
if not (0 <= inner_padding_factor <= 1.0):
|
129 |
+
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
|
130 |
+
|
131 |
+
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
|
132 |
+
and output_size is None):
|
133 |
+
output_size = tmp_crop_size * \
|
134 |
+
(1 + inner_padding_factor * 2).astype(np.int32)
|
135 |
+
output_size += np.array(outer_padding)
|
136 |
+
print(' deduced from paddings, output_size = ', output_size)
|
137 |
+
|
138 |
+
if not (outer_padding[0] < output_size[0]
|
139 |
+
and outer_padding[1] < output_size[1]):
|
140 |
+
raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
|
141 |
+
'and outer_padding[1] < output_size[1])')
|
142 |
+
|
143 |
+
# 1) pad the inner region according inner_padding_factor
|
144 |
+
# print('---> STEP1: pad the inner region according inner_padding_factor')
|
145 |
+
if inner_padding_factor > 0:
|
146 |
+
size_diff = tmp_crop_size * inner_padding_factor * 2
|
147 |
+
tmp_5pts += size_diff / 2
|
148 |
+
tmp_crop_size += np.round(size_diff).astype(np.int32)
|
149 |
+
|
150 |
+
# print(' crop_size = ', tmp_crop_size)
|
151 |
+
# print(' reference_5pts = ', tmp_5pts)
|
152 |
+
|
153 |
+
# 2) resize the padded inner region
|
154 |
+
# print('---> STEP2: resize the padded inner region')
|
155 |
+
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
|
156 |
+
# print(' crop_size = ', tmp_crop_size)
|
157 |
+
# print(' size_bf_outer_pad = ', size_bf_outer_pad)
|
158 |
+
|
159 |
+
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
|
160 |
+
raise FaceWarpException('Must have (output_size - outer_padding)'
|
161 |
+
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
|
162 |
+
|
163 |
+
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
|
164 |
+
# print(' resize scale_factor = ', scale_factor)
|
165 |
+
tmp_5pts = tmp_5pts * scale_factor
|
166 |
+
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
|
167 |
+
# tmp_5pts = tmp_5pts + size_diff / 2
|
168 |
+
tmp_crop_size = size_bf_outer_pad
|
169 |
+
# print(' crop_size = ', tmp_crop_size)
|
170 |
+
# print(' reference_5pts = ', tmp_5pts)
|
171 |
+
|
172 |
+
# 3) add outer_padding to make output_size
|
173 |
+
reference_5point = tmp_5pts + np.array(outer_padding)
|
174 |
+
tmp_crop_size = output_size
|
175 |
+
# print('---> STEP3: add outer_padding to make output_size')
|
176 |
+
# print(' crop_size = ', tmp_crop_size)
|
177 |
+
# print(' reference_5pts = ', tmp_5pts)
|
178 |
+
#
|
179 |
+
# print('===> end get_reference_facial_points\n')
|
180 |
+
|
181 |
+
return reference_5point
|
182 |
+
|
183 |
+
|
184 |
+
def get_affine_transform_matrix(src_pts, dst_pts):
|
185 |
+
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
|
186 |
+
n_pts = src_pts.shape[0]
|
187 |
+
ones = np.ones((n_pts, 1), src_pts.dtype)
|
188 |
+
src_pts_ = np.hstack([src_pts, ones])
|
189 |
+
dst_pts_ = np.hstack([dst_pts, ones])
|
190 |
+
|
191 |
+
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
|
192 |
+
|
193 |
+
if rank == 3:
|
194 |
+
tfm = np.float32([
|
195 |
+
[A[0, 0], A[1, 0], A[2, 0]],
|
196 |
+
[A[0, 1], A[1, 1], A[2, 1]]
|
197 |
+
])
|
198 |
+
elif rank == 2:
|
199 |
+
tfm = np.float32([
|
200 |
+
[A[0, 0], A[1, 0], 0],
|
201 |
+
[A[0, 1], A[1, 1], 0]
|
202 |
+
])
|
203 |
+
|
204 |
+
return tfm
|
205 |
+
|
206 |
+
|
207 |
+
def warp_and_crop_face(src_img,
|
208 |
+
facial_pts,
|
209 |
+
reference_pts=None,
|
210 |
+
crop_size=(96, 112),
|
211 |
+
align_type='smilarity'): #smilarity cv2_affine affine
|
212 |
+
if reference_pts is None:
|
213 |
+
if crop_size[0] == 96 and crop_size[1] == 112:
|
214 |
+
reference_pts = REFERENCE_FACIAL_POINTS
|
215 |
+
else:
|
216 |
+
default_square = False
|
217 |
+
inner_padding_factor = 0
|
218 |
+
outer_padding = (0, 0)
|
219 |
+
output_size = crop_size
|
220 |
+
|
221 |
+
reference_pts = get_reference_facial_points(output_size,
|
222 |
+
inner_padding_factor,
|
223 |
+
outer_padding,
|
224 |
+
default_square)
|
225 |
+
|
226 |
+
ref_pts = np.float32(reference_pts)
|
227 |
+
ref_pts_shp = ref_pts.shape
|
228 |
+
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
|
229 |
+
raise FaceWarpException(
|
230 |
+
'reference_pts.shape must be (K,2) or (2,K) and K>2')
|
231 |
+
|
232 |
+
if ref_pts_shp[0] == 2:
|
233 |
+
ref_pts = ref_pts.T
|
234 |
+
|
235 |
+
src_pts = np.float32(facial_pts)
|
236 |
+
src_pts_shp = src_pts.shape
|
237 |
+
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
|
238 |
+
raise FaceWarpException(
|
239 |
+
'facial_pts.shape must be (K,2) or (2,K) and K>2')
|
240 |
+
|
241 |
+
if src_pts_shp[0] == 2:
|
242 |
+
src_pts = src_pts.T
|
243 |
+
|
244 |
+
if src_pts.shape != ref_pts.shape:
|
245 |
+
raise FaceWarpException(
|
246 |
+
'facial_pts and reference_pts must have the same shape')
|
247 |
+
|
248 |
+
if align_type is 'cv2_affine':
|
249 |
+
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
|
250 |
+
tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
|
251 |
+
elif align_type is 'affine':
|
252 |
+
tfm = get_affine_transform_matrix(src_pts, ref_pts)
|
253 |
+
tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
|
254 |
+
else:
|
255 |
+
params, scale = _umeyama(src_pts, ref_pts)
|
256 |
+
tfm = params[:2, :]
|
257 |
+
|
258 |
+
params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
|
259 |
+
tfm_inv = params[:2, :]
|
260 |
+
|
261 |
+
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
|
262 |
+
|
263 |
+
return face_img, tfm_inv
|
core/data/deg_kair_utils/utils_blindsr.py
ADDED
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from core.data.deg_kair_utils import utils_image as util
|
7 |
+
|
8 |
+
import random
|
9 |
+
from scipy import ndimage
|
10 |
+
import scipy
|
11 |
+
import scipy.stats as ss
|
12 |
+
from scipy.interpolate import interp2d
|
13 |
+
from scipy.linalg import orth
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
"""
|
19 |
+
# --------------------------------------------
|
20 |
+
# Super-Resolution
|
21 |
+
# --------------------------------------------
|
22 |
+
#
|
23 |
+
# Kai Zhang ([email protected])
|
24 |
+
# https://github.com/cszn
|
25 |
+
# From 2019/03--2021/08
|
26 |
+
# --------------------------------------------
|
27 |
+
"""
|
28 |
+
|
29 |
+
def modcrop_np(img, sf):
|
30 |
+
'''
|
31 |
+
Args:
|
32 |
+
img: numpy image, WxH or WxHxC
|
33 |
+
sf: scale factor
|
34 |
+
|
35 |
+
Return:
|
36 |
+
cropped image
|
37 |
+
'''
|
38 |
+
w, h = img.shape[:2]
|
39 |
+
im = np.copy(img)
|
40 |
+
return im[:w - w % sf, :h - h % sf, ...]
|
41 |
+
|
42 |
+
|
43 |
+
"""
|
44 |
+
# --------------------------------------------
|
45 |
+
# anisotropic Gaussian kernels
|
46 |
+
# --------------------------------------------
|
47 |
+
"""
|
48 |
+
def analytic_kernel(k):
|
49 |
+
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
50 |
+
k_size = k.shape[0]
|
51 |
+
# Calculate the big kernels size
|
52 |
+
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
53 |
+
# Loop over the small kernel to fill the big one
|
54 |
+
for r in range(k_size):
|
55 |
+
for c in range(k_size):
|
56 |
+
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
57 |
+
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
58 |
+
crop = k_size // 2
|
59 |
+
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
60 |
+
# Normalize to 1
|
61 |
+
return cropped_big_k / cropped_big_k.sum()
|
62 |
+
|
63 |
+
|
64 |
+
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
65 |
+
""" generate an anisotropic Gaussian kernel
|
66 |
+
Args:
|
67 |
+
ksize : e.g., 15, kernel size
|
68 |
+
theta : [0, pi], rotation angle range
|
69 |
+
l1 : [0.1,50], scaling of eigenvalues
|
70 |
+
l2 : [0.1,l1], scaling of eigenvalues
|
71 |
+
If l1 = l2, will get an isotropic Gaussian kernel.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
k : kernel
|
75 |
+
"""
|
76 |
+
|
77 |
+
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
78 |
+
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
79 |
+
D = np.array([[l1, 0], [0, l2]])
|
80 |
+
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
81 |
+
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
82 |
+
|
83 |
+
return k
|
84 |
+
|
85 |
+
|
86 |
+
def gm_blur_kernel(mean, cov, size=15):
|
87 |
+
center = size / 2.0 + 0.5
|
88 |
+
k = np.zeros([size, size])
|
89 |
+
for y in range(size):
|
90 |
+
for x in range(size):
|
91 |
+
cy = y - center + 1
|
92 |
+
cx = x - center + 1
|
93 |
+
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
94 |
+
|
95 |
+
k = k / np.sum(k)
|
96 |
+
return k
|
97 |
+
|
98 |
+
|
99 |
+
def shift_pixel(x, sf, upper_left=True):
|
100 |
+
"""shift pixel for super-resolution with different scale factors
|
101 |
+
Args:
|
102 |
+
x: WxHxC or WxH
|
103 |
+
sf: scale factor
|
104 |
+
upper_left: shift direction
|
105 |
+
"""
|
106 |
+
h, w = x.shape[:2]
|
107 |
+
shift = (sf-1)*0.5
|
108 |
+
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
109 |
+
if upper_left:
|
110 |
+
x1 = xv + shift
|
111 |
+
y1 = yv + shift
|
112 |
+
else:
|
113 |
+
x1 = xv - shift
|
114 |
+
y1 = yv - shift
|
115 |
+
|
116 |
+
x1 = np.clip(x1, 0, w-1)
|
117 |
+
y1 = np.clip(y1, 0, h-1)
|
118 |
+
|
119 |
+
if x.ndim == 2:
|
120 |
+
x = interp2d(xv, yv, x)(x1, y1)
|
121 |
+
if x.ndim == 3:
|
122 |
+
for i in range(x.shape[-1]):
|
123 |
+
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
124 |
+
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
def blur(x, k):
|
129 |
+
'''
|
130 |
+
x: image, NxcxHxW
|
131 |
+
k: kernel, Nx1xhxw
|
132 |
+
'''
|
133 |
+
n, c = x.shape[:2]
|
134 |
+
p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2
|
135 |
+
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
136 |
+
k = k.repeat(1,c,1,1)
|
137 |
+
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
138 |
+
x = x.view(1, -1, x.shape[2], x.shape[3])
|
139 |
+
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c)
|
140 |
+
x = x.view(n, c, x.shape[2], x.shape[3])
|
141 |
+
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
147 |
+
""""
|
148 |
+
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
149 |
+
# Kai Zhang
|
150 |
+
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
151 |
+
# max_var = 2.5 * sf
|
152 |
+
"""
|
153 |
+
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
154 |
+
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
155 |
+
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
156 |
+
theta = np.random.rand() * np.pi # random theta
|
157 |
+
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
158 |
+
|
159 |
+
# Set COV matrix using Lambdas and Theta
|
160 |
+
LAMBDA = np.diag([lambda_1, lambda_2])
|
161 |
+
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
162 |
+
[np.sin(theta), np.cos(theta)]])
|
163 |
+
SIGMA = Q @ LAMBDA @ Q.T
|
164 |
+
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
165 |
+
|
166 |
+
# Set expectation position (shifting kernel for aligned image)
|
167 |
+
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
168 |
+
MU = MU[None, None, :, None]
|
169 |
+
|
170 |
+
# Create meshgrid for Gaussian
|
171 |
+
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
172 |
+
Z = np.stack([X, Y], 2)[:, :, :, None]
|
173 |
+
|
174 |
+
# Calcualte Gaussian for every pixel of the kernel
|
175 |
+
ZZ = Z-MU
|
176 |
+
ZZ_t = ZZ.transpose(0,1,3,2)
|
177 |
+
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
178 |
+
|
179 |
+
# shift the kernel so it will be centered
|
180 |
+
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
181 |
+
|
182 |
+
# Normalize the kernel and return
|
183 |
+
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
184 |
+
kernel = raw_kernel / np.sum(raw_kernel)
|
185 |
+
return kernel
|
186 |
+
|
187 |
+
|
188 |
+
def fspecial_gaussian(hsize, sigma):
|
189 |
+
hsize = [hsize, hsize]
|
190 |
+
siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
|
191 |
+
std = sigma
|
192 |
+
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
|
193 |
+
arg = -(x*x + y*y)/(2*std*std)
|
194 |
+
h = np.exp(arg)
|
195 |
+
h[h < scipy.finfo(float).eps * h.max()] = 0
|
196 |
+
sumh = h.sum()
|
197 |
+
if sumh != 0:
|
198 |
+
h = h/sumh
|
199 |
+
return h
|
200 |
+
|
201 |
+
|
202 |
+
def fspecial_laplacian(alpha):
|
203 |
+
alpha = max([0, min([alpha,1])])
|
204 |
+
h1 = alpha/(alpha+1)
|
205 |
+
h2 = (1-alpha)/(alpha+1)
|
206 |
+
h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
|
207 |
+
h = np.array(h)
|
208 |
+
return h
|
209 |
+
|
210 |
+
|
211 |
+
def fspecial(filter_type, *args, **kwargs):
|
212 |
+
'''
|
213 |
+
python code from:
|
214 |
+
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
215 |
+
'''
|
216 |
+
if filter_type == 'gaussian':
|
217 |
+
return fspecial_gaussian(*args, **kwargs)
|
218 |
+
if filter_type == 'laplacian':
|
219 |
+
return fspecial_laplacian(*args, **kwargs)
|
220 |
+
|
221 |
+
"""
|
222 |
+
# --------------------------------------------
|
223 |
+
# degradation models
|
224 |
+
# --------------------------------------------
|
225 |
+
"""
|
226 |
+
|
227 |
+
|
228 |
+
def bicubic_degradation(x, sf=3):
|
229 |
+
'''
|
230 |
+
Args:
|
231 |
+
x: HxWxC image, [0, 1]
|
232 |
+
sf: down-scale factor
|
233 |
+
|
234 |
+
Return:
|
235 |
+
bicubicly downsampled LR image
|
236 |
+
'''
|
237 |
+
x = util.imresize_np(x, scale=1/sf)
|
238 |
+
return x
|
239 |
+
|
240 |
+
|
241 |
+
def srmd_degradation(x, k, sf=3):
|
242 |
+
''' blur + bicubic downsampling
|
243 |
+
|
244 |
+
Args:
|
245 |
+
x: HxWxC image, [0, 1]
|
246 |
+
k: hxw, double
|
247 |
+
sf: down-scale factor
|
248 |
+
|
249 |
+
Return:
|
250 |
+
downsampled LR image
|
251 |
+
|
252 |
+
Reference:
|
253 |
+
@inproceedings{zhang2018learning,
|
254 |
+
title={Learning a single convolutional super-resolution network for multiple degradations},
|
255 |
+
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
256 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
257 |
+
pages={3262--3271},
|
258 |
+
year={2018}
|
259 |
+
}
|
260 |
+
'''
|
261 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
262 |
+
x = bicubic_degradation(x, sf=sf)
|
263 |
+
return x
|
264 |
+
|
265 |
+
|
266 |
+
def dpsr_degradation(x, k, sf=3):
|
267 |
+
|
268 |
+
''' bicubic downsampling + blur
|
269 |
+
|
270 |
+
Args:
|
271 |
+
x: HxWxC image, [0, 1]
|
272 |
+
k: hxw, double
|
273 |
+
sf: down-scale factor
|
274 |
+
|
275 |
+
Return:
|
276 |
+
downsampled LR image
|
277 |
+
|
278 |
+
Reference:
|
279 |
+
@inproceedings{zhang2019deep,
|
280 |
+
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
281 |
+
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
282 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
283 |
+
pages={1671--1681},
|
284 |
+
year={2019}
|
285 |
+
}
|
286 |
+
'''
|
287 |
+
x = bicubic_degradation(x, sf=sf)
|
288 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
289 |
+
return x
|
290 |
+
|
291 |
+
|
292 |
+
def classical_degradation(x, k, sf=3):
|
293 |
+
''' blur + downsampling
|
294 |
+
|
295 |
+
Args:
|
296 |
+
x: HxWxC image, [0, 1]/[0, 255]
|
297 |
+
k: hxw, double
|
298 |
+
sf: down-scale factor
|
299 |
+
|
300 |
+
Return:
|
301 |
+
downsampled LR image
|
302 |
+
'''
|
303 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
304 |
+
#x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
305 |
+
st = 0
|
306 |
+
return x[st::sf, st::sf, ...]
|
307 |
+
|
308 |
+
|
309 |
+
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
310 |
+
"""USM sharpening. borrowed from real-ESRGAN
|
311 |
+
Input image: I; Blurry image: B.
|
312 |
+
1. K = I + weight * (I - B)
|
313 |
+
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
314 |
+
3. Blur mask:
|
315 |
+
4. Out = Mask * K + (1 - Mask) * I
|
316 |
+
Args:
|
317 |
+
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
318 |
+
weight (float): Sharp weight. Default: 1.
|
319 |
+
radius (float): Kernel size of Gaussian blur. Default: 50.
|
320 |
+
threshold (int):
|
321 |
+
"""
|
322 |
+
if radius % 2 == 0:
|
323 |
+
radius += 1
|
324 |
+
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
325 |
+
residual = img - blur
|
326 |
+
mask = np.abs(residual) * 255 > threshold
|
327 |
+
mask = mask.astype('float32')
|
328 |
+
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
329 |
+
|
330 |
+
K = img + weight * residual
|
331 |
+
K = np.clip(K, 0, 1)
|
332 |
+
return soft_mask * K + (1 - soft_mask) * img
|
333 |
+
|
334 |
+
|
335 |
+
def add_blur(img, sf=4):
|
336 |
+
wd2 = 4.0 + sf
|
337 |
+
wd = 2.0 + 0.2*sf
|
338 |
+
if random.random() < 0.5:
|
339 |
+
l1 = wd2*random.random()
|
340 |
+
l2 = wd2*random.random()
|
341 |
+
k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2)
|
342 |
+
else:
|
343 |
+
k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random())
|
344 |
+
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
345 |
+
|
346 |
+
return img
|
347 |
+
|
348 |
+
|
349 |
+
def add_resize(img, sf=4):
|
350 |
+
rnum = np.random.rand()
|
351 |
+
if rnum > 0.8: # up
|
352 |
+
sf1 = random.uniform(1, 2)
|
353 |
+
elif rnum < 0.7: # down
|
354 |
+
sf1 = random.uniform(0.5/sf, 1)
|
355 |
+
else:
|
356 |
+
sf1 = 1.0
|
357 |
+
img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
358 |
+
img = np.clip(img, 0.0, 1.0)
|
359 |
+
|
360 |
+
return img
|
361 |
+
|
362 |
+
|
363 |
+
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
364 |
+
noise_level = random.randint(noise_level1, noise_level2)
|
365 |
+
rnum = np.random.rand()
|
366 |
+
if rnum > 0.6: # add color Gaussian noise
|
367 |
+
img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
|
368 |
+
elif rnum < 0.4: # add grayscale Gaussian noise
|
369 |
+
img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
|
370 |
+
else: # add noise
|
371 |
+
L = noise_level2/255.
|
372 |
+
D = np.diag(np.random.rand(3))
|
373 |
+
U = orth(np.random.rand(3,3))
|
374 |
+
conv = np.dot(np.dot(np.transpose(U), D), U)
|
375 |
+
img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
|
376 |
+
img = np.clip(img, 0.0, 1.0)
|
377 |
+
return img
|
378 |
+
|
379 |
+
|
380 |
+
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
381 |
+
noise_level = random.randint(noise_level1, noise_level2)
|
382 |
+
img = np.clip(img, 0.0, 1.0)
|
383 |
+
rnum = random.random()
|
384 |
+
if rnum > 0.6:
|
385 |
+
img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
|
386 |
+
elif rnum < 0.4:
|
387 |
+
img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
|
388 |
+
else:
|
389 |
+
L = noise_level2/255.
|
390 |
+
D = np.diag(np.random.rand(3))
|
391 |
+
U = orth(np.random.rand(3,3))
|
392 |
+
conv = np.dot(np.dot(np.transpose(U), D), U)
|
393 |
+
img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
|
394 |
+
img = np.clip(img, 0.0, 1.0)
|
395 |
+
return img
|
396 |
+
|
397 |
+
|
398 |
+
def add_Poisson_noise(img):
|
399 |
+
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
400 |
+
vals = 10**(2*random.random()+2.0) # [2, 4]
|
401 |
+
if random.random() < 0.5:
|
402 |
+
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
403 |
+
else:
|
404 |
+
img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114])
|
405 |
+
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
406 |
+
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
407 |
+
img += noise_gray[:, :, np.newaxis]
|
408 |
+
img = np.clip(img, 0.0, 1.0)
|
409 |
+
return img
|
410 |
+
|
411 |
+
|
412 |
+
def add_JPEG_noise(img):
|
413 |
+
quality_factor = random.randint(30, 95)
|
414 |
+
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
415 |
+
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
416 |
+
img = cv2.imdecode(encimg, 1)
|
417 |
+
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
418 |
+
return img
|
419 |
+
|
420 |
+
|
421 |
+
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
422 |
+
h, w = lq.shape[:2]
|
423 |
+
rnd_h = random.randint(0, h-lq_patchsize)
|
424 |
+
rnd_w = random.randint(0, w-lq_patchsize)
|
425 |
+
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
426 |
+
|
427 |
+
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
428 |
+
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :]
|
429 |
+
return lq, hq
|
430 |
+
|
431 |
+
|
432 |
+
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
433 |
+
"""
|
434 |
+
This is the degradation model of BSRGAN from the paper
|
435 |
+
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
436 |
+
----------
|
437 |
+
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
438 |
+
sf: scale factor
|
439 |
+
isp_model: camera ISP model
|
440 |
+
|
441 |
+
Returns
|
442 |
+
-------
|
443 |
+
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
444 |
+
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
445 |
+
"""
|
446 |
+
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
447 |
+
sf_ori = sf
|
448 |
+
|
449 |
+
h1, w1 = img.shape[:2]
|
450 |
+
img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
|
451 |
+
h, w = img.shape[:2]
|
452 |
+
|
453 |
+
if h < lq_patchsize*sf or w < lq_patchsize*sf:
|
454 |
+
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
455 |
+
|
456 |
+
hq = img.copy()
|
457 |
+
|
458 |
+
if sf == 4 and random.random() < scale2_prob: # downsample1
|
459 |
+
if np.random.rand() < 0.5:
|
460 |
+
img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3]))
|
461 |
+
else:
|
462 |
+
img = util.imresize_np(img, 1/2, True)
|
463 |
+
img = np.clip(img, 0.0, 1.0)
|
464 |
+
sf = 2
|
465 |
+
|
466 |
+
shuffle_order = random.sample(range(7), 7)
|
467 |
+
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
468 |
+
if idx1 > idx2: # keep downsample3 last
|
469 |
+
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
470 |
+
|
471 |
+
for i in shuffle_order:
|
472 |
+
|
473 |
+
if i == 0:
|
474 |
+
img = add_blur(img, sf=sf)
|
475 |
+
|
476 |
+
elif i == 1:
|
477 |
+
img = add_blur(img, sf=sf)
|
478 |
+
|
479 |
+
elif i == 2:
|
480 |
+
a, b = img.shape[1], img.shape[0]
|
481 |
+
# downsample2
|
482 |
+
if random.random() < 0.75:
|
483 |
+
sf1 = random.uniform(1,2*sf)
|
484 |
+
img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3]))
|
485 |
+
else:
|
486 |
+
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf))
|
487 |
+
k_shifted = shift_pixel(k, sf)
|
488 |
+
k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel
|
489 |
+
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
490 |
+
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
491 |
+
img = np.clip(img, 0.0, 1.0)
|
492 |
+
|
493 |
+
elif i == 3:
|
494 |
+
# downsample3
|
495 |
+
img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3]))
|
496 |
+
img = np.clip(img, 0.0, 1.0)
|
497 |
+
|
498 |
+
elif i == 4:
|
499 |
+
# add Gaussian noise
|
500 |
+
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
501 |
+
|
502 |
+
elif i == 5:
|
503 |
+
# add JPEG noise
|
504 |
+
if random.random() < jpeg_prob:
|
505 |
+
img = add_JPEG_noise(img)
|
506 |
+
|
507 |
+
elif i == 6:
|
508 |
+
# add processed camera sensor noise
|
509 |
+
if random.random() < isp_prob and isp_model is not None:
|
510 |
+
with torch.no_grad():
|
511 |
+
img, hq = isp_model.forward(img.copy(), hq)
|
512 |
+
|
513 |
+
# add final JPEG compression noise
|
514 |
+
img = add_JPEG_noise(img)
|
515 |
+
|
516 |
+
# random crop
|
517 |
+
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
518 |
+
|
519 |
+
return img, hq
|
520 |
+
|
521 |
+
|
522 |
+
|
523 |
+
|
524 |
+
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None):
|
525 |
+
"""
|
526 |
+
This is an extended degradation model by combining
|
527 |
+
the degradation models of BSRGAN and Real-ESRGAN
|
528 |
+
----------
|
529 |
+
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
530 |
+
sf: scale factor
|
531 |
+
use_shuffle: the degradation shuffle
|
532 |
+
use_sharp: sharpening the img
|
533 |
+
|
534 |
+
Returns
|
535 |
+
-------
|
536 |
+
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
537 |
+
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
538 |
+
"""
|
539 |
+
|
540 |
+
h1, w1 = img.shape[:2]
|
541 |
+
img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
|
542 |
+
h, w = img.shape[:2]
|
543 |
+
|
544 |
+
if h < lq_patchsize*sf or w < lq_patchsize*sf:
|
545 |
+
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
546 |
+
|
547 |
+
if use_sharp:
|
548 |
+
img = add_sharpening(img)
|
549 |
+
hq = img.copy()
|
550 |
+
|
551 |
+
if random.random() < shuffle_prob:
|
552 |
+
shuffle_order = random.sample(range(13), 13)
|
553 |
+
else:
|
554 |
+
shuffle_order = list(range(13))
|
555 |
+
# local shuffle for noise, JPEG is always the last one
|
556 |
+
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
557 |
+
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
558 |
+
|
559 |
+
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
560 |
+
|
561 |
+
for i in shuffle_order:
|
562 |
+
if i == 0:
|
563 |
+
img = add_blur(img, sf=sf)
|
564 |
+
elif i == 1:
|
565 |
+
img = add_resize(img, sf=sf)
|
566 |
+
elif i == 2:
|
567 |
+
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
568 |
+
elif i == 3:
|
569 |
+
if random.random() < poisson_prob:
|
570 |
+
img = add_Poisson_noise(img)
|
571 |
+
elif i == 4:
|
572 |
+
if random.random() < speckle_prob:
|
573 |
+
img = add_speckle_noise(img)
|
574 |
+
elif i == 5:
|
575 |
+
if random.random() < isp_prob and isp_model is not None:
|
576 |
+
with torch.no_grad():
|
577 |
+
img, hq = isp_model.forward(img.copy(), hq)
|
578 |
+
elif i == 6:
|
579 |
+
img = add_JPEG_noise(img)
|
580 |
+
elif i == 7:
|
581 |
+
img = add_blur(img, sf=sf)
|
582 |
+
elif i == 8:
|
583 |
+
img = add_resize(img, sf=sf)
|
584 |
+
elif i == 9:
|
585 |
+
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
586 |
+
elif i == 10:
|
587 |
+
if random.random() < poisson_prob:
|
588 |
+
img = add_Poisson_noise(img)
|
589 |
+
elif i == 11:
|
590 |
+
if random.random() < speckle_prob:
|
591 |
+
img = add_speckle_noise(img)
|
592 |
+
elif i == 12:
|
593 |
+
if random.random() < isp_prob and isp_model is not None:
|
594 |
+
with torch.no_grad():
|
595 |
+
img, hq = isp_model.forward(img.copy(), hq)
|
596 |
+
else:
|
597 |
+
print('check the shuffle!')
|
598 |
+
|
599 |
+
# resize to desired size
|
600 |
+
img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3]))
|
601 |
+
|
602 |
+
# add final JPEG compression noise
|
603 |
+
img = add_JPEG_noise(img)
|
604 |
+
|
605 |
+
# random crop
|
606 |
+
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
607 |
+
|
608 |
+
return img, hq
|
609 |
+
|
610 |
+
|
611 |
+
|
612 |
+
if __name__ == '__main__':
|
613 |
+
img = util.imread_uint('utils/test.png', 3)
|
614 |
+
img = util.uint2single(img)
|
615 |
+
sf = 4
|
616 |
+
|
617 |
+
for i in range(20):
|
618 |
+
img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72)
|
619 |
+
print(i)
|
620 |
+
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
|
621 |
+
img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
|
622 |
+
util.imsave(img_concat, str(i)+'.png')
|
623 |
+
|
624 |
+
# for i in range(10):
|
625 |
+
# img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64)
|
626 |
+
# print(i)
|
627 |
+
# lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
|
628 |
+
# img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
|
629 |
+
# util.imsave(img_concat, str(i)+'.png')
|
630 |
+
|
631 |
+
# run utils/utils_blindsr.py
|
core/data/deg_kair_utils/utils_bnorm.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
"""
|
6 |
+
# --------------------------------------------
|
7 |
+
# Batch Normalization
|
8 |
+
# --------------------------------------------
|
9 |
+
|
10 |
+
# Kai Zhang ([email protected])
|
11 |
+
# https://github.com/cszn
|
12 |
+
# 01/Jan/2019
|
13 |
+
# --------------------------------------------
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
# --------------------------------------------
|
18 |
+
# remove/delete specified layer
|
19 |
+
# --------------------------------------------
|
20 |
+
def deleteLayer(model, layer_type=nn.BatchNorm2d):
|
21 |
+
''' Kai Zhang, 11/Jan/2019.
|
22 |
+
'''
|
23 |
+
for k, m in list(model.named_children()):
|
24 |
+
if isinstance(m, layer_type):
|
25 |
+
del model._modules[k]
|
26 |
+
deleteLayer(m, layer_type)
|
27 |
+
|
28 |
+
|
29 |
+
# --------------------------------------------
|
30 |
+
# merge bn, "conv+bn" --> "conv"
|
31 |
+
# --------------------------------------------
|
32 |
+
def merge_bn(model):
|
33 |
+
''' Kai Zhang, 11/Jan/2019.
|
34 |
+
merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv')
|
35 |
+
based on https://github.com/pytorch/pytorch/pull/901
|
36 |
+
'''
|
37 |
+
prev_m = None
|
38 |
+
for k, m in list(model.named_children()):
|
39 |
+
if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
|
40 |
+
|
41 |
+
w = prev_m.weight.data
|
42 |
+
|
43 |
+
if prev_m.bias is None:
|
44 |
+
zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
|
45 |
+
prev_m.bias = nn.Parameter(zeros)
|
46 |
+
b = prev_m.bias.data
|
47 |
+
|
48 |
+
invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
|
49 |
+
if isinstance(prev_m, nn.ConvTranspose2d):
|
50 |
+
w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
|
51 |
+
else:
|
52 |
+
w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
|
53 |
+
b.add_(-m.running_mean).mul_(invstd)
|
54 |
+
if m.affine:
|
55 |
+
if isinstance(prev_m, nn.ConvTranspose2d):
|
56 |
+
w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
|
57 |
+
else:
|
58 |
+
w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
|
59 |
+
b.mul_(m.weight.data).add_(m.bias.data)
|
60 |
+
|
61 |
+
del model._modules[k]
|
62 |
+
prev_m = m
|
63 |
+
merge_bn(m)
|
64 |
+
|
65 |
+
|
66 |
+
# --------------------------------------------
|
67 |
+
# add bn, "conv" --> "conv+bn"
|
68 |
+
# --------------------------------------------
|
69 |
+
def add_bn(model):
|
70 |
+
''' Kai Zhang, 11/Jan/2019.
|
71 |
+
'''
|
72 |
+
for k, m in list(model.named_children()):
|
73 |
+
if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
|
74 |
+
b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True)
|
75 |
+
b.weight.data.fill_(1)
|
76 |
+
new_m = nn.Sequential(model._modules[k], b)
|
77 |
+
model._modules[k] = new_m
|
78 |
+
add_bn(m)
|
79 |
+
|
80 |
+
|
81 |
+
# --------------------------------------------
|
82 |
+
# tidy model after removing bn
|
83 |
+
# --------------------------------------------
|
84 |
+
def tidy_sequential(model):
|
85 |
+
''' Kai Zhang, 11/Jan/2019.
|
86 |
+
'''
|
87 |
+
for k, m in list(model.named_children()):
|
88 |
+
if isinstance(m, nn.Sequential):
|
89 |
+
if m.__len__() == 1:
|
90 |
+
model._modules[k] = m.__getitem__(0)
|
91 |
+
tidy_sequential(m)
|
core/data/deg_kair_utils/utils_deblur.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import scipy
|
4 |
+
from scipy import fftpack
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from math import cos, sin
|
8 |
+
from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round
|
9 |
+
from numpy.random import randn, rand
|
10 |
+
from scipy.signal import convolve2d
|
11 |
+
import cv2
|
12 |
+
import random
|
13 |
+
# import utils_image as util
|
14 |
+
|
15 |
+
'''
|
16 |
+
modified by Kai Zhang (github: https://github.com/cszn)
|
17 |
+
03/03/2019
|
18 |
+
'''
|
19 |
+
|
20 |
+
|
21 |
+
def get_uperleft_denominator(img, kernel):
|
22 |
+
'''
|
23 |
+
img: HxWxC
|
24 |
+
kernel: hxw
|
25 |
+
denominator: HxWx1
|
26 |
+
upperleft: HxWxC
|
27 |
+
'''
|
28 |
+
V = psf2otf(kernel, img.shape[:2])
|
29 |
+
denominator = np.expand_dims(np.abs(V)**2, axis=2)
|
30 |
+
upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
|
31 |
+
return upperleft, denominator
|
32 |
+
|
33 |
+
|
34 |
+
def get_uperleft_denominator_pytorch(img, kernel):
|
35 |
+
'''
|
36 |
+
img: NxCxHxW
|
37 |
+
kernel: Nx1xhxw
|
38 |
+
denominator: Nx1xHxW
|
39 |
+
upperleft: NxCxHxWx2
|
40 |
+
'''
|
41 |
+
V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2
|
42 |
+
denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW
|
43 |
+
upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2
|
44 |
+
return upperleft, denominator
|
45 |
+
|
46 |
+
|
47 |
+
def c2c(x):
|
48 |
+
return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
|
49 |
+
|
50 |
+
|
51 |
+
def r2c(x):
|
52 |
+
return torch.stack([x, torch.zeros_like(x)], -1)
|
53 |
+
|
54 |
+
|
55 |
+
def cdiv(x, y):
|
56 |
+
a, b = x[..., 0], x[..., 1]
|
57 |
+
c, d = y[..., 0], y[..., 1]
|
58 |
+
cd2 = c**2 + d**2
|
59 |
+
return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
|
60 |
+
|
61 |
+
|
62 |
+
def cabs(x):
|
63 |
+
return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
|
64 |
+
|
65 |
+
|
66 |
+
def cmul(t1, t2):
|
67 |
+
'''
|
68 |
+
complex multiplication
|
69 |
+
t1: NxCxHxWx2
|
70 |
+
output: NxCxHxWx2
|
71 |
+
'''
|
72 |
+
real1, imag1 = t1[..., 0], t1[..., 1]
|
73 |
+
real2, imag2 = t2[..., 0], t2[..., 1]
|
74 |
+
return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
|
75 |
+
|
76 |
+
|
77 |
+
def cconj(t, inplace=False):
|
78 |
+
'''
|
79 |
+
# complex's conjugation
|
80 |
+
t: NxCxHxWx2
|
81 |
+
output: NxCxHxWx2
|
82 |
+
'''
|
83 |
+
c = t.clone() if not inplace else t
|
84 |
+
c[..., 1] *= -1
|
85 |
+
return c
|
86 |
+
|
87 |
+
|
88 |
+
def rfft(t):
|
89 |
+
return torch.rfft(t, 2, onesided=False)
|
90 |
+
|
91 |
+
|
92 |
+
def irfft(t):
|
93 |
+
return torch.irfft(t, 2, onesided=False)
|
94 |
+
|
95 |
+
|
96 |
+
def fft(t):
|
97 |
+
return torch.fft(t, 2)
|
98 |
+
|
99 |
+
|
100 |
+
def ifft(t):
|
101 |
+
return torch.ifft(t, 2)
|
102 |
+
|
103 |
+
|
104 |
+
def p2o(psf, shape):
|
105 |
+
'''
|
106 |
+
# psf: NxCxhxw
|
107 |
+
# shape: [H,W]
|
108 |
+
# otf: NxCxHxWx2
|
109 |
+
'''
|
110 |
+
otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
|
111 |
+
otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
|
112 |
+
for axis, axis_size in enumerate(psf.shape[2:]):
|
113 |
+
otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
|
114 |
+
otf = torch.rfft(otf, 2, onesided=False)
|
115 |
+
n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
|
116 |
+
otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
|
117 |
+
return otf
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
# otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math.
|
122 |
+
def otf2psf(otf, outsize=None):
|
123 |
+
insize = np.array(otf.shape)
|
124 |
+
psf = np.fft.ifftn(otf, axes=(0, 1))
|
125 |
+
for axis, axis_size in enumerate(insize):
|
126 |
+
psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
|
127 |
+
if type(outsize) != type(None):
|
128 |
+
insize = np.array(otf.shape)
|
129 |
+
outsize = np.array(outsize)
|
130 |
+
n = max(np.size(outsize), np.size(insize))
|
131 |
+
# outsize = postpad(outsize(:), n, 1);
|
132 |
+
# insize = postpad(insize(:) , n, 1);
|
133 |
+
colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
|
134 |
+
colvec_in = insize.flatten().reshape((np.size(insize), 1))
|
135 |
+
outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
|
136 |
+
insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")
|
137 |
+
|
138 |
+
pad = (insize - outsize) / 2
|
139 |
+
if np.any(pad < 0):
|
140 |
+
print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
|
141 |
+
prepad = np.floor(pad)
|
142 |
+
postpad = np.ceil(pad)
|
143 |
+
dims_start = prepad.astype(int)
|
144 |
+
dims_end = (insize - postpad).astype(int)
|
145 |
+
for i in range(len(dims_start.shape)):
|
146 |
+
psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
|
147 |
+
n_ops = np.sum(otf.size * np.log2(otf.shape))
|
148 |
+
psf = np.real_if_close(psf, tol=n_ops)
|
149 |
+
return psf
|
150 |
+
|
151 |
+
|
152 |
+
# psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py
|
153 |
+
def psf2otf(psf, shape=None):
|
154 |
+
"""
|
155 |
+
Convert point-spread function to optical transfer function.
|
156 |
+
Compute the Fast Fourier Transform (FFT) of the point-spread
|
157 |
+
function (PSF) array and creates the optical transfer function (OTF)
|
158 |
+
array that is not influenced by the PSF off-centering.
|
159 |
+
By default, the OTF array is the same size as the PSF array.
|
160 |
+
To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
|
161 |
+
post-pads the PSF array (down or to the right) with zeros to match
|
162 |
+
dimensions specified in OUTSIZE, then circularly shifts the values of
|
163 |
+
the PSF array up (or to the left) until the central pixel reaches (1,1)
|
164 |
+
position.
|
165 |
+
Parameters
|
166 |
+
----------
|
167 |
+
psf : `numpy.ndarray`
|
168 |
+
PSF array
|
169 |
+
shape : int
|
170 |
+
Output shape of the OTF array
|
171 |
+
Returns
|
172 |
+
-------
|
173 |
+
otf : `numpy.ndarray`
|
174 |
+
OTF array
|
175 |
+
Notes
|
176 |
+
-----
|
177 |
+
Adapted from MATLAB psf2otf function
|
178 |
+
"""
|
179 |
+
if type(shape) == type(None):
|
180 |
+
shape = psf.shape
|
181 |
+
shape = np.array(shape)
|
182 |
+
if np.all(psf == 0):
|
183 |
+
# return np.zeros_like(psf)
|
184 |
+
return np.zeros(shape)
|
185 |
+
if len(psf.shape) == 1:
|
186 |
+
psf = psf.reshape((1, psf.shape[0]))
|
187 |
+
inshape = psf.shape
|
188 |
+
psf = zero_pad(psf, shape, position='corner')
|
189 |
+
for axis, axis_size in enumerate(inshape):
|
190 |
+
psf = np.roll(psf, -int(axis_size / 2), axis=axis)
|
191 |
+
# Compute the OTF
|
192 |
+
otf = np.fft.fft2(psf, axes=(0, 1))
|
193 |
+
# Estimate the rough number of operations involved in the FFT
|
194 |
+
# and discard the PSF imaginary part if within roundoff error
|
195 |
+
# roundoff error = machine epsilon = sys.float_info.epsilon
|
196 |
+
# or np.finfo().eps
|
197 |
+
n_ops = np.sum(psf.size * np.log2(psf.shape))
|
198 |
+
otf = np.real_if_close(otf, tol=n_ops)
|
199 |
+
return otf
|
200 |
+
|
201 |
+
|
202 |
+
def zero_pad(image, shape, position='corner'):
|
203 |
+
"""
|
204 |
+
Extends image to a certain size with zeros
|
205 |
+
Parameters
|
206 |
+
----------
|
207 |
+
image: real 2d `numpy.ndarray`
|
208 |
+
Input image
|
209 |
+
shape: tuple of int
|
210 |
+
Desired output shape of the image
|
211 |
+
position : str, optional
|
212 |
+
The position of the input image in the output one:
|
213 |
+
* 'corner'
|
214 |
+
top-left corner (default)
|
215 |
+
* 'center'
|
216 |
+
centered
|
217 |
+
Returns
|
218 |
+
-------
|
219 |
+
padded_img: real `numpy.ndarray`
|
220 |
+
The zero-padded image
|
221 |
+
"""
|
222 |
+
shape = np.asarray(shape, dtype=int)
|
223 |
+
imshape = np.asarray(image.shape, dtype=int)
|
224 |
+
if np.alltrue(imshape == shape):
|
225 |
+
return image
|
226 |
+
if np.any(shape <= 0):
|
227 |
+
raise ValueError("ZERO_PAD: null or negative shape given")
|
228 |
+
dshape = shape - imshape
|
229 |
+
if np.any(dshape < 0):
|
230 |
+
raise ValueError("ZERO_PAD: target size smaller than source one")
|
231 |
+
pad_img = np.zeros(shape, dtype=image.dtype)
|
232 |
+
idx, idy = np.indices(imshape)
|
233 |
+
if position == 'center':
|
234 |
+
if np.any(dshape % 2 != 0):
|
235 |
+
raise ValueError("ZERO_PAD: source and target shapes "
|
236 |
+
"have different parity.")
|
237 |
+
offx, offy = dshape // 2
|
238 |
+
else:
|
239 |
+
offx, offy = (0, 0)
|
240 |
+
pad_img[idx + offx, idy + offy] = image
|
241 |
+
return pad_img
|
242 |
+
|
243 |
+
|
244 |
+
'''
|
245 |
+
Reducing boundary artifacts
|
246 |
+
'''
|
247 |
+
|
248 |
+
|
249 |
+
def opt_fft_size(n):
|
250 |
+
'''
|
251 |
+
Kai Zhang (github: https://github.com/cszn)
|
252 |
+
03/03/2019
|
253 |
+
# opt_fft_size.m
|
254 |
+
# compute an optimal data length for Fourier transforms
|
255 |
+
# written by Sunghyun Cho ([email protected])
|
256 |
+
# persistent opt_fft_size_LUT;
|
257 |
+
'''
|
258 |
+
|
259 |
+
LUT_size = 2048
|
260 |
+
# print("generate opt_fft_size_LUT")
|
261 |
+
opt_fft_size_LUT = np.zeros(LUT_size)
|
262 |
+
|
263 |
+
e2 = 1
|
264 |
+
while e2 <= LUT_size:
|
265 |
+
e3 = e2
|
266 |
+
while e3 <= LUT_size:
|
267 |
+
e5 = e3
|
268 |
+
while e5 <= LUT_size:
|
269 |
+
e7 = e5
|
270 |
+
while e7 <= LUT_size:
|
271 |
+
if e7 <= LUT_size:
|
272 |
+
opt_fft_size_LUT[e7-1] = e7
|
273 |
+
if e7*11 <= LUT_size:
|
274 |
+
opt_fft_size_LUT[e7*11-1] = e7*11
|
275 |
+
if e7*13 <= LUT_size:
|
276 |
+
opt_fft_size_LUT[e7*13-1] = e7*13
|
277 |
+
e7 = e7 * 7
|
278 |
+
e5 = e5 * 5
|
279 |
+
e3 = e3 * 3
|
280 |
+
e2 = e2 * 2
|
281 |
+
|
282 |
+
nn = 0
|
283 |
+
for i in range(LUT_size, 0, -1):
|
284 |
+
if opt_fft_size_LUT[i-1] != 0:
|
285 |
+
nn = i-1
|
286 |
+
else:
|
287 |
+
opt_fft_size_LUT[i-1] = nn+1
|
288 |
+
|
289 |
+
m = np.zeros(len(n))
|
290 |
+
for c in range(len(n)):
|
291 |
+
nn = n[c]
|
292 |
+
if nn <= LUT_size:
|
293 |
+
m[c] = opt_fft_size_LUT[nn-1]
|
294 |
+
else:
|
295 |
+
m[c] = -1
|
296 |
+
return m
|
297 |
+
|
298 |
+
|
299 |
+
def wrap_boundary_liu(img, img_size):
|
300 |
+
|
301 |
+
"""
|
302 |
+
Reducing boundary artifacts in image deconvolution
|
303 |
+
Renting Liu, Jiaya Jia
|
304 |
+
ICIP 2008
|
305 |
+
"""
|
306 |
+
if img.ndim == 2:
|
307 |
+
ret = wrap_boundary(img, img_size)
|
308 |
+
elif img.ndim == 3:
|
309 |
+
ret = [wrap_boundary(img[:, :, i], img_size) for i in range(3)]
|
310 |
+
ret = np.stack(ret, 2)
|
311 |
+
return ret
|
312 |
+
|
313 |
+
|
314 |
+
def wrap_boundary(img, img_size):
|
315 |
+
|
316 |
+
"""
|
317 |
+
python code from:
|
318 |
+
https://github.com/ys-koshelev/nla_deblur/blob/90fe0ab98c26c791dcbdf231fe6f938fca80e2a0/boundaries.py
|
319 |
+
Reducing boundary artifacts in image deconvolution
|
320 |
+
Renting Liu, Jiaya Jia
|
321 |
+
ICIP 2008
|
322 |
+
"""
|
323 |
+
(H, W) = np.shape(img)
|
324 |
+
H_w = int(img_size[0]) - H
|
325 |
+
W_w = int(img_size[1]) - W
|
326 |
+
|
327 |
+
# ret = np.zeros((img_size[0], img_size[1]));
|
328 |
+
alpha = 1
|
329 |
+
HG = img[:, :]
|
330 |
+
|
331 |
+
r_A = np.zeros((alpha*2+H_w, W))
|
332 |
+
r_A[:alpha, :] = HG[-alpha:, :]
|
333 |
+
r_A[-alpha:, :] = HG[:alpha, :]
|
334 |
+
a = np.arange(H_w)/(H_w-1)
|
335 |
+
# r_A(alpha+1:end-alpha, 1) = (1-a)*r_A(alpha,1) + a*r_A(end-alpha+1,1)
|
336 |
+
r_A[alpha:-alpha, 0] = (1-a)*r_A[alpha-1, 0] + a*r_A[-alpha, 0]
|
337 |
+
# r_A(alpha+1:end-alpha, end) = (1-a)*r_A(alpha,end) + a*r_A(end-alpha+1,end)
|
338 |
+
r_A[alpha:-alpha, -1] = (1-a)*r_A[alpha-1, -1] + a*r_A[-alpha, -1]
|
339 |
+
|
340 |
+
r_B = np.zeros((H, alpha*2+W_w))
|
341 |
+
r_B[:, :alpha] = HG[:, -alpha:]
|
342 |
+
r_B[:, -alpha:] = HG[:, :alpha]
|
343 |
+
a = np.arange(W_w)/(W_w-1)
|
344 |
+
r_B[0, alpha:-alpha] = (1-a)*r_B[0, alpha-1] + a*r_B[0, -alpha]
|
345 |
+
r_B[-1, alpha:-alpha] = (1-a)*r_B[-1, alpha-1] + a*r_B[-1, -alpha]
|
346 |
+
|
347 |
+
if alpha == 1:
|
348 |
+
A2 = solve_min_laplacian(r_A[alpha-1:, :])
|
349 |
+
B2 = solve_min_laplacian(r_B[:, alpha-1:])
|
350 |
+
r_A[alpha-1:, :] = A2
|
351 |
+
r_B[:, alpha-1:] = B2
|
352 |
+
else:
|
353 |
+
A2 = solve_min_laplacian(r_A[alpha-1:-alpha+1, :])
|
354 |
+
r_A[alpha-1:-alpha+1, :] = A2
|
355 |
+
B2 = solve_min_laplacian(r_B[:, alpha-1:-alpha+1])
|
356 |
+
r_B[:, alpha-1:-alpha+1] = B2
|
357 |
+
A = r_A
|
358 |
+
B = r_B
|
359 |
+
|
360 |
+
r_C = np.zeros((alpha*2+H_w, alpha*2+W_w))
|
361 |
+
r_C[:alpha, :] = B[-alpha:, :]
|
362 |
+
r_C[-alpha:, :] = B[:alpha, :]
|
363 |
+
r_C[:, :alpha] = A[:, -alpha:]
|
364 |
+
r_C[:, -alpha:] = A[:, :alpha]
|
365 |
+
|
366 |
+
if alpha == 1:
|
367 |
+
C2 = C2 = solve_min_laplacian(r_C[alpha-1:, alpha-1:])
|
368 |
+
r_C[alpha-1:, alpha-1:] = C2
|
369 |
+
else:
|
370 |
+
C2 = solve_min_laplacian(r_C[alpha-1:-alpha+1, alpha-1:-alpha+1])
|
371 |
+
r_C[alpha-1:-alpha+1, alpha-1:-alpha+1] = C2
|
372 |
+
C = r_C
|
373 |
+
# return C
|
374 |
+
A = A[alpha-1:-alpha-1, :]
|
375 |
+
B = B[:, alpha:-alpha]
|
376 |
+
C = C[alpha:-alpha, alpha:-alpha]
|
377 |
+
ret = np.vstack((np.hstack((img, B)), np.hstack((A, C))))
|
378 |
+
return ret
|
379 |
+
|
380 |
+
|
381 |
+
def solve_min_laplacian(boundary_image):
|
382 |
+
(H, W) = np.shape(boundary_image)
|
383 |
+
|
384 |
+
# Laplacian
|
385 |
+
f = np.zeros((H, W))
|
386 |
+
# boundary image contains image intensities at boundaries
|
387 |
+
boundary_image[1:-1, 1:-1] = 0
|
388 |
+
j = np.arange(2, H)-1
|
389 |
+
k = np.arange(2, W)-1
|
390 |
+
f_bp = np.zeros((H, W))
|
391 |
+
f_bp[np.ix_(j, k)] = -4*boundary_image[np.ix_(j, k)] + boundary_image[np.ix_(j, k+1)] + boundary_image[np.ix_(j, k-1)] + boundary_image[np.ix_(j-1, k)] + boundary_image[np.ix_(j+1, k)]
|
392 |
+
|
393 |
+
del(j, k)
|
394 |
+
f1 = f - f_bp # subtract boundary points contribution
|
395 |
+
del(f_bp, f)
|
396 |
+
|
397 |
+
# DST Sine Transform algo starts here
|
398 |
+
f2 = f1[1:-1,1:-1]
|
399 |
+
del(f1)
|
400 |
+
|
401 |
+
# compute sine tranform
|
402 |
+
if f2.shape[1] == 1:
|
403 |
+
tt = fftpack.dst(f2, type=1, axis=0)/2
|
404 |
+
else:
|
405 |
+
tt = fftpack.dst(f2, type=1)/2
|
406 |
+
|
407 |
+
if tt.shape[0] == 1:
|
408 |
+
f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1, axis=0)/2)
|
409 |
+
else:
|
410 |
+
f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1)/2)
|
411 |
+
del(f2)
|
412 |
+
|
413 |
+
# compute Eigen Values
|
414 |
+
[x, y] = np.meshgrid(np.arange(1, W-1), np.arange(1, H-1))
|
415 |
+
denom = (2*np.cos(np.pi*x/(W-1))-2) + (2*np.cos(np.pi*y/(H-1)) - 2)
|
416 |
+
|
417 |
+
# divide
|
418 |
+
f3 = f2sin/denom
|
419 |
+
del(f2sin, x, y)
|
420 |
+
|
421 |
+
# compute Inverse Sine Transform
|
422 |
+
if f3.shape[0] == 1:
|
423 |
+
tt = fftpack.idst(f3*2, type=1, axis=1)/(2*(f3.shape[1]+1))
|
424 |
+
else:
|
425 |
+
tt = fftpack.idst(f3*2, type=1, axis=0)/(2*(f3.shape[0]+1))
|
426 |
+
del(f3)
|
427 |
+
if tt.shape[1] == 1:
|
428 |
+
img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1)/(2*(tt.shape[0]+1)))
|
429 |
+
else:
|
430 |
+
img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1, axis=0)/(2*(tt.shape[1]+1)))
|
431 |
+
del(tt)
|
432 |
+
|
433 |
+
# put solution in inner points; outer points obtained from boundary image
|
434 |
+
img_direct = boundary_image
|
435 |
+
img_direct[1:-1, 1:-1] = 0
|
436 |
+
img_direct[1:-1, 1:-1] = img_tt
|
437 |
+
return img_direct
|
438 |
+
|
439 |
+
|
440 |
+
"""
|
441 |
+
Created on Thu Jan 18 15:36:32 2018
|
442 |
+
@author: italo
|
443 |
+
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
444 |
+
"""
|
445 |
+
|
446 |
+
"""
|
447 |
+
Syntax
|
448 |
+
h = fspecial(type)
|
449 |
+
h = fspecial('average',hsize)
|
450 |
+
h = fspecial('disk',radius)
|
451 |
+
h = fspecial('gaussian',hsize,sigma)
|
452 |
+
h = fspecial('laplacian',alpha)
|
453 |
+
h = fspecial('log',hsize,sigma)
|
454 |
+
h = fspecial('motion',len,theta)
|
455 |
+
h = fspecial('prewitt')
|
456 |
+
h = fspecial('sobel')
|
457 |
+
"""
|
458 |
+
|
459 |
+
|
460 |
+
def fspecial_average(hsize=3):
|
461 |
+
"""Smoothing filter"""
|
462 |
+
return np.ones((hsize, hsize))/hsize**2
|
463 |
+
|
464 |
+
|
465 |
+
def fspecial_disk(radius):
|
466 |
+
"""Disk filter"""
|
467 |
+
raise(NotImplemented)
|
468 |
+
rad = 0.6
|
469 |
+
crad = np.ceil(rad-0.5)
|
470 |
+
[x, y] = np.meshgrid(np.arange(-crad, crad+1), np.arange(-crad, crad+1))
|
471 |
+
maxxy = np.zeros(x.shape)
|
472 |
+
maxxy[abs(x) >= abs(y)] = abs(x)[abs(x) >= abs(y)]
|
473 |
+
maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)]
|
474 |
+
minxy = np.zeros(x.shape)
|
475 |
+
minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)]
|
476 |
+
minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)]
|
477 |
+
m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\
|
478 |
+
(rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\
|
479 |
+
np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2)
|
480 |
+
m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\
|
481 |
+
(rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\
|
482 |
+
np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2)
|
483 |
+
h = None
|
484 |
+
return h
|
485 |
+
|
486 |
+
|
487 |
+
def fspecial_gaussian(hsize, sigma):
|
488 |
+
hsize = [hsize, hsize]
|
489 |
+
siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
|
490 |
+
std = sigma
|
491 |
+
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
|
492 |
+
arg = -(x*x + y*y)/(2*std*std)
|
493 |
+
h = np.exp(arg)
|
494 |
+
h[h < scipy.finfo(float).eps * h.max()] = 0
|
495 |
+
sumh = h.sum()
|
496 |
+
if sumh != 0:
|
497 |
+
h = h/sumh
|
498 |
+
return h
|
499 |
+
|
500 |
+
|
501 |
+
def fspecial_laplacian(alpha):
|
502 |
+
alpha = max([0, min([alpha,1])])
|
503 |
+
h1 = alpha/(alpha+1)
|
504 |
+
h2 = (1-alpha)/(alpha+1)
|
505 |
+
h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
|
506 |
+
h = np.array(h)
|
507 |
+
return h
|
508 |
+
|
509 |
+
|
510 |
+
def fspecial_log(hsize, sigma):
|
511 |
+
raise(NotImplemented)
|
512 |
+
|
513 |
+
|
514 |
+
def fspecial_motion(motion_len, theta):
|
515 |
+
raise(NotImplemented)
|
516 |
+
|
517 |
+
|
518 |
+
def fspecial_prewitt():
|
519 |
+
return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])
|
520 |
+
|
521 |
+
|
522 |
+
def fspecial_sobel():
|
523 |
+
return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
|
524 |
+
|
525 |
+
|
526 |
+
def fspecial(filter_type, *args, **kwargs):
|
527 |
+
'''
|
528 |
+
python code from:
|
529 |
+
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
530 |
+
'''
|
531 |
+
if filter_type == 'average':
|
532 |
+
return fspecial_average(*args, **kwargs)
|
533 |
+
if filter_type == 'disk':
|
534 |
+
return fspecial_disk(*args, **kwargs)
|
535 |
+
if filter_type == 'gaussian':
|
536 |
+
return fspecial_gaussian(*args, **kwargs)
|
537 |
+
if filter_type == 'laplacian':
|
538 |
+
return fspecial_laplacian(*args, **kwargs)
|
539 |
+
if filter_type == 'log':
|
540 |
+
return fspecial_log(*args, **kwargs)
|
541 |
+
if filter_type == 'motion':
|
542 |
+
return fspecial_motion(*args, **kwargs)
|
543 |
+
if filter_type == 'prewitt':
|
544 |
+
return fspecial_prewitt(*args, **kwargs)
|
545 |
+
if filter_type == 'sobel':
|
546 |
+
return fspecial_sobel(*args, **kwargs)
|
547 |
+
|
548 |
+
|
549 |
+
def fspecial_gauss(size, sigma):
|
550 |
+
x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
|
551 |
+
g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
|
552 |
+
return g / g.sum()
|
553 |
+
|
554 |
+
|
555 |
+
def blurkernel_synthesis(h=37, w=None):
|
556 |
+
# https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py
|
557 |
+
w = h if w is None else w
|
558 |
+
kdims = [h, w]
|
559 |
+
x = randomTrajectory(250)
|
560 |
+
k = None
|
561 |
+
while k is None:
|
562 |
+
k = kernelFromTrajectory(x)
|
563 |
+
|
564 |
+
# center pad to kdims
|
565 |
+
pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2)
|
566 |
+
pad_width = [(pad_width[0],), (pad_width[1],)]
|
567 |
+
|
568 |
+
if pad_width[0][0]<0 or pad_width[1][0]<0:
|
569 |
+
k = k[0:h, 0:h]
|
570 |
+
else:
|
571 |
+
k = pad(k, pad_width, "constant")
|
572 |
+
x1,x2 = k.shape
|
573 |
+
if np.random.randint(0, 4) == 1:
|
574 |
+
k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR)
|
575 |
+
y1, y2 = k.shape
|
576 |
+
k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2]
|
577 |
+
|
578 |
+
if sum(k)<0.1:
|
579 |
+
k = fspecial_gaussian(h, 0.1+6*np.random.rand(1))
|
580 |
+
k = k / sum(k)
|
581 |
+
# import matplotlib.pyplot as plt
|
582 |
+
# plt.imshow(k, interpolation="nearest", cmap="gray")
|
583 |
+
# plt.show()
|
584 |
+
return k
|
585 |
+
|
586 |
+
|
587 |
+
def kernelFromTrajectory(x):
|
588 |
+
h = 5 - log(rand()) / 0.15
|
589 |
+
h = round(min([h, 27])).astype(int)
|
590 |
+
h = h + 1 - h % 2
|
591 |
+
w = h
|
592 |
+
k = zeros((h, w))
|
593 |
+
|
594 |
+
xmin = min(x[0])
|
595 |
+
xmax = max(x[0])
|
596 |
+
ymin = min(x[1])
|
597 |
+
ymax = max(x[1])
|
598 |
+
xthr = arange(xmin, xmax, (xmax - xmin) / w)
|
599 |
+
ythr = arange(ymin, ymax, (ymax - ymin) / h)
|
600 |
+
|
601 |
+
for i in range(1, xthr.size):
|
602 |
+
for j in range(1, ythr.size):
|
603 |
+
idx = (
|
604 |
+
(x[0, :] >= xthr[i - 1])
|
605 |
+
& (x[0, :] < xthr[i])
|
606 |
+
& (x[1, :] >= ythr[j - 1])
|
607 |
+
& (x[1, :] < ythr[j])
|
608 |
+
)
|
609 |
+
k[i - 1, j - 1] = sum(idx)
|
610 |
+
if sum(k) == 0:
|
611 |
+
return
|
612 |
+
k = k / sum(k)
|
613 |
+
k = convolve2d(k, fspecial_gauss(3, 1), "same")
|
614 |
+
k = k / sum(k)
|
615 |
+
return k
|
616 |
+
|
617 |
+
|
618 |
+
def randomTrajectory(T):
|
619 |
+
x = zeros((3, T))
|
620 |
+
v = randn(3, T)
|
621 |
+
r = zeros((3, T))
|
622 |
+
trv = 1 / 1
|
623 |
+
trr = 2 * pi / T
|
624 |
+
for t in range(1, T):
|
625 |
+
F_rot = randn(3) / (t + 1) + r[:, t - 1]
|
626 |
+
F_trans = randn(3) / (t + 1)
|
627 |
+
r[:, t] = r[:, t - 1] + trr * F_rot
|
628 |
+
v[:, t] = v[:, t - 1] + trv * F_trans
|
629 |
+
st = v[:, t]
|
630 |
+
st = rot3D(st, r[:, t])
|
631 |
+
x[:, t] = x[:, t - 1] + st
|
632 |
+
return x
|
633 |
+
|
634 |
+
|
635 |
+
def rot3D(x, r):
|
636 |
+
Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]])
|
637 |
+
Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]])
|
638 |
+
Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]])
|
639 |
+
R = Rz @ Ry @ Rx
|
640 |
+
x = R @ x
|
641 |
+
return x
|
642 |
+
|
643 |
+
|
644 |
+
if __name__ == '__main__':
|
645 |
+
a = opt_fft_size([111])
|
646 |
+
print(a)
|
647 |
+
|
648 |
+
print(fspecial('gaussian', 5, 1))
|
649 |
+
|
650 |
+
print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape)
|
651 |
+
|
652 |
+
k = blurkernel_synthesis(11)
|
653 |
+
import matplotlib.pyplot as plt
|
654 |
+
plt.imshow(k, interpolation="nearest", cmap="gray")
|
655 |
+
plt.show()
|
core/data/deg_kair_utils/utils_dist.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
import torch.multiprocessing as mp
|
8 |
+
|
9 |
+
|
10 |
+
# ----------------------------------
|
11 |
+
# init
|
12 |
+
# ----------------------------------
|
13 |
+
def init_dist(launcher, backend='nccl', **kwargs):
|
14 |
+
if mp.get_start_method(allow_none=True) is None:
|
15 |
+
mp.set_start_method('spawn')
|
16 |
+
if launcher == 'pytorch':
|
17 |
+
_init_dist_pytorch(backend, **kwargs)
|
18 |
+
elif launcher == 'slurm':
|
19 |
+
_init_dist_slurm(backend, **kwargs)
|
20 |
+
else:
|
21 |
+
raise ValueError(f'Invalid launcher type: {launcher}')
|
22 |
+
|
23 |
+
|
24 |
+
def _init_dist_pytorch(backend, **kwargs):
|
25 |
+
rank = int(os.environ['RANK'])
|
26 |
+
num_gpus = torch.cuda.device_count()
|
27 |
+
torch.cuda.set_device(rank % num_gpus)
|
28 |
+
dist.init_process_group(backend=backend, **kwargs)
|
29 |
+
|
30 |
+
|
31 |
+
def _init_dist_slurm(backend, port=None):
|
32 |
+
"""Initialize slurm distributed training environment.
|
33 |
+
If argument ``port`` is not specified, then the master port will be system
|
34 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
35 |
+
environment variable, then a default port ``29500`` will be used.
|
36 |
+
Args:
|
37 |
+
backend (str): Backend of torch.distributed.
|
38 |
+
port (int, optional): Master port. Defaults to None.
|
39 |
+
"""
|
40 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
41 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
42 |
+
node_list = os.environ['SLURM_NODELIST']
|
43 |
+
num_gpus = torch.cuda.device_count()
|
44 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
45 |
+
addr = subprocess.getoutput(
|
46 |
+
f'scontrol show hostname {node_list} | head -n1')
|
47 |
+
# specify master port
|
48 |
+
if port is not None:
|
49 |
+
os.environ['MASTER_PORT'] = str(port)
|
50 |
+
elif 'MASTER_PORT' in os.environ:
|
51 |
+
pass # use MASTER_PORT in the environment variable
|
52 |
+
else:
|
53 |
+
# 29500 is torch.distributed default port
|
54 |
+
os.environ['MASTER_PORT'] = '29500'
|
55 |
+
os.environ['MASTER_ADDR'] = addr
|
56 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
57 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
58 |
+
os.environ['RANK'] = str(proc_id)
|
59 |
+
dist.init_process_group(backend=backend)
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
# ----------------------------------
|
64 |
+
# get rank and world_size
|
65 |
+
# ----------------------------------
|
66 |
+
def get_dist_info():
|
67 |
+
if dist.is_available():
|
68 |
+
initialized = dist.is_initialized()
|
69 |
+
else:
|
70 |
+
initialized = False
|
71 |
+
if initialized:
|
72 |
+
rank = dist.get_rank()
|
73 |
+
world_size = dist.get_world_size()
|
74 |
+
else:
|
75 |
+
rank = 0
|
76 |
+
world_size = 1
|
77 |
+
return rank, world_size
|
78 |
+
|
79 |
+
|
80 |
+
def get_rank():
|
81 |
+
if not dist.is_available():
|
82 |
+
return 0
|
83 |
+
|
84 |
+
if not dist.is_initialized():
|
85 |
+
return 0
|
86 |
+
|
87 |
+
return dist.get_rank()
|
88 |
+
|
89 |
+
|
90 |
+
def get_world_size():
|
91 |
+
if not dist.is_available():
|
92 |
+
return 1
|
93 |
+
|
94 |
+
if not dist.is_initialized():
|
95 |
+
return 1
|
96 |
+
|
97 |
+
return dist.get_world_size()
|
98 |
+
|
99 |
+
|
100 |
+
def master_only(func):
|
101 |
+
|
102 |
+
@functools.wraps(func)
|
103 |
+
def wrapper(*args, **kwargs):
|
104 |
+
rank, _ = get_dist_info()
|
105 |
+
if rank == 0:
|
106 |
+
return func(*args, **kwargs)
|
107 |
+
|
108 |
+
return wrapper
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
# ----------------------------------
|
116 |
+
# operation across ranks
|
117 |
+
# ----------------------------------
|
118 |
+
def reduce_sum(tensor):
|
119 |
+
if not dist.is_available():
|
120 |
+
return tensor
|
121 |
+
|
122 |
+
if not dist.is_initialized():
|
123 |
+
return tensor
|
124 |
+
|
125 |
+
tensor = tensor.clone()
|
126 |
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
127 |
+
|
128 |
+
return tensor
|
129 |
+
|
130 |
+
|
131 |
+
def gather_grad(params):
|
132 |
+
world_size = get_world_size()
|
133 |
+
|
134 |
+
if world_size == 1:
|
135 |
+
return
|
136 |
+
|
137 |
+
for param in params:
|
138 |
+
if param.grad is not None:
|
139 |
+
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
|
140 |
+
param.grad.data.div_(world_size)
|
141 |
+
|
142 |
+
|
143 |
+
def all_gather(data):
|
144 |
+
world_size = get_world_size()
|
145 |
+
|
146 |
+
if world_size == 1:
|
147 |
+
return [data]
|
148 |
+
|
149 |
+
buffer = pickle.dumps(data)
|
150 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
151 |
+
tensor = torch.ByteTensor(storage).to('cuda')
|
152 |
+
|
153 |
+
local_size = torch.IntTensor([tensor.numel()]).to('cuda')
|
154 |
+
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
|
155 |
+
dist.all_gather(size_list, local_size)
|
156 |
+
size_list = [int(size.item()) for size in size_list]
|
157 |
+
max_size = max(size_list)
|
158 |
+
|
159 |
+
tensor_list = []
|
160 |
+
for _ in size_list:
|
161 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
|
162 |
+
|
163 |
+
if local_size != max_size:
|
164 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
|
165 |
+
tensor = torch.cat((tensor, padding), 0)
|
166 |
+
|
167 |
+
dist.all_gather(tensor_list, tensor)
|
168 |
+
|
169 |
+
data_list = []
|
170 |
+
|
171 |
+
for size, tensor in zip(size_list, tensor_list):
|
172 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
173 |
+
data_list.append(pickle.loads(buffer))
|
174 |
+
|
175 |
+
return data_list
|
176 |
+
|
177 |
+
|
178 |
+
def reduce_loss_dict(loss_dict):
|
179 |
+
world_size = get_world_size()
|
180 |
+
|
181 |
+
if world_size < 2:
|
182 |
+
return loss_dict
|
183 |
+
|
184 |
+
with torch.no_grad():
|
185 |
+
keys = []
|
186 |
+
losses = []
|
187 |
+
|
188 |
+
for k in sorted(loss_dict.keys()):
|
189 |
+
keys.append(k)
|
190 |
+
losses.append(loss_dict[k])
|
191 |
+
|
192 |
+
losses = torch.stack(losses, 0)
|
193 |
+
dist.reduce(losses, dst=0)
|
194 |
+
|
195 |
+
if dist.get_rank() == 0:
|
196 |
+
losses /= world_size
|
197 |
+
|
198 |
+
reduced_losses = {k: v for k, v in zip(keys, losses)}
|
199 |
+
|
200 |
+
return reduced_losses
|
201 |
+
|
core/data/deg_kair_utils/utils_googledownload.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
'''
|
7 |
+
borrowed from
|
8 |
+
https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py
|
9 |
+
'''
|
10 |
+
|
11 |
+
|
12 |
+
def sizeof_fmt(size, suffix='B'):
|
13 |
+
"""Get human readable file size.
|
14 |
+
Args:
|
15 |
+
size (int): File size.
|
16 |
+
suffix (str): Suffix. Default: 'B'.
|
17 |
+
Return:
|
18 |
+
str: Formated file siz.
|
19 |
+
"""
|
20 |
+
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
21 |
+
if abs(size) < 1024.0:
|
22 |
+
return f'{size:3.1f} {unit}{suffix}'
|
23 |
+
size /= 1024.0
|
24 |
+
return f'{size:3.1f} Y{suffix}'
|
25 |
+
|
26 |
+
|
27 |
+
def download_file_from_google_drive(file_id, save_path):
|
28 |
+
"""Download files from google drive.
|
29 |
+
Ref:
|
30 |
+
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
31 |
+
Args:
|
32 |
+
file_id (str): File id.
|
33 |
+
save_path (str): Save path.
|
34 |
+
"""
|
35 |
+
|
36 |
+
session = requests.Session()
|
37 |
+
URL = 'https://docs.google.com/uc?export=download'
|
38 |
+
params = {'id': file_id}
|
39 |
+
|
40 |
+
response = session.get(URL, params=params, stream=True)
|
41 |
+
token = get_confirm_token(response)
|
42 |
+
if token:
|
43 |
+
params['confirm'] = token
|
44 |
+
response = session.get(URL, params=params, stream=True)
|
45 |
+
|
46 |
+
# get file size
|
47 |
+
response_file_size = session.get(
|
48 |
+
URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
49 |
+
if 'Content-Range' in response_file_size.headers:
|
50 |
+
file_size = int(
|
51 |
+
response_file_size.headers['Content-Range'].split('/')[1])
|
52 |
+
else:
|
53 |
+
file_size = None
|
54 |
+
|
55 |
+
save_response_content(response, save_path, file_size)
|
56 |
+
|
57 |
+
|
58 |
+
def get_confirm_token(response):
|
59 |
+
for key, value in response.cookies.items():
|
60 |
+
if key.startswith('download_warning'):
|
61 |
+
return value
|
62 |
+
return None
|
63 |
+
|
64 |
+
|
65 |
+
def save_response_content(response,
|
66 |
+
destination,
|
67 |
+
file_size=None,
|
68 |
+
chunk_size=32768):
|
69 |
+
if file_size is not None:
|
70 |
+
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
71 |
+
|
72 |
+
readable_file_size = sizeof_fmt(file_size)
|
73 |
+
else:
|
74 |
+
pbar = None
|
75 |
+
|
76 |
+
with open(destination, 'wb') as f:
|
77 |
+
downloaded_size = 0
|
78 |
+
for chunk in response.iter_content(chunk_size):
|
79 |
+
downloaded_size += chunk_size
|
80 |
+
if pbar is not None:
|
81 |
+
pbar.update(1)
|
82 |
+
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
|
83 |
+
f'/ {readable_file_size}')
|
84 |
+
if chunk: # filter out keep-alive new chunks
|
85 |
+
f.write(chunk)
|
86 |
+
if pbar is not None:
|
87 |
+
pbar.close()
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv'
|
92 |
+
save_path = 'BSRGAN.pth'
|
93 |
+
download_file_from_google_drive(file_id, save_path)
|
core/data/deg_kair_utils/utils_image.py
ADDED
@@ -0,0 +1,1016 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
from torchvision.utils import make_grid
|
8 |
+
from datetime import datetime
|
9 |
+
# import torchvision.transforms as transforms
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from mpl_toolkits.mplot3d import Axes3D
|
12 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
13 |
+
|
14 |
+
|
15 |
+
'''
|
16 |
+
# --------------------------------------------
|
17 |
+
# Kai Zhang (github: https://github.com/cszn)
|
18 |
+
# 03/Mar/2019
|
19 |
+
# --------------------------------------------
|
20 |
+
# https://github.com/twhui/SRGAN-pyTorch
|
21 |
+
# https://github.com/xinntao/BasicSR
|
22 |
+
# --------------------------------------------
|
23 |
+
'''
|
24 |
+
|
25 |
+
|
26 |
+
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
|
27 |
+
|
28 |
+
|
29 |
+
def is_image_file(filename):
|
30 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
31 |
+
|
32 |
+
|
33 |
+
def get_timestamp():
|
34 |
+
return datetime.now().strftime('%y%m%d-%H%M%S')
|
35 |
+
|
36 |
+
|
37 |
+
def imshow(x, title=None, cbar=False, figsize=None):
|
38 |
+
plt.figure(figsize=figsize)
|
39 |
+
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
|
40 |
+
if title:
|
41 |
+
plt.title(title)
|
42 |
+
if cbar:
|
43 |
+
plt.colorbar()
|
44 |
+
plt.show()
|
45 |
+
|
46 |
+
|
47 |
+
def surf(Z, cmap='rainbow', figsize=None):
|
48 |
+
plt.figure(figsize=figsize)
|
49 |
+
ax3 = plt.axes(projection='3d')
|
50 |
+
|
51 |
+
w, h = Z.shape[:2]
|
52 |
+
xx = np.arange(0,w,1)
|
53 |
+
yy = np.arange(0,h,1)
|
54 |
+
X, Y = np.meshgrid(xx, yy)
|
55 |
+
ax3.plot_surface(X,Y,Z,cmap=cmap)
|
56 |
+
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
57 |
+
plt.show()
|
58 |
+
|
59 |
+
|
60 |
+
'''
|
61 |
+
# --------------------------------------------
|
62 |
+
# get image pathes
|
63 |
+
# --------------------------------------------
|
64 |
+
'''
|
65 |
+
|
66 |
+
|
67 |
+
def get_image_paths(dataroot):
|
68 |
+
paths = None # return None if dataroot is None
|
69 |
+
if isinstance(dataroot, str):
|
70 |
+
paths = sorted(_get_paths_from_images(dataroot))
|
71 |
+
elif isinstance(dataroot, list):
|
72 |
+
paths = []
|
73 |
+
for i in dataroot:
|
74 |
+
paths += sorted(_get_paths_from_images(i))
|
75 |
+
return paths
|
76 |
+
|
77 |
+
|
78 |
+
def _get_paths_from_images(path):
|
79 |
+
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
80 |
+
images = []
|
81 |
+
for dirpath, _, fnames in sorted(os.walk(path)):
|
82 |
+
for fname in sorted(fnames):
|
83 |
+
if is_image_file(fname):
|
84 |
+
img_path = os.path.join(dirpath, fname)
|
85 |
+
images.append(img_path)
|
86 |
+
assert images, '{:s} has no valid image file'.format(path)
|
87 |
+
return images
|
88 |
+
|
89 |
+
|
90 |
+
'''
|
91 |
+
# --------------------------------------------
|
92 |
+
# split large images into small images
|
93 |
+
# --------------------------------------------
|
94 |
+
'''
|
95 |
+
|
96 |
+
|
97 |
+
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
98 |
+
w, h = img.shape[:2]
|
99 |
+
patches = []
|
100 |
+
if w > p_max and h > p_max:
|
101 |
+
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
|
102 |
+
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
|
103 |
+
w1.append(w-p_size)
|
104 |
+
h1.append(h-p_size)
|
105 |
+
# print(w1)
|
106 |
+
# print(h1)
|
107 |
+
for i in w1:
|
108 |
+
for j in h1:
|
109 |
+
patches.append(img[i:i+p_size, j:j+p_size,:])
|
110 |
+
else:
|
111 |
+
patches.append(img)
|
112 |
+
|
113 |
+
return patches
|
114 |
+
|
115 |
+
|
116 |
+
def imssave(imgs, img_path):
|
117 |
+
"""
|
118 |
+
imgs: list, N images of size WxHxC
|
119 |
+
"""
|
120 |
+
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
121 |
+
for i, img in enumerate(imgs):
|
122 |
+
if img.ndim == 3:
|
123 |
+
img = img[:, :, [2, 1, 0]]
|
124 |
+
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png')
|
125 |
+
cv2.imwrite(new_path, img)
|
126 |
+
|
127 |
+
|
128 |
+
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800):
|
129 |
+
"""
|
130 |
+
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
131 |
+
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
132 |
+
will be splitted.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
original_dataroot:
|
136 |
+
taget_dataroot:
|
137 |
+
p_size: size of small images
|
138 |
+
p_overlap: patch size in training is a good choice
|
139 |
+
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
140 |
+
"""
|
141 |
+
paths = get_image_paths(original_dataroot)
|
142 |
+
for img_path in paths:
|
143 |
+
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
144 |
+
img = imread_uint(img_path, n_channels=n_channels)
|
145 |
+
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
146 |
+
imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
|
147 |
+
#if original_dataroot == taget_dataroot:
|
148 |
+
#del img_path
|
149 |
+
|
150 |
+
'''
|
151 |
+
# --------------------------------------------
|
152 |
+
# makedir
|
153 |
+
# --------------------------------------------
|
154 |
+
'''
|
155 |
+
|
156 |
+
|
157 |
+
def mkdir(path):
|
158 |
+
if not os.path.exists(path):
|
159 |
+
os.makedirs(path)
|
160 |
+
|
161 |
+
|
162 |
+
def mkdirs(paths):
|
163 |
+
if isinstance(paths, str):
|
164 |
+
mkdir(paths)
|
165 |
+
else:
|
166 |
+
for path in paths:
|
167 |
+
mkdir(path)
|
168 |
+
|
169 |
+
|
170 |
+
def mkdir_and_rename(path):
|
171 |
+
if os.path.exists(path):
|
172 |
+
new_name = path + '_archived_' + get_timestamp()
|
173 |
+
print('Path already exists. Rename it to [{:s}]'.format(new_name))
|
174 |
+
os.rename(path, new_name)
|
175 |
+
os.makedirs(path)
|
176 |
+
|
177 |
+
|
178 |
+
'''
|
179 |
+
# --------------------------------------------
|
180 |
+
# read image from path
|
181 |
+
# opencv is fast, but read BGR numpy image
|
182 |
+
# --------------------------------------------
|
183 |
+
'''
|
184 |
+
|
185 |
+
|
186 |
+
# --------------------------------------------
|
187 |
+
# get uint8 image of size HxWxn_channles (RGB)
|
188 |
+
# --------------------------------------------
|
189 |
+
def imread_uint(path, n_channels=3):
|
190 |
+
# input: path
|
191 |
+
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
192 |
+
if n_channels == 1:
|
193 |
+
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
194 |
+
img = np.expand_dims(img, axis=2) # HxWx1
|
195 |
+
elif n_channels == 3:
|
196 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
197 |
+
if img.ndim == 2:
|
198 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
199 |
+
else:
|
200 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
201 |
+
return img
|
202 |
+
|
203 |
+
|
204 |
+
# --------------------------------------------
|
205 |
+
# matlab's imwrite
|
206 |
+
# --------------------------------------------
|
207 |
+
def imsave(img, img_path):
|
208 |
+
img = np.squeeze(img)
|
209 |
+
if img.ndim == 3:
|
210 |
+
img = img[:, :, [2, 1, 0]]
|
211 |
+
cv2.imwrite(img_path, img)
|
212 |
+
|
213 |
+
def imwrite(img, img_path):
|
214 |
+
img = np.squeeze(img)
|
215 |
+
if img.ndim == 3:
|
216 |
+
img = img[:, :, [2, 1, 0]]
|
217 |
+
cv2.imwrite(img_path, img)
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
# --------------------------------------------
|
222 |
+
# get single image of size HxWxn_channles (BGR)
|
223 |
+
# --------------------------------------------
|
224 |
+
def read_img(path):
|
225 |
+
# read image by cv2
|
226 |
+
# return: Numpy float32, HWC, BGR, [0,1]
|
227 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
228 |
+
img = img.astype(np.float32) / 255.
|
229 |
+
if img.ndim == 2:
|
230 |
+
img = np.expand_dims(img, axis=2)
|
231 |
+
# some images have 4 channels
|
232 |
+
if img.shape[2] > 3:
|
233 |
+
img = img[:, :, :3]
|
234 |
+
return img
|
235 |
+
|
236 |
+
|
237 |
+
'''
|
238 |
+
# --------------------------------------------
|
239 |
+
# image format conversion
|
240 |
+
# --------------------------------------------
|
241 |
+
# numpy(single) <---> numpy(uint)
|
242 |
+
# numpy(single) <---> tensor
|
243 |
+
# numpy(uint) <---> tensor
|
244 |
+
# --------------------------------------------
|
245 |
+
'''
|
246 |
+
|
247 |
+
|
248 |
+
# --------------------------------------------
|
249 |
+
# numpy(single) [0, 1] <---> numpy(uint)
|
250 |
+
# --------------------------------------------
|
251 |
+
|
252 |
+
|
253 |
+
def uint2single(img):
|
254 |
+
|
255 |
+
return np.float32(img/255.)
|
256 |
+
|
257 |
+
|
258 |
+
def single2uint(img):
|
259 |
+
|
260 |
+
return np.uint8((img.clip(0, 1)*255.).round())
|
261 |
+
|
262 |
+
|
263 |
+
def uint162single(img):
|
264 |
+
|
265 |
+
return np.float32(img/65535.)
|
266 |
+
|
267 |
+
|
268 |
+
def single2uint16(img):
|
269 |
+
|
270 |
+
return np.uint16((img.clip(0, 1)*65535.).round())
|
271 |
+
|
272 |
+
|
273 |
+
# --------------------------------------------
|
274 |
+
# numpy(uint) (HxWxC or HxW) <---> tensor
|
275 |
+
# --------------------------------------------
|
276 |
+
|
277 |
+
|
278 |
+
# convert uint to 4-dimensional torch tensor
|
279 |
+
def uint2tensor4(img):
|
280 |
+
if img.ndim == 2:
|
281 |
+
img = np.expand_dims(img, axis=2)
|
282 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
|
283 |
+
|
284 |
+
|
285 |
+
# convert uint to 3-dimensional torch tensor
|
286 |
+
def uint2tensor3(img):
|
287 |
+
if img.ndim == 2:
|
288 |
+
img = np.expand_dims(img, axis=2)
|
289 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
|
290 |
+
|
291 |
+
|
292 |
+
# convert 2/3/4-dimensional torch tensor to uint
|
293 |
+
def tensor2uint(img):
|
294 |
+
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
295 |
+
if img.ndim == 3:
|
296 |
+
img = np.transpose(img, (1, 2, 0))
|
297 |
+
return np.uint8((img*255.0).round())
|
298 |
+
|
299 |
+
|
300 |
+
# --------------------------------------------
|
301 |
+
# numpy(single) (HxWxC) <---> tensor
|
302 |
+
# --------------------------------------------
|
303 |
+
|
304 |
+
|
305 |
+
# convert single (HxWxC) to 3-dimensional torch tensor
|
306 |
+
def single2tensor3(img):
|
307 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
308 |
+
|
309 |
+
|
310 |
+
# convert single (HxWxC) to 4-dimensional torch tensor
|
311 |
+
def single2tensor4(img):
|
312 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
313 |
+
|
314 |
+
|
315 |
+
# convert torch tensor to single
|
316 |
+
def tensor2single(img):
|
317 |
+
img = img.data.squeeze().float().cpu().numpy()
|
318 |
+
if img.ndim == 3:
|
319 |
+
img = np.transpose(img, (1, 2, 0))
|
320 |
+
|
321 |
+
return img
|
322 |
+
|
323 |
+
# convert torch tensor to single
|
324 |
+
def tensor2single3(img):
|
325 |
+
img = img.data.squeeze().float().cpu().numpy()
|
326 |
+
if img.ndim == 3:
|
327 |
+
img = np.transpose(img, (1, 2, 0))
|
328 |
+
elif img.ndim == 2:
|
329 |
+
img = np.expand_dims(img, axis=2)
|
330 |
+
return img
|
331 |
+
|
332 |
+
|
333 |
+
def single2tensor5(img):
|
334 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
335 |
+
|
336 |
+
|
337 |
+
def single32tensor5(img):
|
338 |
+
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
339 |
+
|
340 |
+
|
341 |
+
def single42tensor4(img):
|
342 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
343 |
+
|
344 |
+
|
345 |
+
# from skimage.io import imread, imsave
|
346 |
+
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
347 |
+
'''
|
348 |
+
Converts a torch Tensor into an image Numpy array of BGR channel order
|
349 |
+
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
350 |
+
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
351 |
+
'''
|
352 |
+
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
353 |
+
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
354 |
+
n_dim = tensor.dim()
|
355 |
+
if n_dim == 4:
|
356 |
+
n_img = len(tensor)
|
357 |
+
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
358 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
359 |
+
elif n_dim == 3:
|
360 |
+
img_np = tensor.numpy()
|
361 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
362 |
+
elif n_dim == 2:
|
363 |
+
img_np = tensor.numpy()
|
364 |
+
else:
|
365 |
+
raise TypeError(
|
366 |
+
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
367 |
+
if out_type == np.uint8:
|
368 |
+
img_np = (img_np * 255.0).round()
|
369 |
+
# Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
|
370 |
+
return img_np.astype(out_type)
|
371 |
+
|
372 |
+
|
373 |
+
'''
|
374 |
+
# --------------------------------------------
|
375 |
+
# Augmentation, flipe and/or rotate
|
376 |
+
# --------------------------------------------
|
377 |
+
# The following two are enough.
|
378 |
+
# (1) augmet_img: numpy image of WxHxC or WxH
|
379 |
+
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
380 |
+
# --------------------------------------------
|
381 |
+
'''
|
382 |
+
|
383 |
+
|
384 |
+
def augment_img(img, mode=0):
|
385 |
+
'''Kai Zhang (github: https://github.com/cszn)
|
386 |
+
'''
|
387 |
+
if mode == 0:
|
388 |
+
return img
|
389 |
+
elif mode == 1:
|
390 |
+
return np.flipud(np.rot90(img))
|
391 |
+
elif mode == 2:
|
392 |
+
return np.flipud(img)
|
393 |
+
elif mode == 3:
|
394 |
+
return np.rot90(img, k=3)
|
395 |
+
elif mode == 4:
|
396 |
+
return np.flipud(np.rot90(img, k=2))
|
397 |
+
elif mode == 5:
|
398 |
+
return np.rot90(img)
|
399 |
+
elif mode == 6:
|
400 |
+
return np.rot90(img, k=2)
|
401 |
+
elif mode == 7:
|
402 |
+
return np.flipud(np.rot90(img, k=3))
|
403 |
+
|
404 |
+
|
405 |
+
def augment_img_tensor4(img, mode=0):
|
406 |
+
'''Kai Zhang (github: https://github.com/cszn)
|
407 |
+
'''
|
408 |
+
if mode == 0:
|
409 |
+
return img
|
410 |
+
elif mode == 1:
|
411 |
+
return img.rot90(1, [2, 3]).flip([2])
|
412 |
+
elif mode == 2:
|
413 |
+
return img.flip([2])
|
414 |
+
elif mode == 3:
|
415 |
+
return img.rot90(3, [2, 3])
|
416 |
+
elif mode == 4:
|
417 |
+
return img.rot90(2, [2, 3]).flip([2])
|
418 |
+
elif mode == 5:
|
419 |
+
return img.rot90(1, [2, 3])
|
420 |
+
elif mode == 6:
|
421 |
+
return img.rot90(2, [2, 3])
|
422 |
+
elif mode == 7:
|
423 |
+
return img.rot90(3, [2, 3]).flip([2])
|
424 |
+
|
425 |
+
|
426 |
+
def augment_img_tensor(img, mode=0):
|
427 |
+
'''Kai Zhang (github: https://github.com/cszn)
|
428 |
+
'''
|
429 |
+
img_size = img.size()
|
430 |
+
img_np = img.data.cpu().numpy()
|
431 |
+
if len(img_size) == 3:
|
432 |
+
img_np = np.transpose(img_np, (1, 2, 0))
|
433 |
+
elif len(img_size) == 4:
|
434 |
+
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
435 |
+
img_np = augment_img(img_np, mode=mode)
|
436 |
+
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
437 |
+
if len(img_size) == 3:
|
438 |
+
img_tensor = img_tensor.permute(2, 0, 1)
|
439 |
+
elif len(img_size) == 4:
|
440 |
+
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
441 |
+
|
442 |
+
return img_tensor.type_as(img)
|
443 |
+
|
444 |
+
|
445 |
+
def augment_img_np3(img, mode=0):
|
446 |
+
if mode == 0:
|
447 |
+
return img
|
448 |
+
elif mode == 1:
|
449 |
+
return img.transpose(1, 0, 2)
|
450 |
+
elif mode == 2:
|
451 |
+
return img[::-1, :, :]
|
452 |
+
elif mode == 3:
|
453 |
+
img = img[::-1, :, :]
|
454 |
+
img = img.transpose(1, 0, 2)
|
455 |
+
return img
|
456 |
+
elif mode == 4:
|
457 |
+
return img[:, ::-1, :]
|
458 |
+
elif mode == 5:
|
459 |
+
img = img[:, ::-1, :]
|
460 |
+
img = img.transpose(1, 0, 2)
|
461 |
+
return img
|
462 |
+
elif mode == 6:
|
463 |
+
img = img[:, ::-1, :]
|
464 |
+
img = img[::-1, :, :]
|
465 |
+
return img
|
466 |
+
elif mode == 7:
|
467 |
+
img = img[:, ::-1, :]
|
468 |
+
img = img[::-1, :, :]
|
469 |
+
img = img.transpose(1, 0, 2)
|
470 |
+
return img
|
471 |
+
|
472 |
+
|
473 |
+
def augment_imgs(img_list, hflip=True, rot=True):
|
474 |
+
# horizontal flip OR rotate
|
475 |
+
hflip = hflip and random.random() < 0.5
|
476 |
+
vflip = rot and random.random() < 0.5
|
477 |
+
rot90 = rot and random.random() < 0.5
|
478 |
+
|
479 |
+
def _augment(img):
|
480 |
+
if hflip:
|
481 |
+
img = img[:, ::-1, :]
|
482 |
+
if vflip:
|
483 |
+
img = img[::-1, :, :]
|
484 |
+
if rot90:
|
485 |
+
img = img.transpose(1, 0, 2)
|
486 |
+
return img
|
487 |
+
|
488 |
+
return [_augment(img) for img in img_list]
|
489 |
+
|
490 |
+
|
491 |
+
'''
|
492 |
+
# --------------------------------------------
|
493 |
+
# modcrop and shave
|
494 |
+
# --------------------------------------------
|
495 |
+
'''
|
496 |
+
|
497 |
+
|
498 |
+
def modcrop(img_in, scale):
|
499 |
+
# img_in: Numpy, HWC or HW
|
500 |
+
img = np.copy(img_in)
|
501 |
+
if img.ndim == 2:
|
502 |
+
H, W = img.shape
|
503 |
+
H_r, W_r = H % scale, W % scale
|
504 |
+
img = img[:H - H_r, :W - W_r]
|
505 |
+
elif img.ndim == 3:
|
506 |
+
H, W, C = img.shape
|
507 |
+
H_r, W_r = H % scale, W % scale
|
508 |
+
img = img[:H - H_r, :W - W_r, :]
|
509 |
+
else:
|
510 |
+
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
511 |
+
return img
|
512 |
+
|
513 |
+
|
514 |
+
def shave(img_in, border=0):
|
515 |
+
# img_in: Numpy, HWC or HW
|
516 |
+
img = np.copy(img_in)
|
517 |
+
h, w = img.shape[:2]
|
518 |
+
img = img[border:h-border, border:w-border]
|
519 |
+
return img
|
520 |
+
|
521 |
+
|
522 |
+
'''
|
523 |
+
# --------------------------------------------
|
524 |
+
# image processing process on numpy image
|
525 |
+
# channel_convert(in_c, tar_type, img_list):
|
526 |
+
# rgb2ycbcr(img, only_y=True):
|
527 |
+
# bgr2ycbcr(img, only_y=True):
|
528 |
+
# ycbcr2rgb(img):
|
529 |
+
# --------------------------------------------
|
530 |
+
'''
|
531 |
+
|
532 |
+
|
533 |
+
def rgb2ycbcr(img, only_y=True):
|
534 |
+
'''same as matlab rgb2ycbcr
|
535 |
+
only_y: only return Y channel
|
536 |
+
Input:
|
537 |
+
uint8, [0, 255]
|
538 |
+
float, [0, 1]
|
539 |
+
'''
|
540 |
+
in_img_type = img.dtype
|
541 |
+
img.astype(np.float32)
|
542 |
+
if in_img_type != np.uint8:
|
543 |
+
img *= 255.
|
544 |
+
# convert
|
545 |
+
if only_y:
|
546 |
+
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
547 |
+
else:
|
548 |
+
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
549 |
+
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
550 |
+
if in_img_type == np.uint8:
|
551 |
+
rlt = rlt.round()
|
552 |
+
else:
|
553 |
+
rlt /= 255.
|
554 |
+
return rlt.astype(in_img_type)
|
555 |
+
|
556 |
+
|
557 |
+
def ycbcr2rgb(img):
|
558 |
+
'''same as matlab ycbcr2rgb
|
559 |
+
Input:
|
560 |
+
uint8, [0, 255]
|
561 |
+
float, [0, 1]
|
562 |
+
'''
|
563 |
+
in_img_type = img.dtype
|
564 |
+
img.astype(np.float32)
|
565 |
+
if in_img_type != np.uint8:
|
566 |
+
img *= 255.
|
567 |
+
# convert
|
568 |
+
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
569 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
570 |
+
rlt = np.clip(rlt, 0, 255)
|
571 |
+
if in_img_type == np.uint8:
|
572 |
+
rlt = rlt.round()
|
573 |
+
else:
|
574 |
+
rlt /= 255.
|
575 |
+
return rlt.astype(in_img_type)
|
576 |
+
|
577 |
+
|
578 |
+
def bgr2ycbcr(img, only_y=True):
|
579 |
+
'''bgr version of rgb2ycbcr
|
580 |
+
only_y: only return Y channel
|
581 |
+
Input:
|
582 |
+
uint8, [0, 255]
|
583 |
+
float, [0, 1]
|
584 |
+
'''
|
585 |
+
in_img_type = img.dtype
|
586 |
+
img.astype(np.float32)
|
587 |
+
if in_img_type != np.uint8:
|
588 |
+
img *= 255.
|
589 |
+
# convert
|
590 |
+
if only_y:
|
591 |
+
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
592 |
+
else:
|
593 |
+
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
594 |
+
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
595 |
+
if in_img_type == np.uint8:
|
596 |
+
rlt = rlt.round()
|
597 |
+
else:
|
598 |
+
rlt /= 255.
|
599 |
+
return rlt.astype(in_img_type)
|
600 |
+
|
601 |
+
|
602 |
+
def channel_convert(in_c, tar_type, img_list):
|
603 |
+
# conversion among BGR, gray and y
|
604 |
+
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
605 |
+
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
606 |
+
return [np.expand_dims(img, axis=2) for img in gray_list]
|
607 |
+
elif in_c == 3 and tar_type == 'y': # BGR to y
|
608 |
+
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
609 |
+
return [np.expand_dims(img, axis=2) for img in y_list]
|
610 |
+
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
611 |
+
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
612 |
+
else:
|
613 |
+
return img_list
|
614 |
+
|
615 |
+
|
616 |
+
'''
|
617 |
+
# --------------------------------------------
|
618 |
+
# metric, PSNR, SSIM and PSNRB
|
619 |
+
# --------------------------------------------
|
620 |
+
'''
|
621 |
+
|
622 |
+
|
623 |
+
# --------------------------------------------
|
624 |
+
# PSNR
|
625 |
+
# --------------------------------------------
|
626 |
+
def calculate_psnr(img1, img2, border=0):
|
627 |
+
# img1 and img2 have range [0, 255]
|
628 |
+
#img1 = img1.squeeze()
|
629 |
+
#img2 = img2.squeeze()
|
630 |
+
if not img1.shape == img2.shape:
|
631 |
+
raise ValueError('Input images must have the same dimensions.')
|
632 |
+
h, w = img1.shape[:2]
|
633 |
+
img1 = img1[border:h-border, border:w-border]
|
634 |
+
img2 = img2[border:h-border, border:w-border]
|
635 |
+
|
636 |
+
img1 = img1.astype(np.float64)
|
637 |
+
img2 = img2.astype(np.float64)
|
638 |
+
mse = np.mean((img1 - img2)**2)
|
639 |
+
if mse == 0:
|
640 |
+
return float('inf')
|
641 |
+
return 20 * math.log10(255.0 / math.sqrt(mse))
|
642 |
+
|
643 |
+
|
644 |
+
# --------------------------------------------
|
645 |
+
# SSIM
|
646 |
+
# --------------------------------------------
|
647 |
+
def calculate_ssim(img1, img2, border=0):
|
648 |
+
'''calculate SSIM
|
649 |
+
the same outputs as MATLAB's
|
650 |
+
img1, img2: [0, 255]
|
651 |
+
'''
|
652 |
+
#img1 = img1.squeeze()
|
653 |
+
#img2 = img2.squeeze()
|
654 |
+
if not img1.shape == img2.shape:
|
655 |
+
raise ValueError('Input images must have the same dimensions.')
|
656 |
+
h, w = img1.shape[:2]
|
657 |
+
img1 = img1[border:h-border, border:w-border]
|
658 |
+
img2 = img2[border:h-border, border:w-border]
|
659 |
+
|
660 |
+
if img1.ndim == 2:
|
661 |
+
return ssim(img1, img2)
|
662 |
+
elif img1.ndim == 3:
|
663 |
+
if img1.shape[2] == 3:
|
664 |
+
ssims = []
|
665 |
+
for i in range(3):
|
666 |
+
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
|
667 |
+
return np.array(ssims).mean()
|
668 |
+
elif img1.shape[2] == 1:
|
669 |
+
return ssim(np.squeeze(img1), np.squeeze(img2))
|
670 |
+
else:
|
671 |
+
raise ValueError('Wrong input image dimensions.')
|
672 |
+
|
673 |
+
|
674 |
+
def ssim(img1, img2):
|
675 |
+
C1 = (0.01 * 255)**2
|
676 |
+
C2 = (0.03 * 255)**2
|
677 |
+
|
678 |
+
img1 = img1.astype(np.float64)
|
679 |
+
img2 = img2.astype(np.float64)
|
680 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
681 |
+
window = np.outer(kernel, kernel.transpose())
|
682 |
+
|
683 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
684 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
685 |
+
mu1_sq = mu1**2
|
686 |
+
mu2_sq = mu2**2
|
687 |
+
mu1_mu2 = mu1 * mu2
|
688 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
689 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
690 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
691 |
+
|
692 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
693 |
+
(sigma1_sq + sigma2_sq + C2))
|
694 |
+
return ssim_map.mean()
|
695 |
+
|
696 |
+
|
697 |
+
def _blocking_effect_factor(im):
|
698 |
+
block_size = 8
|
699 |
+
|
700 |
+
block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
|
701 |
+
block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
|
702 |
+
|
703 |
+
horizontal_block_difference = (
|
704 |
+
(im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
|
705 |
+
3).sum(2).sum(1)
|
706 |
+
vertical_block_difference = (
|
707 |
+
(im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
|
708 |
+
2).sum(1)
|
709 |
+
|
710 |
+
nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
|
711 |
+
nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
|
712 |
+
|
713 |
+
horizontal_nonblock_difference = (
|
714 |
+
(im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
|
715 |
+
3).sum(2).sum(1)
|
716 |
+
vertical_nonblock_difference = (
|
717 |
+
(im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
|
718 |
+
3).sum(2).sum(1)
|
719 |
+
|
720 |
+
n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
|
721 |
+
n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
|
722 |
+
boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
|
723 |
+
n_boundary_horiz + n_boundary_vert)
|
724 |
+
|
725 |
+
n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
|
726 |
+
n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
|
727 |
+
nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
|
728 |
+
n_nonboundary_horiz + n_nonboundary_vert)
|
729 |
+
|
730 |
+
scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
|
731 |
+
bef = scaler * (boundary_difference - nonboundary_difference)
|
732 |
+
|
733 |
+
bef[boundary_difference <= nonboundary_difference] = 0
|
734 |
+
return bef
|
735 |
+
|
736 |
+
|
737 |
+
def calculate_psnrb(img1, img2, border=0):
|
738 |
+
"""Calculate PSNR-B (Peak Signal-to-Noise Ratio).
|
739 |
+
Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
|
740 |
+
# https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
|
741 |
+
Args:
|
742 |
+
img1 (ndarray): Images with range [0, 255].
|
743 |
+
img2 (ndarray): Images with range [0, 255].
|
744 |
+
border (int): Cropped pixels in each edge of an image. These
|
745 |
+
pixels are not involved in the PSNR calculation.
|
746 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
747 |
+
Returns:
|
748 |
+
float: psnr result.
|
749 |
+
"""
|
750 |
+
|
751 |
+
if not img1.shape == img2.shape:
|
752 |
+
raise ValueError('Input images must have the same dimensions.')
|
753 |
+
|
754 |
+
if img1.ndim == 2:
|
755 |
+
img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2)
|
756 |
+
|
757 |
+
h, w = img1.shape[:2]
|
758 |
+
img1 = img1[border:h-border, border:w-border]
|
759 |
+
img2 = img2[border:h-border, border:w-border]
|
760 |
+
|
761 |
+
img1 = img1.astype(np.float64)
|
762 |
+
img2 = img2.astype(np.float64)
|
763 |
+
|
764 |
+
# follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
|
765 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
|
766 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
|
767 |
+
|
768 |
+
total = 0
|
769 |
+
for c in range(img1.shape[1]):
|
770 |
+
mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
|
771 |
+
bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
|
772 |
+
|
773 |
+
mse = mse.view(mse.shape[0], -1).mean(1)
|
774 |
+
total += 10 * torch.log10(1 / (mse + bef))
|
775 |
+
|
776 |
+
return float(total) / img1.shape[1]
|
777 |
+
|
778 |
+
'''
|
779 |
+
# --------------------------------------------
|
780 |
+
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
781 |
+
# --------------------------------------------
|
782 |
+
'''
|
783 |
+
|
784 |
+
|
785 |
+
# matlab 'imresize' function, now only support 'bicubic'
|
786 |
+
def cubic(x):
|
787 |
+
absx = torch.abs(x)
|
788 |
+
absx2 = absx**2
|
789 |
+
absx3 = absx**3
|
790 |
+
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
|
791 |
+
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
|
792 |
+
|
793 |
+
|
794 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
795 |
+
if (scale < 1) and (antialiasing):
|
796 |
+
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
797 |
+
kernel_width = kernel_width / scale
|
798 |
+
|
799 |
+
# Output-space coordinates
|
800 |
+
x = torch.linspace(1, out_length, out_length)
|
801 |
+
|
802 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
803 |
+
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
804 |
+
# space maps to 1.5 in input space.
|
805 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
806 |
+
|
807 |
+
# What is the left-most pixel that can be involved in the computation?
|
808 |
+
left = torch.floor(u - kernel_width / 2)
|
809 |
+
|
810 |
+
# What is the maximum number of pixels that can be involved in the
|
811 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
812 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
813 |
+
# of this function.
|
814 |
+
P = math.ceil(kernel_width) + 2
|
815 |
+
|
816 |
+
# The indices of the input pixels involved in computing the k-th output
|
817 |
+
# pixel are in row k of the indices matrix.
|
818 |
+
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
819 |
+
1, P).expand(out_length, P)
|
820 |
+
|
821 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
822 |
+
# weights matrix.
|
823 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
824 |
+
# apply cubic kernel
|
825 |
+
if (scale < 1) and (antialiasing):
|
826 |
+
weights = scale * cubic(distance_to_center * scale)
|
827 |
+
else:
|
828 |
+
weights = cubic(distance_to_center)
|
829 |
+
# Normalize the weights matrix so that each row sums to 1.
|
830 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
831 |
+
weights = weights / weights_sum.expand(out_length, P)
|
832 |
+
|
833 |
+
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
834 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
835 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
836 |
+
indices = indices.narrow(1, 1, P - 2)
|
837 |
+
weights = weights.narrow(1, 1, P - 2)
|
838 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
839 |
+
indices = indices.narrow(1, 0, P - 2)
|
840 |
+
weights = weights.narrow(1, 0, P - 2)
|
841 |
+
weights = weights.contiguous()
|
842 |
+
indices = indices.contiguous()
|
843 |
+
sym_len_s = -indices.min() + 1
|
844 |
+
sym_len_e = indices.max() - in_length
|
845 |
+
indices = indices + sym_len_s - 1
|
846 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
847 |
+
|
848 |
+
|
849 |
+
# --------------------------------------------
|
850 |
+
# imresize for tensor image [0, 1]
|
851 |
+
# --------------------------------------------
|
852 |
+
def imresize(img, scale, antialiasing=True):
|
853 |
+
# Now the scale should be the same for H and W
|
854 |
+
# input: img: pytorch tensor, CHW or HW [0,1]
|
855 |
+
# output: CHW or HW [0,1] w/o round
|
856 |
+
need_squeeze = True if img.dim() == 2 else False
|
857 |
+
if need_squeeze:
|
858 |
+
img.unsqueeze_(0)
|
859 |
+
in_C, in_H, in_W = img.size()
|
860 |
+
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
861 |
+
kernel_width = 4
|
862 |
+
kernel = 'cubic'
|
863 |
+
|
864 |
+
# Return the desired dimension order for performing the resize. The
|
865 |
+
# strategy is to perform the resize first along the dimension with the
|
866 |
+
# smallest scale factor.
|
867 |
+
# Now we do not support this.
|
868 |
+
|
869 |
+
# get weights and indices
|
870 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
871 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
872 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
873 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
874 |
+
# process H dimension
|
875 |
+
# symmetric copying
|
876 |
+
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
877 |
+
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
878 |
+
|
879 |
+
sym_patch = img[:, :sym_len_Hs, :]
|
880 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
881 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
882 |
+
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
883 |
+
|
884 |
+
sym_patch = img[:, -sym_len_He:, :]
|
885 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
886 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
887 |
+
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
888 |
+
|
889 |
+
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
890 |
+
kernel_width = weights_H.size(1)
|
891 |
+
for i in range(out_H):
|
892 |
+
idx = int(indices_H[i][0])
|
893 |
+
for j in range(out_C):
|
894 |
+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
895 |
+
|
896 |
+
# process W dimension
|
897 |
+
# symmetric copying
|
898 |
+
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
899 |
+
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
900 |
+
|
901 |
+
sym_patch = out_1[:, :, :sym_len_Ws]
|
902 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
903 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
904 |
+
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
905 |
+
|
906 |
+
sym_patch = out_1[:, :, -sym_len_We:]
|
907 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
908 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
909 |
+
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
910 |
+
|
911 |
+
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
912 |
+
kernel_width = weights_W.size(1)
|
913 |
+
for i in range(out_W):
|
914 |
+
idx = int(indices_W[i][0])
|
915 |
+
for j in range(out_C):
|
916 |
+
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
|
917 |
+
if need_squeeze:
|
918 |
+
out_2.squeeze_()
|
919 |
+
return out_2
|
920 |
+
|
921 |
+
|
922 |
+
# --------------------------------------------
|
923 |
+
# imresize for numpy image [0, 1]
|
924 |
+
# --------------------------------------------
|
925 |
+
def imresize_np(img, scale, antialiasing=True):
|
926 |
+
# Now the scale should be the same for H and W
|
927 |
+
# input: img: Numpy, HWC or HW [0,1]
|
928 |
+
# output: HWC or HW [0,1] w/o round
|
929 |
+
img = torch.from_numpy(img)
|
930 |
+
need_squeeze = True if img.dim() == 2 else False
|
931 |
+
if need_squeeze:
|
932 |
+
img.unsqueeze_(2)
|
933 |
+
|
934 |
+
in_H, in_W, in_C = img.size()
|
935 |
+
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
936 |
+
kernel_width = 4
|
937 |
+
kernel = 'cubic'
|
938 |
+
|
939 |
+
# Return the desired dimension order for performing the resize. The
|
940 |
+
# strategy is to perform the resize first along the dimension with the
|
941 |
+
# smallest scale factor.
|
942 |
+
# Now we do not support this.
|
943 |
+
|
944 |
+
# get weights and indices
|
945 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
946 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
947 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
948 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
949 |
+
# process H dimension
|
950 |
+
# symmetric copying
|
951 |
+
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
952 |
+
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
953 |
+
|
954 |
+
sym_patch = img[:sym_len_Hs, :, :]
|
955 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
956 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
957 |
+
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
958 |
+
|
959 |
+
sym_patch = img[-sym_len_He:, :, :]
|
960 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
961 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
962 |
+
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
963 |
+
|
964 |
+
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
965 |
+
kernel_width = weights_H.size(1)
|
966 |
+
for i in range(out_H):
|
967 |
+
idx = int(indices_H[i][0])
|
968 |
+
for j in range(out_C):
|
969 |
+
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
970 |
+
|
971 |
+
# process W dimension
|
972 |
+
# symmetric copying
|
973 |
+
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
974 |
+
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
975 |
+
|
976 |
+
sym_patch = out_1[:, :sym_len_Ws, :]
|
977 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
978 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
979 |
+
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
980 |
+
|
981 |
+
sym_patch = out_1[:, -sym_len_We:, :]
|
982 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
983 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
984 |
+
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
985 |
+
|
986 |
+
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
987 |
+
kernel_width = weights_W.size(1)
|
988 |
+
for i in range(out_W):
|
989 |
+
idx = int(indices_W[i][0])
|
990 |
+
for j in range(out_C):
|
991 |
+
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
|
992 |
+
if need_squeeze:
|
993 |
+
out_2.squeeze_()
|
994 |
+
|
995 |
+
return out_2.numpy()
|
996 |
+
|
997 |
+
|
998 |
+
if __name__ == '__main__':
|
999 |
+
img = imread_uint('test.bmp', 3)
|
1000 |
+
# img = uint2single(img)
|
1001 |
+
# img_bicubic = imresize_np(img, 1/4)
|
1002 |
+
# imshow(single2uint(img_bicubic))
|
1003 |
+
#
|
1004 |
+
# img_tensor = single2tensor4(img)
|
1005 |
+
# for i in range(8):
|
1006 |
+
# imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1))
|
1007 |
+
|
1008 |
+
# patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200)
|
1009 |
+
# imssave(patches,'a.png')
|
1010 |
+
|
1011 |
+
|
1012 |
+
|
1013 |
+
|
1014 |
+
|
1015 |
+
|
1016 |
+
|
core/data/deg_kair_utils/utils_lmdb.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import lmdb
|
3 |
+
import sys
|
4 |
+
from multiprocessing import Pool
|
5 |
+
from os import path as osp
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
def make_lmdb_from_imgs(data_path,
|
10 |
+
lmdb_path,
|
11 |
+
img_path_list,
|
12 |
+
keys,
|
13 |
+
batch=5000,
|
14 |
+
compress_level=1,
|
15 |
+
multiprocessing_read=False,
|
16 |
+
n_thread=40,
|
17 |
+
map_size=None):
|
18 |
+
"""Make lmdb from images.
|
19 |
+
|
20 |
+
Contents of lmdb. The file structure is:
|
21 |
+
example.lmdb
|
22 |
+
├── data.mdb
|
23 |
+
├── lock.mdb
|
24 |
+
├── meta_info.txt
|
25 |
+
|
26 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
27 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
28 |
+
|
29 |
+
The meta_info.txt is a specified txt file to record the meta information
|
30 |
+
of our datasets. It will be automatically created when preparing
|
31 |
+
datasets by our provided dataset tools.
|
32 |
+
Each line in the txt file records 1)image name (with extension),
|
33 |
+
2)image shape, and 3)compression level, separated by a white space.
|
34 |
+
|
35 |
+
For example, the meta information could be:
|
36 |
+
`000_00000000.png (720,1280,3) 1`, which means:
|
37 |
+
1) image name (with extension): 000_00000000.png;
|
38 |
+
2) image shape: (720,1280,3);
|
39 |
+
3) compression level: 1
|
40 |
+
|
41 |
+
We use the image name without extension as the lmdb key.
|
42 |
+
|
43 |
+
If `multiprocessing_read` is True, it will read all the images to memory
|
44 |
+
using multiprocessing. Thus, your server needs to have enough memory.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
data_path (str): Data path for reading images.
|
48 |
+
lmdb_path (str): Lmdb save path.
|
49 |
+
img_path_list (str): Image path list.
|
50 |
+
keys (str): Used for lmdb keys.
|
51 |
+
batch (int): After processing batch images, lmdb commits.
|
52 |
+
Default: 5000.
|
53 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
54 |
+
multiprocessing_read (bool): Whether use multiprocessing to read all
|
55 |
+
the images to memory. Default: False.
|
56 |
+
n_thread (int): For multiprocessing.
|
57 |
+
map_size (int | None): Map size for lmdb env. If None, use the
|
58 |
+
estimated size from images. Default: None
|
59 |
+
"""
|
60 |
+
|
61 |
+
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
|
62 |
+
f'but got {len(img_path_list)} and {len(keys)}')
|
63 |
+
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
64 |
+
print(f'Totoal images: {len(img_path_list)}')
|
65 |
+
if not lmdb_path.endswith('.lmdb'):
|
66 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
67 |
+
if osp.exists(lmdb_path):
|
68 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
69 |
+
sys.exit(1)
|
70 |
+
|
71 |
+
if multiprocessing_read:
|
72 |
+
# read all the images to memory (multiprocessing)
|
73 |
+
dataset = {} # use dict to keep the order for multiprocessing
|
74 |
+
shapes = {}
|
75 |
+
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
|
76 |
+
pbar = tqdm(total=len(img_path_list), unit='image')
|
77 |
+
|
78 |
+
def callback(arg):
|
79 |
+
"""get the image data and update pbar."""
|
80 |
+
key, dataset[key], shapes[key] = arg
|
81 |
+
pbar.update(1)
|
82 |
+
pbar.set_description(f'Read {key}')
|
83 |
+
|
84 |
+
pool = Pool(n_thread)
|
85 |
+
for path, key in zip(img_path_list, keys):
|
86 |
+
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
|
87 |
+
pool.close()
|
88 |
+
pool.join()
|
89 |
+
pbar.close()
|
90 |
+
print(f'Finish reading {len(img_path_list)} images.')
|
91 |
+
|
92 |
+
# create lmdb environment
|
93 |
+
if map_size is None:
|
94 |
+
# obtain data size for one image
|
95 |
+
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
96 |
+
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
97 |
+
data_size_per_img = img_byte.nbytes
|
98 |
+
print('Data size per image is: ', data_size_per_img)
|
99 |
+
data_size = data_size_per_img * len(img_path_list)
|
100 |
+
map_size = data_size * 10
|
101 |
+
|
102 |
+
env = lmdb.open(lmdb_path, map_size=map_size)
|
103 |
+
|
104 |
+
# write data to lmdb
|
105 |
+
pbar = tqdm(total=len(img_path_list), unit='chunk')
|
106 |
+
txn = env.begin(write=True)
|
107 |
+
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
108 |
+
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
|
109 |
+
pbar.update(1)
|
110 |
+
pbar.set_description(f'Write {key}')
|
111 |
+
key_byte = key.encode('ascii')
|
112 |
+
if multiprocessing_read:
|
113 |
+
img_byte = dataset[key]
|
114 |
+
h, w, c = shapes[key]
|
115 |
+
else:
|
116 |
+
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
|
117 |
+
h, w, c = img_shape
|
118 |
+
|
119 |
+
txn.put(key_byte, img_byte)
|
120 |
+
# write meta information
|
121 |
+
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
|
122 |
+
if idx % batch == 0:
|
123 |
+
txn.commit()
|
124 |
+
txn = env.begin(write=True)
|
125 |
+
pbar.close()
|
126 |
+
txn.commit()
|
127 |
+
env.close()
|
128 |
+
txt_file.close()
|
129 |
+
print('\nFinish writing lmdb.')
|
130 |
+
|
131 |
+
|
132 |
+
def read_img_worker(path, key, compress_level):
|
133 |
+
"""Read image worker.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
path (str): Image path.
|
137 |
+
key (str): Image key.
|
138 |
+
compress_level (int): Compress level when encoding images.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
str: Image key.
|
142 |
+
byte: Image byte.
|
143 |
+
tuple[int]: Image shape.
|
144 |
+
"""
|
145 |
+
|
146 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
147 |
+
# deal with `libpng error: Read Error`
|
148 |
+
if img is None:
|
149 |
+
print(f'To deal with `libpng error: Read Error`, use PIL to load {path}')
|
150 |
+
from PIL import Image
|
151 |
+
import numpy as np
|
152 |
+
img = Image.open(path)
|
153 |
+
img = np.asanyarray(img)
|
154 |
+
img = img[:, :, [2, 1, 0]]
|
155 |
+
|
156 |
+
if img.ndim == 2:
|
157 |
+
h, w = img.shape
|
158 |
+
c = 1
|
159 |
+
else:
|
160 |
+
h, w, c = img.shape
|
161 |
+
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
162 |
+
return (key, img_byte, (h, w, c))
|
163 |
+
|
164 |
+
|
165 |
+
class LmdbMaker():
|
166 |
+
"""LMDB Maker.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
lmdb_path (str): Lmdb save path.
|
170 |
+
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
|
171 |
+
batch (int): After processing batch images, lmdb commits.
|
172 |
+
Default: 5000.
|
173 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
|
177 |
+
if not lmdb_path.endswith('.lmdb'):
|
178 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
179 |
+
if osp.exists(lmdb_path):
|
180 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
181 |
+
sys.exit(1)
|
182 |
+
|
183 |
+
self.lmdb_path = lmdb_path
|
184 |
+
self.batch = batch
|
185 |
+
self.compress_level = compress_level
|
186 |
+
self.env = lmdb.open(lmdb_path, map_size=map_size)
|
187 |
+
self.txn = self.env.begin(write=True)
|
188 |
+
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
189 |
+
self.counter = 0
|
190 |
+
|
191 |
+
def put(self, img_byte, key, img_shape):
|
192 |
+
self.counter += 1
|
193 |
+
key_byte = key.encode('ascii')
|
194 |
+
self.txn.put(key_byte, img_byte)
|
195 |
+
# write meta information
|
196 |
+
h, w, c = img_shape
|
197 |
+
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
|
198 |
+
if self.counter % self.batch == 0:
|
199 |
+
self.txn.commit()
|
200 |
+
self.txn = self.env.begin(write=True)
|
201 |
+
|
202 |
+
def close(self):
|
203 |
+
self.txn.commit()
|
204 |
+
self.env.close()
|
205 |
+
self.txt_file.close()
|
core/data/deg_kair_utils/utils_logger.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import datetime
|
3 |
+
import logging
|
4 |
+
|
5 |
+
|
6 |
+
'''
|
7 |
+
# --------------------------------------------
|
8 |
+
# Kai Zhang (github: https://github.com/cszn)
|
9 |
+
# 03/Mar/2019
|
10 |
+
# --------------------------------------------
|
11 |
+
# https://github.com/xinntao/BasicSR
|
12 |
+
# --------------------------------------------
|
13 |
+
'''
|
14 |
+
|
15 |
+
|
16 |
+
def log(*args, **kwargs):
|
17 |
+
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
|
18 |
+
|
19 |
+
|
20 |
+
'''
|
21 |
+
# --------------------------------------------
|
22 |
+
# logger
|
23 |
+
# --------------------------------------------
|
24 |
+
'''
|
25 |
+
|
26 |
+
|
27 |
+
def logger_info(logger_name, log_path='default_logger.log'):
|
28 |
+
''' set up logger
|
29 |
+
modified by Kai Zhang (github: https://github.com/cszn)
|
30 |
+
'''
|
31 |
+
log = logging.getLogger(logger_name)
|
32 |
+
if log.hasHandlers():
|
33 |
+
print('LogHandlers exist!')
|
34 |
+
else:
|
35 |
+
print('LogHandlers setup!')
|
36 |
+
level = logging.INFO
|
37 |
+
formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
|
38 |
+
fh = logging.FileHandler(log_path, mode='a')
|
39 |
+
fh.setFormatter(formatter)
|
40 |
+
log.setLevel(level)
|
41 |
+
log.addHandler(fh)
|
42 |
+
# print(len(log.handlers))
|
43 |
+
|
44 |
+
sh = logging.StreamHandler()
|
45 |
+
sh.setFormatter(formatter)
|
46 |
+
log.addHandler(sh)
|
47 |
+
|
48 |
+
|
49 |
+
'''
|
50 |
+
# --------------------------------------------
|
51 |
+
# print to file and std_out simultaneously
|
52 |
+
# --------------------------------------------
|
53 |
+
'''
|
54 |
+
|
55 |
+
|
56 |
+
class logger_print(object):
|
57 |
+
def __init__(self, log_path="default.log"):
|
58 |
+
self.terminal = sys.stdout
|
59 |
+
self.log = open(log_path, 'a')
|
60 |
+
|
61 |
+
def write(self, message):
|
62 |
+
self.terminal.write(message)
|
63 |
+
self.log.write(message) # write the message
|
64 |
+
|
65 |
+
def flush(self):
|
66 |
+
pass
|
core/data/deg_kair_utils/utils_mat.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import scipy.io as spio
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
def loadmat(filename):
|
8 |
+
'''
|
9 |
+
this function should be called instead of direct spio.loadmat
|
10 |
+
as it cures the problem of not properly recovering python dictionaries
|
11 |
+
from mat files. It calls the function check keys to cure all entries
|
12 |
+
which are still mat-objects
|
13 |
+
'''
|
14 |
+
data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
|
15 |
+
return dict_to_nonedict(_check_keys(data))
|
16 |
+
|
17 |
+
def _check_keys(dict):
|
18 |
+
'''
|
19 |
+
checks if entries in dictionary are mat-objects. If yes
|
20 |
+
todict is called to change them to nested dictionaries
|
21 |
+
'''
|
22 |
+
for key in dict:
|
23 |
+
if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
|
24 |
+
dict[key] = _todict(dict[key])
|
25 |
+
return dict
|
26 |
+
|
27 |
+
def _todict(matobj):
|
28 |
+
'''
|
29 |
+
A recursive function which constructs from matobjects nested dictionaries
|
30 |
+
'''
|
31 |
+
dict = {}
|
32 |
+
for strg in matobj._fieldnames:
|
33 |
+
elem = matobj.__dict__[strg]
|
34 |
+
if isinstance(elem, spio.matlab.mio5_params.mat_struct):
|
35 |
+
dict[strg] = _todict(elem)
|
36 |
+
else:
|
37 |
+
dict[strg] = elem
|
38 |
+
return dict
|
39 |
+
|
40 |
+
|
41 |
+
def dict_to_nonedict(opt):
|
42 |
+
if isinstance(opt, dict):
|
43 |
+
new_opt = dict()
|
44 |
+
for key, sub_opt in opt.items():
|
45 |
+
new_opt[key] = dict_to_nonedict(sub_opt)
|
46 |
+
return NoneDict(**new_opt)
|
47 |
+
elif isinstance(opt, list):
|
48 |
+
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
49 |
+
else:
|
50 |
+
return opt
|
51 |
+
|
52 |
+
|
53 |
+
class NoneDict(dict):
|
54 |
+
def __missing__(self, key):
|
55 |
+
return None
|
56 |
+
|
57 |
+
|
58 |
+
def mat2json(mat_path=None, filepath = None):
|
59 |
+
"""
|
60 |
+
Converts .mat file to .json and writes new file
|
61 |
+
Parameters
|
62 |
+
----------
|
63 |
+
mat_path: Str
|
64 |
+
path/filename .mat存放路径
|
65 |
+
filepath: Str
|
66 |
+
如果需要保存成json, 添加这一路径. 否则不保存
|
67 |
+
Returns
|
68 |
+
返回转化的字典
|
69 |
+
-------
|
70 |
+
None
|
71 |
+
Examples
|
72 |
+
--------
|
73 |
+
>>> mat2json(blah blah)
|
74 |
+
"""
|
75 |
+
|
76 |
+
matlabFile = loadmat(mat_path)
|
77 |
+
#pop all those dumb fields that don't let you jsonize file
|
78 |
+
matlabFile.pop('__header__')
|
79 |
+
matlabFile.pop('__version__')
|
80 |
+
matlabFile.pop('__globals__')
|
81 |
+
#jsonize the file - orientation is 'index'
|
82 |
+
matlabFile = pd.Series(matlabFile).to_json()
|
83 |
+
|
84 |
+
if filepath:
|
85 |
+
json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json'
|
86 |
+
with open(json_path, 'w') as f:
|
87 |
+
f.write(matlabFile)
|
88 |
+
return matlabFile
|
core/data/deg_kair_utils/utils_matconvnet.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
# import scipy.io as io
|
7 |
+
import hdf5storage
|
8 |
+
|
9 |
+
"""
|
10 |
+
# --------------------------------------------
|
11 |
+
# Convert matconvnet SimpleNN model into pytorch model
|
12 |
+
# --------------------------------------------
|
13 |
+
# Kai Zhang ([email protected])
|
14 |
+
# https://github.com/cszn
|
15 |
+
# 28/Nov/2019
|
16 |
+
# --------------------------------------------
|
17 |
+
"""
|
18 |
+
|
19 |
+
|
20 |
+
def weights2tensor(x, squeeze=False, in_features=None, out_features=None):
|
21 |
+
"""Modified version of https://github.com/albanie/pytorch-mcn
|
22 |
+
Adjust memory layout and load weights as torch tensor
|
23 |
+
Args:
|
24 |
+
x (ndaray): a numpy array, corresponding to a set of network weights
|
25 |
+
stored in column major order
|
26 |
+
squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove
|
27 |
+
singletons from the trailing dimensions. So after converting to
|
28 |
+
pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1)
|
29 |
+
it will be reshaped to a matrix with shape (A,B).
|
30 |
+
in_features (int :: None): used to reshape weights for a linear block.
|
31 |
+
out_features (int :: None): used to reshape weights for a linear block.
|
32 |
+
Returns:
|
33 |
+
torch.tensor: a permuted sets of weights, matching the pytorch layout
|
34 |
+
convention
|
35 |
+
"""
|
36 |
+
if x.ndim == 4:
|
37 |
+
x = x.transpose((3, 2, 0, 1))
|
38 |
+
# for FFDNet, pixel-shuffle layer
|
39 |
+
# if x.shape[1]==13:
|
40 |
+
# x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:]
|
41 |
+
# if x.shape[0]==12:
|
42 |
+
# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
|
43 |
+
# if x.shape[1]==5:
|
44 |
+
# x=x[:,[0,2,1,3, 4],:,:]
|
45 |
+
# if x.shape[0]==4:
|
46 |
+
# x=x[[0,2,1,3],:,:,:]
|
47 |
+
## for SRMD, pixel-shuffle layer
|
48 |
+
# if x.shape[0]==12:
|
49 |
+
# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
|
50 |
+
# if x.shape[0]==27:
|
51 |
+
# x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:]
|
52 |
+
# if x.shape[0]==48:
|
53 |
+
# x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:]
|
54 |
+
|
55 |
+
elif x.ndim == 3: # add by Kai
|
56 |
+
x = x[:,:,:,None]
|
57 |
+
x = x.transpose((3, 2, 0, 1))
|
58 |
+
elif x.ndim == 2:
|
59 |
+
if x.shape[1] == 1:
|
60 |
+
x = x.flatten()
|
61 |
+
if squeeze:
|
62 |
+
if in_features and out_features:
|
63 |
+
x = x.reshape((out_features, in_features))
|
64 |
+
x = np.squeeze(x)
|
65 |
+
return torch.from_numpy(np.ascontiguousarray(x))
|
66 |
+
|
67 |
+
|
68 |
+
def save_model(network, save_path):
|
69 |
+
state_dict = network.state_dict()
|
70 |
+
for key, param in state_dict.items():
|
71 |
+
state_dict[key] = param.cpu()
|
72 |
+
torch.save(state_dict, save_path)
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
|
77 |
+
|
78 |
+
# from utils import utils_logger
|
79 |
+
# import logging
|
80 |
+
# utils_logger.logger_info('a', 'a.log')
|
81 |
+
# logger = logging.getLogger('a')
|
82 |
+
#
|
83 |
+
# mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat')
|
84 |
+
mcn = hdf5storage.loadmat('models/modelcolor.mat')
|
85 |
+
|
86 |
+
|
87 |
+
#logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0])
|
88 |
+
|
89 |
+
mat_net = OrderedDict()
|
90 |
+
for idx in range(25):
|
91 |
+
mat_net[str(idx)] = OrderedDict()
|
92 |
+
count = -1
|
93 |
+
|
94 |
+
print(idx)
|
95 |
+
for i in range(13):
|
96 |
+
|
97 |
+
if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv':
|
98 |
+
|
99 |
+
count += 1
|
100 |
+
w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0]
|
101 |
+
# print(w.shape)
|
102 |
+
w = weights2tensor(w)
|
103 |
+
# print(w.shape)
|
104 |
+
|
105 |
+
b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1]
|
106 |
+
b = weights2tensor(b)
|
107 |
+
print(b.shape)
|
108 |
+
|
109 |
+
mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w
|
110 |
+
mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b
|
111 |
+
|
112 |
+
torch.save(mat_net, 'model_zoo/modelcolor.pth')
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
# from models.network_dncnn import IRCNN as net
|
117 |
+
# network = net(in_nc=3, out_nc=3, nc=64)
|
118 |
+
# state_dict = network.state_dict()
|
119 |
+
#
|
120 |
+
# #show_kv(state_dict)
|
121 |
+
#
|
122 |
+
# for i in range(len(mcn['net'][0][0][0])):
|
123 |
+
# print(mcn['net'][0][0][0][i][0][0][0][0])
|
124 |
+
#
|
125 |
+
# count = -1
|
126 |
+
# mat_net = OrderedDict()
|
127 |
+
# for i in range(len(mcn['net'][0][0][0])):
|
128 |
+
# if mcn['net'][0][0][0][i][0][0][0][0] == 'conv':
|
129 |
+
#
|
130 |
+
# count += 1
|
131 |
+
# w = mcn['net'][0][0][0][i][0][1][0][0]
|
132 |
+
# print(w.shape)
|
133 |
+
# w = weights2tensor(w)
|
134 |
+
# print(w.shape)
|
135 |
+
#
|
136 |
+
# b = mcn['net'][0][0][0][i][0][1][0][1]
|
137 |
+
# b = weights2tensor(b)
|
138 |
+
# print(b.shape)
|
139 |
+
#
|
140 |
+
# mat_net['model.{:d}.weight'.format(count*2)] = w
|
141 |
+
# mat_net['model.{:d}.bias'.format(count*2)] = b
|
142 |
+
#
|
143 |
+
# torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth')
|
144 |
+
#
|
145 |
+
#
|
146 |
+
#
|
147 |
+
# crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth')
|
148 |
+
# def show_kv(net):
|
149 |
+
# for k, v in net.items():
|
150 |
+
# print(k)
|
151 |
+
#
|
152 |
+
# show_kv(crt_net)
|
153 |
+
|
154 |
+
|
155 |
+
# from models.network_dncnn import DnCNN as net
|
156 |
+
# network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
|
157 |
+
|
158 |
+
# from models.network_srmd import SRMD as net
|
159 |
+
# #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
|
160 |
+
# network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
|
161 |
+
#
|
162 |
+
# from models.network_rrdb import RRDB as net
|
163 |
+
# network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv')
|
164 |
+
#
|
165 |
+
# state_dict = network.state_dict()
|
166 |
+
# for key, param in state_dict.items():
|
167 |
+
# print(key)
|
168 |
+
# from models.network_imdn import IMDN as net
|
169 |
+
# network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle')
|
170 |
+
# state_dict = network.state_dict()
|
171 |
+
# mat_net = OrderedDict()
|
172 |
+
# for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()):
|
173 |
+
# mat_net[key] = param2
|
174 |
+
# torch.save(mat_net, 'model_zoo/imdn_x4_1.pth')
|
175 |
+
#
|
176 |
+
|
177 |
+
# net_old = torch.load('net_old.pth')
|
178 |
+
# def show_kv(net):
|
179 |
+
# for k, v in net.items():
|
180 |
+
# print(k)
|
181 |
+
#
|
182 |
+
# show_kv(net_old)
|
183 |
+
# from models.network_dpsr import MSRResNet_prior as net
|
184 |
+
# model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
|
185 |
+
# state_dict = network.state_dict()
|
186 |
+
# net_new = OrderedDict()
|
187 |
+
# for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()):
|
188 |
+
# net_new[key] = param_old
|
189 |
+
# torch.save(net_new, 'net_new.pth')
|
190 |
+
|
191 |
+
|
192 |
+
# print(key)
|
193 |
+
# print(param.size())
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
# run utils/utils_matconvnet.py
|
core/data/deg_kair_utils/utils_model.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from utils import utils_image as util
|
5 |
+
import re
|
6 |
+
import glob
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
'''
|
11 |
+
# --------------------------------------------
|
12 |
+
# Model
|
13 |
+
# --------------------------------------------
|
14 |
+
# Kai Zhang (github: https://github.com/cszn)
|
15 |
+
# 03/Mar/2019
|
16 |
+
# --------------------------------------------
|
17 |
+
'''
|
18 |
+
|
19 |
+
|
20 |
+
def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
|
21 |
+
"""
|
22 |
+
# ---------------------------------------
|
23 |
+
# Kai Zhang (github: https://github.com/cszn)
|
24 |
+
# 03/Mar/2019
|
25 |
+
# ---------------------------------------
|
26 |
+
Args:
|
27 |
+
save_dir: model folder
|
28 |
+
net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
|
29 |
+
pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
|
30 |
+
|
31 |
+
Return:
|
32 |
+
init_iter: iteration number
|
33 |
+
init_path: model path
|
34 |
+
# ---------------------------------------
|
35 |
+
"""
|
36 |
+
|
37 |
+
file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
|
38 |
+
if file_list:
|
39 |
+
iter_exist = []
|
40 |
+
for file_ in file_list:
|
41 |
+
iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
|
42 |
+
iter_exist.append(int(iter_current[0]))
|
43 |
+
init_iter = max(iter_exist)
|
44 |
+
init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
|
45 |
+
else:
|
46 |
+
init_iter = 0
|
47 |
+
init_path = pretrained_path
|
48 |
+
return init_iter, init_path
|
49 |
+
|
50 |
+
|
51 |
+
def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
|
52 |
+
'''
|
53 |
+
# ---------------------------------------
|
54 |
+
# Kai Zhang (github: https://github.com/cszn)
|
55 |
+
# 03/Mar/2019
|
56 |
+
# ---------------------------------------
|
57 |
+
Args:
|
58 |
+
model: trained model
|
59 |
+
L: input Low-quality image
|
60 |
+
mode:
|
61 |
+
(0) normal: test(model, L)
|
62 |
+
(1) pad: test_pad(model, L, modulo=16)
|
63 |
+
(2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
|
64 |
+
(3) x8: test_x8(model, L, modulo=1) ^_^
|
65 |
+
(4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
|
66 |
+
refield: effective receptive filed of the network, 32 is enough
|
67 |
+
useful when split, i.e., mode=2, 4
|
68 |
+
min_size: min_sizeXmin_size image, e.g., 256X256 image
|
69 |
+
useful when split, i.e., mode=2, 4
|
70 |
+
sf: scale factor for super-resolution, otherwise 1
|
71 |
+
modulo: 1 if split
|
72 |
+
useful when pad, i.e., mode=1
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
E: estimated image
|
76 |
+
# ---------------------------------------
|
77 |
+
'''
|
78 |
+
if mode == 0:
|
79 |
+
E = test(model, L)
|
80 |
+
elif mode == 1:
|
81 |
+
E = test_pad(model, L, modulo, sf)
|
82 |
+
elif mode == 2:
|
83 |
+
E = test_split(model, L, refield, min_size, sf, modulo)
|
84 |
+
elif mode == 3:
|
85 |
+
E = test_x8(model, L, modulo, sf)
|
86 |
+
elif mode == 4:
|
87 |
+
E = test_split_x8(model, L, refield, min_size, sf, modulo)
|
88 |
+
return E
|
89 |
+
|
90 |
+
|
91 |
+
'''
|
92 |
+
# --------------------------------------------
|
93 |
+
# normal (0)
|
94 |
+
# --------------------------------------------
|
95 |
+
'''
|
96 |
+
|
97 |
+
|
98 |
+
def test(model, L):
|
99 |
+
E = model(L)
|
100 |
+
return E
|
101 |
+
|
102 |
+
|
103 |
+
'''
|
104 |
+
# --------------------------------------------
|
105 |
+
# pad (1)
|
106 |
+
# --------------------------------------------
|
107 |
+
'''
|
108 |
+
|
109 |
+
|
110 |
+
def test_pad(model, L, modulo=16, sf=1):
|
111 |
+
h, w = L.size()[-2:]
|
112 |
+
paddingBottom = int(np.ceil(h/modulo)*modulo-h)
|
113 |
+
paddingRight = int(np.ceil(w/modulo)*modulo-w)
|
114 |
+
L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
|
115 |
+
E = model(L)
|
116 |
+
E = E[..., :h*sf, :w*sf]
|
117 |
+
return E
|
118 |
+
|
119 |
+
|
120 |
+
'''
|
121 |
+
# --------------------------------------------
|
122 |
+
# split (function)
|
123 |
+
# --------------------------------------------
|
124 |
+
'''
|
125 |
+
|
126 |
+
|
127 |
+
def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
|
128 |
+
"""
|
129 |
+
Args:
|
130 |
+
model: trained model
|
131 |
+
L: input Low-quality image
|
132 |
+
refield: effective receptive filed of the network, 32 is enough
|
133 |
+
min_size: min_sizeXmin_size image, e.g., 256X256 image
|
134 |
+
sf: scale factor for super-resolution, otherwise 1
|
135 |
+
modulo: 1 if split
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
E: estimated result
|
139 |
+
"""
|
140 |
+
h, w = L.size()[-2:]
|
141 |
+
if h*w <= min_size**2:
|
142 |
+
L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
|
143 |
+
E = model(L)
|
144 |
+
E = E[..., :h*sf, :w*sf]
|
145 |
+
else:
|
146 |
+
top = slice(0, (h//2//refield+1)*refield)
|
147 |
+
bottom = slice(h - (h//2//refield+1)*refield, h)
|
148 |
+
left = slice(0, (w//2//refield+1)*refield)
|
149 |
+
right = slice(w - (w//2//refield+1)*refield, w)
|
150 |
+
Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
|
151 |
+
|
152 |
+
if h * w <= 4*(min_size**2):
|
153 |
+
Es = [model(Ls[i]) for i in range(4)]
|
154 |
+
else:
|
155 |
+
Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
|
156 |
+
|
157 |
+
b, c = Es[0].size()[:2]
|
158 |
+
E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
|
159 |
+
|
160 |
+
E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
|
161 |
+
E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
|
162 |
+
E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
|
163 |
+
E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
|
164 |
+
return E
|
165 |
+
|
166 |
+
|
167 |
+
'''
|
168 |
+
# --------------------------------------------
|
169 |
+
# split (2)
|
170 |
+
# --------------------------------------------
|
171 |
+
'''
|
172 |
+
|
173 |
+
|
174 |
+
def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
|
175 |
+
E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
|
176 |
+
return E
|
177 |
+
|
178 |
+
|
179 |
+
'''
|
180 |
+
# --------------------------------------------
|
181 |
+
# x8 (3)
|
182 |
+
# --------------------------------------------
|
183 |
+
'''
|
184 |
+
|
185 |
+
|
186 |
+
def test_x8(model, L, modulo=1, sf=1):
|
187 |
+
E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)]
|
188 |
+
for i in range(len(E_list)):
|
189 |
+
if i == 3 or i == 5:
|
190 |
+
E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i)
|
191 |
+
else:
|
192 |
+
E_list[i] = util.augment_img_tensor4(E_list[i], mode=i)
|
193 |
+
output_cat = torch.stack(E_list, dim=0)
|
194 |
+
E = output_cat.mean(dim=0, keepdim=False)
|
195 |
+
return E
|
196 |
+
|
197 |
+
|
198 |
+
'''
|
199 |
+
# --------------------------------------------
|
200 |
+
# split and x8 (4)
|
201 |
+
# --------------------------------------------
|
202 |
+
'''
|
203 |
+
|
204 |
+
|
205 |
+
def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
|
206 |
+
E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
|
207 |
+
for k, i in enumerate(range(len(E_list))):
|
208 |
+
if i==3 or i==5:
|
209 |
+
E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i)
|
210 |
+
else:
|
211 |
+
E_list[k] = util.augment_img_tensor4(E_list[k], mode=i)
|
212 |
+
output_cat = torch.stack(E_list, dim=0)
|
213 |
+
E = output_cat.mean(dim=0, keepdim=False)
|
214 |
+
return E
|
215 |
+
|
216 |
+
|
217 |
+
'''
|
218 |
+
# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
|
219 |
+
# _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
|
220 |
+
# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
|
221 |
+
'''
|
222 |
+
|
223 |
+
|
224 |
+
'''
|
225 |
+
# --------------------------------------------
|
226 |
+
# print
|
227 |
+
# --------------------------------------------
|
228 |
+
'''
|
229 |
+
|
230 |
+
|
231 |
+
# --------------------------------------------
|
232 |
+
# print model
|
233 |
+
# --------------------------------------------
|
234 |
+
def print_model(model):
|
235 |
+
msg = describe_model(model)
|
236 |
+
print(msg)
|
237 |
+
|
238 |
+
|
239 |
+
# --------------------------------------------
|
240 |
+
# print params
|
241 |
+
# --------------------------------------------
|
242 |
+
def print_params(model):
|
243 |
+
msg = describe_params(model)
|
244 |
+
print(msg)
|
245 |
+
|
246 |
+
|
247 |
+
'''
|
248 |
+
# --------------------------------------------
|
249 |
+
# information
|
250 |
+
# --------------------------------------------
|
251 |
+
'''
|
252 |
+
|
253 |
+
|
254 |
+
# --------------------------------------------
|
255 |
+
# model inforation
|
256 |
+
# --------------------------------------------
|
257 |
+
def info_model(model):
|
258 |
+
msg = describe_model(model)
|
259 |
+
return msg
|
260 |
+
|
261 |
+
|
262 |
+
# --------------------------------------------
|
263 |
+
# params inforation
|
264 |
+
# --------------------------------------------
|
265 |
+
def info_params(model):
|
266 |
+
msg = describe_params(model)
|
267 |
+
return msg
|
268 |
+
|
269 |
+
|
270 |
+
'''
|
271 |
+
# --------------------------------------------
|
272 |
+
# description
|
273 |
+
# --------------------------------------------
|
274 |
+
'''
|
275 |
+
|
276 |
+
|
277 |
+
# --------------------------------------------
|
278 |
+
# model name and total number of parameters
|
279 |
+
# --------------------------------------------
|
280 |
+
def describe_model(model):
|
281 |
+
if isinstance(model, torch.nn.DataParallel):
|
282 |
+
model = model.module
|
283 |
+
msg = '\n'
|
284 |
+
msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
|
285 |
+
msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
|
286 |
+
msg += 'Net structure:\n{}'.format(str(model)) + '\n'
|
287 |
+
return msg
|
288 |
+
|
289 |
+
|
290 |
+
# --------------------------------------------
|
291 |
+
# parameters description
|
292 |
+
# --------------------------------------------
|
293 |
+
def describe_params(model):
|
294 |
+
if isinstance(model, torch.nn.DataParallel):
|
295 |
+
model = model.module
|
296 |
+
msg = '\n'
|
297 |
+
msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
|
298 |
+
for name, param in model.state_dict().items():
|
299 |
+
if not 'num_batches_tracked' in name:
|
300 |
+
v = param.data.clone().float()
|
301 |
+
msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
|
302 |
+
return msg
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == '__main__':
|
306 |
+
|
307 |
+
class Net(torch.nn.Module):
|
308 |
+
def __init__(self, in_channels=3, out_channels=3):
|
309 |
+
super(Net, self).__init__()
|
310 |
+
self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
|
311 |
+
|
312 |
+
def forward(self, x):
|
313 |
+
x = self.conv(x)
|
314 |
+
return x
|
315 |
+
|
316 |
+
start = torch.cuda.Event(enable_timing=True)
|
317 |
+
end = torch.cuda.Event(enable_timing=True)
|
318 |
+
|
319 |
+
model = Net()
|
320 |
+
model = model.eval()
|
321 |
+
print_model(model)
|
322 |
+
print_params(model)
|
323 |
+
x = torch.randn((2,3,401,401))
|
324 |
+
torch.cuda.empty_cache()
|
325 |
+
with torch.no_grad():
|
326 |
+
for mode in range(5):
|
327 |
+
y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1)
|
328 |
+
print(y.shape)
|
329 |
+
|
330 |
+
# run utils/utils_model.py
|
core/data/deg_kair_utils/utils_modelsummary.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
'''
|
6 |
+
---- 1) FLOPs: floating point operations
|
7 |
+
---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
|
8 |
+
---- 3) #Conv2d: the number of ‘Conv2d’ layers
|
9 |
+
# --------------------------------------------
|
10 |
+
# Kai Zhang (github: https://github.com/cszn)
|
11 |
+
# 21/July/2020
|
12 |
+
# --------------------------------------------
|
13 |
+
# Reference
|
14 |
+
https://github.com/sovrasov/flops-counter.pytorch.git
|
15 |
+
|
16 |
+
# If you use this code, please consider the following citation:
|
17 |
+
|
18 |
+
@inproceedings{zhang2020aim, %
|
19 |
+
title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
|
20 |
+
author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
|
21 |
+
booktitle={European Conference on Computer Vision Workshops},
|
22 |
+
year={2020}
|
23 |
+
}
|
24 |
+
# --------------------------------------------
|
25 |
+
'''
|
26 |
+
|
27 |
+
def get_model_flops(model, input_res, print_per_layer_stat=True,
|
28 |
+
input_constructor=None):
|
29 |
+
assert type(input_res) is tuple, 'Please provide the size of the input image.'
|
30 |
+
assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
|
31 |
+
flops_model = add_flops_counting_methods(model)
|
32 |
+
flops_model.eval().start_flops_count()
|
33 |
+
if input_constructor:
|
34 |
+
input = input_constructor(input_res)
|
35 |
+
_ = flops_model(**input)
|
36 |
+
else:
|
37 |
+
device = list(flops_model.parameters())[-1].device
|
38 |
+
batch = torch.FloatTensor(1, *input_res).to(device)
|
39 |
+
_ = flops_model(batch)
|
40 |
+
|
41 |
+
if print_per_layer_stat:
|
42 |
+
print_model_with_flops(flops_model)
|
43 |
+
flops_count = flops_model.compute_average_flops_cost()
|
44 |
+
flops_model.stop_flops_count()
|
45 |
+
|
46 |
+
return flops_count
|
47 |
+
|
48 |
+
def get_model_activation(model, input_res, input_constructor=None):
|
49 |
+
assert type(input_res) is tuple, 'Please provide the size of the input image.'
|
50 |
+
assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
|
51 |
+
activation_model = add_activation_counting_methods(model)
|
52 |
+
activation_model.eval().start_activation_count()
|
53 |
+
if input_constructor:
|
54 |
+
input = input_constructor(input_res)
|
55 |
+
_ = activation_model(**input)
|
56 |
+
else:
|
57 |
+
device = list(activation_model.parameters())[-1].device
|
58 |
+
batch = torch.FloatTensor(1, *input_res).to(device)
|
59 |
+
_ = activation_model(batch)
|
60 |
+
|
61 |
+
activation_count, num_conv = activation_model.compute_average_activation_cost()
|
62 |
+
activation_model.stop_activation_count()
|
63 |
+
|
64 |
+
return activation_count, num_conv
|
65 |
+
|
66 |
+
|
67 |
+
def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
|
68 |
+
input_constructor=None):
|
69 |
+
assert type(input_res) is tuple
|
70 |
+
assert len(input_res) >= 3
|
71 |
+
flops_model = add_flops_counting_methods(model)
|
72 |
+
flops_model.eval().start_flops_count()
|
73 |
+
if input_constructor:
|
74 |
+
input = input_constructor(input_res)
|
75 |
+
_ = flops_model(**input)
|
76 |
+
else:
|
77 |
+
batch = torch.FloatTensor(1, *input_res)
|
78 |
+
_ = flops_model(batch)
|
79 |
+
|
80 |
+
if print_per_layer_stat:
|
81 |
+
print_model_with_flops(flops_model)
|
82 |
+
flops_count = flops_model.compute_average_flops_cost()
|
83 |
+
params_count = get_model_parameters_number(flops_model)
|
84 |
+
flops_model.stop_flops_count()
|
85 |
+
|
86 |
+
if as_strings:
|
87 |
+
return flops_to_string(flops_count), params_to_string(params_count)
|
88 |
+
|
89 |
+
return flops_count, params_count
|
90 |
+
|
91 |
+
|
92 |
+
def flops_to_string(flops, units='GMac', precision=2):
|
93 |
+
if units is None:
|
94 |
+
if flops // 10**9 > 0:
|
95 |
+
return str(round(flops / 10.**9, precision)) + ' GMac'
|
96 |
+
elif flops // 10**6 > 0:
|
97 |
+
return str(round(flops / 10.**6, precision)) + ' MMac'
|
98 |
+
elif flops // 10**3 > 0:
|
99 |
+
return str(round(flops / 10.**3, precision)) + ' KMac'
|
100 |
+
else:
|
101 |
+
return str(flops) + ' Mac'
|
102 |
+
else:
|
103 |
+
if units == 'GMac':
|
104 |
+
return str(round(flops / 10.**9, precision)) + ' ' + units
|
105 |
+
elif units == 'MMac':
|
106 |
+
return str(round(flops / 10.**6, precision)) + ' ' + units
|
107 |
+
elif units == 'KMac':
|
108 |
+
return str(round(flops / 10.**3, precision)) + ' ' + units
|
109 |
+
else:
|
110 |
+
return str(flops) + ' Mac'
|
111 |
+
|
112 |
+
|
113 |
+
def params_to_string(params_num):
|
114 |
+
if params_num // 10 ** 6 > 0:
|
115 |
+
return str(round(params_num / 10 ** 6, 2)) + ' M'
|
116 |
+
elif params_num // 10 ** 3:
|
117 |
+
return str(round(params_num / 10 ** 3, 2)) + ' k'
|
118 |
+
else:
|
119 |
+
return str(params_num)
|
120 |
+
|
121 |
+
|
122 |
+
def print_model_with_flops(model, units='GMac', precision=3):
|
123 |
+
total_flops = model.compute_average_flops_cost()
|
124 |
+
|
125 |
+
def accumulate_flops(self):
|
126 |
+
if is_supported_instance(self):
|
127 |
+
return self.__flops__ / model.__batch_counter__
|
128 |
+
else:
|
129 |
+
sum = 0
|
130 |
+
for m in self.children():
|
131 |
+
sum += m.accumulate_flops()
|
132 |
+
return sum
|
133 |
+
|
134 |
+
def flops_repr(self):
|
135 |
+
accumulated_flops_cost = self.accumulate_flops()
|
136 |
+
return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
|
137 |
+
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
|
138 |
+
self.original_extra_repr()])
|
139 |
+
|
140 |
+
def add_extra_repr(m):
|
141 |
+
m.accumulate_flops = accumulate_flops.__get__(m)
|
142 |
+
flops_extra_repr = flops_repr.__get__(m)
|
143 |
+
if m.extra_repr != flops_extra_repr:
|
144 |
+
m.original_extra_repr = m.extra_repr
|
145 |
+
m.extra_repr = flops_extra_repr
|
146 |
+
assert m.extra_repr != m.original_extra_repr
|
147 |
+
|
148 |
+
def del_extra_repr(m):
|
149 |
+
if hasattr(m, 'original_extra_repr'):
|
150 |
+
m.extra_repr = m.original_extra_repr
|
151 |
+
del m.original_extra_repr
|
152 |
+
if hasattr(m, 'accumulate_flops'):
|
153 |
+
del m.accumulate_flops
|
154 |
+
|
155 |
+
model.apply(add_extra_repr)
|
156 |
+
print(model)
|
157 |
+
model.apply(del_extra_repr)
|
158 |
+
|
159 |
+
|
160 |
+
def get_model_parameters_number(model):
|
161 |
+
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
162 |
+
return params_num
|
163 |
+
|
164 |
+
|
165 |
+
def add_flops_counting_methods(net_main_module):
|
166 |
+
# adding additional methods to the existing module object,
|
167 |
+
# this is done this way so that each function has access to self object
|
168 |
+
# embed()
|
169 |
+
net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
|
170 |
+
net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
|
171 |
+
net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
|
172 |
+
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
|
173 |
+
|
174 |
+
net_main_module.reset_flops_count()
|
175 |
+
return net_main_module
|
176 |
+
|
177 |
+
|
178 |
+
def compute_average_flops_cost(self):
|
179 |
+
"""
|
180 |
+
A method that will be available after add_flops_counting_methods() is called
|
181 |
+
on a desired net object.
|
182 |
+
|
183 |
+
Returns current mean flops consumption per image.
|
184 |
+
|
185 |
+
"""
|
186 |
+
|
187 |
+
flops_sum = 0
|
188 |
+
for module in self.modules():
|
189 |
+
if is_supported_instance(module):
|
190 |
+
flops_sum += module.__flops__
|
191 |
+
|
192 |
+
return flops_sum
|
193 |
+
|
194 |
+
|
195 |
+
def start_flops_count(self):
|
196 |
+
"""
|
197 |
+
A method that will be available after add_flops_counting_methods() is called
|
198 |
+
on a desired net object.
|
199 |
+
|
200 |
+
Activates the computation of mean flops consumption per image.
|
201 |
+
Call it before you run the network.
|
202 |
+
|
203 |
+
"""
|
204 |
+
self.apply(add_flops_counter_hook_function)
|
205 |
+
|
206 |
+
|
207 |
+
def stop_flops_count(self):
|
208 |
+
"""
|
209 |
+
A method that will be available after add_flops_counting_methods() is called
|
210 |
+
on a desired net object.
|
211 |
+
|
212 |
+
Stops computing the mean flops consumption per image.
|
213 |
+
Call whenever you want to pause the computation.
|
214 |
+
|
215 |
+
"""
|
216 |
+
self.apply(remove_flops_counter_hook_function)
|
217 |
+
|
218 |
+
|
219 |
+
def reset_flops_count(self):
|
220 |
+
"""
|
221 |
+
A method that will be available after add_flops_counting_methods() is called
|
222 |
+
on a desired net object.
|
223 |
+
|
224 |
+
Resets statistics computed so far.
|
225 |
+
|
226 |
+
"""
|
227 |
+
self.apply(add_flops_counter_variable_or_reset)
|
228 |
+
|
229 |
+
|
230 |
+
def add_flops_counter_hook_function(module):
|
231 |
+
if is_supported_instance(module):
|
232 |
+
if hasattr(module, '__flops_handle__'):
|
233 |
+
return
|
234 |
+
|
235 |
+
if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
|
236 |
+
handle = module.register_forward_hook(conv_flops_counter_hook)
|
237 |
+
elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
|
238 |
+
handle = module.register_forward_hook(relu_flops_counter_hook)
|
239 |
+
elif isinstance(module, nn.Linear):
|
240 |
+
handle = module.register_forward_hook(linear_flops_counter_hook)
|
241 |
+
elif isinstance(module, (nn.BatchNorm2d)):
|
242 |
+
handle = module.register_forward_hook(bn_flops_counter_hook)
|
243 |
+
else:
|
244 |
+
handle = module.register_forward_hook(empty_flops_counter_hook)
|
245 |
+
module.__flops_handle__ = handle
|
246 |
+
|
247 |
+
|
248 |
+
def remove_flops_counter_hook_function(module):
|
249 |
+
if is_supported_instance(module):
|
250 |
+
if hasattr(module, '__flops_handle__'):
|
251 |
+
module.__flops_handle__.remove()
|
252 |
+
del module.__flops_handle__
|
253 |
+
|
254 |
+
|
255 |
+
def add_flops_counter_variable_or_reset(module):
|
256 |
+
if is_supported_instance(module):
|
257 |
+
module.__flops__ = 0
|
258 |
+
|
259 |
+
|
260 |
+
# ---- Internal functions
|
261 |
+
def is_supported_instance(module):
|
262 |
+
if isinstance(module,
|
263 |
+
(
|
264 |
+
nn.Conv2d, nn.ConvTranspose2d,
|
265 |
+
nn.BatchNorm2d,
|
266 |
+
nn.Linear,
|
267 |
+
nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
|
268 |
+
)):
|
269 |
+
return True
|
270 |
+
|
271 |
+
return False
|
272 |
+
|
273 |
+
|
274 |
+
def conv_flops_counter_hook(conv_module, input, output):
|
275 |
+
# Can have multiple inputs, getting the first one
|
276 |
+
# input = input[0]
|
277 |
+
|
278 |
+
batch_size = output.shape[0]
|
279 |
+
output_dims = list(output.shape[2:])
|
280 |
+
|
281 |
+
kernel_dims = list(conv_module.kernel_size)
|
282 |
+
in_channels = conv_module.in_channels
|
283 |
+
out_channels = conv_module.out_channels
|
284 |
+
groups = conv_module.groups
|
285 |
+
|
286 |
+
filters_per_channel = out_channels // groups
|
287 |
+
conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
|
288 |
+
|
289 |
+
active_elements_count = batch_size * np.prod(output_dims)
|
290 |
+
overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
|
291 |
+
|
292 |
+
# overall_flops = overall_conv_flops
|
293 |
+
|
294 |
+
conv_module.__flops__ += int(overall_conv_flops)
|
295 |
+
# conv_module.__output_dims__ = output_dims
|
296 |
+
|
297 |
+
|
298 |
+
def relu_flops_counter_hook(module, input, output):
|
299 |
+
active_elements_count = output.numel()
|
300 |
+
module.__flops__ += int(active_elements_count)
|
301 |
+
# print(module.__flops__, id(module))
|
302 |
+
# print(module)
|
303 |
+
|
304 |
+
|
305 |
+
def linear_flops_counter_hook(module, input, output):
|
306 |
+
input = input[0]
|
307 |
+
if len(input.shape) == 1:
|
308 |
+
batch_size = 1
|
309 |
+
module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
|
310 |
+
else:
|
311 |
+
batch_size = input.shape[0]
|
312 |
+
module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
|
313 |
+
|
314 |
+
|
315 |
+
def bn_flops_counter_hook(module, input, output):
|
316 |
+
# input = input[0]
|
317 |
+
# TODO: need to check here
|
318 |
+
# batch_flops = np.prod(input.shape)
|
319 |
+
# if module.affine:
|
320 |
+
# batch_flops *= 2
|
321 |
+
# module.__flops__ += int(batch_flops)
|
322 |
+
batch = output.shape[0]
|
323 |
+
output_dims = output.shape[2:]
|
324 |
+
channels = module.num_features
|
325 |
+
batch_flops = batch * channels * np.prod(output_dims)
|
326 |
+
if module.affine:
|
327 |
+
batch_flops *= 2
|
328 |
+
module.__flops__ += int(batch_flops)
|
329 |
+
|
330 |
+
|
331 |
+
# ---- Count the number of convolutional layers and the activation
|
332 |
+
def add_activation_counting_methods(net_main_module):
|
333 |
+
# adding additional methods to the existing module object,
|
334 |
+
# this is done this way so that each function has access to self object
|
335 |
+
# embed()
|
336 |
+
net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
|
337 |
+
net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
|
338 |
+
net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
|
339 |
+
net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
|
340 |
+
|
341 |
+
net_main_module.reset_activation_count()
|
342 |
+
return net_main_module
|
343 |
+
|
344 |
+
|
345 |
+
def compute_average_activation_cost(self):
|
346 |
+
"""
|
347 |
+
A method that will be available after add_activation_counting_methods() is called
|
348 |
+
on a desired net object.
|
349 |
+
|
350 |
+
Returns current mean activation consumption per image.
|
351 |
+
|
352 |
+
"""
|
353 |
+
|
354 |
+
activation_sum = 0
|
355 |
+
num_conv = 0
|
356 |
+
for module in self.modules():
|
357 |
+
if is_supported_instance_for_activation(module):
|
358 |
+
activation_sum += module.__activation__
|
359 |
+
num_conv += module.__num_conv__
|
360 |
+
return activation_sum, num_conv
|
361 |
+
|
362 |
+
|
363 |
+
def start_activation_count(self):
|
364 |
+
"""
|
365 |
+
A method that will be available after add_activation_counting_methods() is called
|
366 |
+
on a desired net object.
|
367 |
+
|
368 |
+
Activates the computation of mean activation consumption per image.
|
369 |
+
Call it before you run the network.
|
370 |
+
|
371 |
+
"""
|
372 |
+
self.apply(add_activation_counter_hook_function)
|
373 |
+
|
374 |
+
|
375 |
+
def stop_activation_count(self):
|
376 |
+
"""
|
377 |
+
A method that will be available after add_activation_counting_methods() is called
|
378 |
+
on a desired net object.
|
379 |
+
|
380 |
+
Stops computing the mean activation consumption per image.
|
381 |
+
Call whenever you want to pause the computation.
|
382 |
+
|
383 |
+
"""
|
384 |
+
self.apply(remove_activation_counter_hook_function)
|
385 |
+
|
386 |
+
|
387 |
+
def reset_activation_count(self):
|
388 |
+
"""
|
389 |
+
A method that will be available after add_activation_counting_methods() is called
|
390 |
+
on a desired net object.
|
391 |
+
|
392 |
+
Resets statistics computed so far.
|
393 |
+
|
394 |
+
"""
|
395 |
+
self.apply(add_activation_counter_variable_or_reset)
|
396 |
+
|
397 |
+
|
398 |
+
def add_activation_counter_hook_function(module):
|
399 |
+
if is_supported_instance_for_activation(module):
|
400 |
+
if hasattr(module, '__activation_handle__'):
|
401 |
+
return
|
402 |
+
|
403 |
+
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
404 |
+
handle = module.register_forward_hook(conv_activation_counter_hook)
|
405 |
+
module.__activation_handle__ = handle
|
406 |
+
|
407 |
+
|
408 |
+
def remove_activation_counter_hook_function(module):
|
409 |
+
if is_supported_instance_for_activation(module):
|
410 |
+
if hasattr(module, '__activation_handle__'):
|
411 |
+
module.__activation_handle__.remove()
|
412 |
+
del module.__activation_handle__
|
413 |
+
|
414 |
+
|
415 |
+
def add_activation_counter_variable_or_reset(module):
|
416 |
+
if is_supported_instance_for_activation(module):
|
417 |
+
module.__activation__ = 0
|
418 |
+
module.__num_conv__ = 0
|
419 |
+
|
420 |
+
|
421 |
+
def is_supported_instance_for_activation(module):
|
422 |
+
if isinstance(module,
|
423 |
+
(
|
424 |
+
nn.Conv2d, nn.ConvTranspose2d,
|
425 |
+
)):
|
426 |
+
return True
|
427 |
+
|
428 |
+
return False
|
429 |
+
|
430 |
+
def conv_activation_counter_hook(module, input, output):
|
431 |
+
"""
|
432 |
+
Calculate the activations in the convolutional operation.
|
433 |
+
Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
|
434 |
+
:param module:
|
435 |
+
:param input:
|
436 |
+
:param output:
|
437 |
+
:return:
|
438 |
+
"""
|
439 |
+
module.__activation__ += output.numel()
|
440 |
+
module.__num_conv__ += 1
|
441 |
+
|
442 |
+
|
443 |
+
def empty_flops_counter_hook(module, input, output):
|
444 |
+
module.__flops__ += 0
|
445 |
+
|
446 |
+
|
447 |
+
def upsample_flops_counter_hook(module, input, output):
|
448 |
+
output_size = output[0]
|
449 |
+
batch_size = output_size.shape[0]
|
450 |
+
output_elements_count = batch_size
|
451 |
+
for val in output_size.shape[1:]:
|
452 |
+
output_elements_count *= val
|
453 |
+
module.__flops__ += int(output_elements_count)
|
454 |
+
|
455 |
+
|
456 |
+
def pool_flops_counter_hook(module, input, output):
|
457 |
+
input = input[0]
|
458 |
+
module.__flops__ += int(np.prod(input.shape))
|
459 |
+
|
460 |
+
|
461 |
+
def dconv_flops_counter_hook(dconv_module, input, output):
|
462 |
+
input = input[0]
|
463 |
+
|
464 |
+
batch_size = input.shape[0]
|
465 |
+
output_dims = list(output.shape[2:])
|
466 |
+
|
467 |
+
m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
|
468 |
+
out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
|
469 |
+
# groups = dconv_module.groups
|
470 |
+
|
471 |
+
# filters_per_channel = out_channels // groups
|
472 |
+
conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
|
473 |
+
conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
|
474 |
+
active_elements_count = batch_size * np.prod(output_dims)
|
475 |
+
|
476 |
+
overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
|
477 |
+
overall_flops = overall_conv_flops
|
478 |
+
|
479 |
+
dconv_module.__flops__ += int(overall_flops)
|
480 |
+
# dconv_module.__output_dims__ = output_dims
|
481 |
+
|
482 |
+
|
483 |
+
|
484 |
+
|
485 |
+
|
core/data/deg_kair_utils/utils_option.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import OrderedDict
|
3 |
+
from datetime import datetime
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
import glob
|
7 |
+
|
8 |
+
|
9 |
+
'''
|
10 |
+
# --------------------------------------------
|
11 |
+
# Kai Zhang (github: https://github.com/cszn)
|
12 |
+
# 03/Mar/2019
|
13 |
+
# --------------------------------------------
|
14 |
+
# https://github.com/xinntao/BasicSR
|
15 |
+
# --------------------------------------------
|
16 |
+
'''
|
17 |
+
|
18 |
+
|
19 |
+
def get_timestamp():
|
20 |
+
return datetime.now().strftime('_%y%m%d_%H%M%S')
|
21 |
+
|
22 |
+
|
23 |
+
def parse(opt_path, is_train=True):
|
24 |
+
|
25 |
+
# ----------------------------------------
|
26 |
+
# remove comments starting with '//'
|
27 |
+
# ----------------------------------------
|
28 |
+
json_str = ''
|
29 |
+
with open(opt_path, 'r') as f:
|
30 |
+
for line in f:
|
31 |
+
line = line.split('//')[0] + '\n'
|
32 |
+
json_str += line
|
33 |
+
|
34 |
+
# ----------------------------------------
|
35 |
+
# initialize opt
|
36 |
+
# ----------------------------------------
|
37 |
+
opt = json.loads(json_str, object_pairs_hook=OrderedDict)
|
38 |
+
|
39 |
+
opt['opt_path'] = opt_path
|
40 |
+
opt['is_train'] = is_train
|
41 |
+
|
42 |
+
# ----------------------------------------
|
43 |
+
# set default
|
44 |
+
# ----------------------------------------
|
45 |
+
if 'merge_bn' not in opt:
|
46 |
+
opt['merge_bn'] = False
|
47 |
+
opt['merge_bn_startpoint'] = -1
|
48 |
+
|
49 |
+
if 'scale' not in opt:
|
50 |
+
opt['scale'] = 1
|
51 |
+
|
52 |
+
# ----------------------------------------
|
53 |
+
# datasets
|
54 |
+
# ----------------------------------------
|
55 |
+
for phase, dataset in opt['datasets'].items():
|
56 |
+
phase = phase.split('_')[0]
|
57 |
+
dataset['phase'] = phase
|
58 |
+
dataset['scale'] = opt['scale'] # broadcast
|
59 |
+
dataset['n_channels'] = opt['n_channels'] # broadcast
|
60 |
+
if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
|
61 |
+
dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
|
62 |
+
if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
|
63 |
+
dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
|
64 |
+
|
65 |
+
# ----------------------------------------
|
66 |
+
# path
|
67 |
+
# ----------------------------------------
|
68 |
+
for key, path in opt['path'].items():
|
69 |
+
if path and key in opt['path']:
|
70 |
+
opt['path'][key] = os.path.expanduser(path)
|
71 |
+
|
72 |
+
path_task = os.path.join(opt['path']['root'], opt['task'])
|
73 |
+
opt['path']['task'] = path_task
|
74 |
+
opt['path']['log'] = path_task
|
75 |
+
opt['path']['options'] = os.path.join(path_task, 'options')
|
76 |
+
|
77 |
+
if is_train:
|
78 |
+
opt['path']['models'] = os.path.join(path_task, 'models')
|
79 |
+
opt['path']['images'] = os.path.join(path_task, 'images')
|
80 |
+
else: # test
|
81 |
+
opt['path']['images'] = os.path.join(path_task, 'test_images')
|
82 |
+
|
83 |
+
# ----------------------------------------
|
84 |
+
# network
|
85 |
+
# ----------------------------------------
|
86 |
+
opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
|
87 |
+
|
88 |
+
# ----------------------------------------
|
89 |
+
# GPU devices
|
90 |
+
# ----------------------------------------
|
91 |
+
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
92 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
93 |
+
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
94 |
+
|
95 |
+
# ----------------------------------------
|
96 |
+
# default setting for distributeddataparallel
|
97 |
+
# ----------------------------------------
|
98 |
+
if 'find_unused_parameters' not in opt:
|
99 |
+
opt['find_unused_parameters'] = True
|
100 |
+
if 'use_static_graph' not in opt:
|
101 |
+
opt['use_static_graph'] = False
|
102 |
+
if 'dist' not in opt:
|
103 |
+
opt['dist'] = False
|
104 |
+
opt['num_gpu'] = len(opt['gpu_ids'])
|
105 |
+
print('number of GPUs is: ' + str(opt['num_gpu']))
|
106 |
+
|
107 |
+
# ----------------------------------------
|
108 |
+
# default setting for perceptual loss
|
109 |
+
# ----------------------------------------
|
110 |
+
if 'F_feature_layer' not in opt['train']:
|
111 |
+
opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34]
|
112 |
+
if 'F_weights' not in opt['train']:
|
113 |
+
opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0]
|
114 |
+
if 'F_lossfn_type' not in opt['train']:
|
115 |
+
opt['train']['F_lossfn_type'] = 'l1'
|
116 |
+
if 'F_use_input_norm' not in opt['train']:
|
117 |
+
opt['train']['F_use_input_norm'] = True
|
118 |
+
if 'F_use_range_norm' not in opt['train']:
|
119 |
+
opt['train']['F_use_range_norm'] = False
|
120 |
+
|
121 |
+
# ----------------------------------------
|
122 |
+
# default setting for optimizer
|
123 |
+
# ----------------------------------------
|
124 |
+
if 'G_optimizer_type' not in opt['train']:
|
125 |
+
opt['train']['G_optimizer_type'] = "adam"
|
126 |
+
if 'G_optimizer_betas' not in opt['train']:
|
127 |
+
opt['train']['G_optimizer_betas'] = [0.9,0.999]
|
128 |
+
if 'G_scheduler_restart_weights' not in opt['train']:
|
129 |
+
opt['train']['G_scheduler_restart_weights'] = 1
|
130 |
+
if 'G_optimizer_wd' not in opt['train']:
|
131 |
+
opt['train']['G_optimizer_wd'] = 0
|
132 |
+
if 'G_optimizer_reuse' not in opt['train']:
|
133 |
+
opt['train']['G_optimizer_reuse'] = False
|
134 |
+
if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
|
135 |
+
opt['train']['D_optimizer_reuse'] = False
|
136 |
+
|
137 |
+
# ----------------------------------------
|
138 |
+
# default setting of strict for model loading
|
139 |
+
# ----------------------------------------
|
140 |
+
if 'G_param_strict' not in opt['train']:
|
141 |
+
opt['train']['G_param_strict'] = True
|
142 |
+
if 'netD' in opt and 'D_param_strict' not in opt['path']:
|
143 |
+
opt['train']['D_param_strict'] = True
|
144 |
+
if 'E_param_strict' not in opt['path']:
|
145 |
+
opt['train']['E_param_strict'] = True
|
146 |
+
|
147 |
+
# ----------------------------------------
|
148 |
+
# Exponential Moving Average
|
149 |
+
# ----------------------------------------
|
150 |
+
if 'E_decay' not in opt['train']:
|
151 |
+
opt['train']['E_decay'] = 0
|
152 |
+
|
153 |
+
# ----------------------------------------
|
154 |
+
# default setting for discriminator
|
155 |
+
# ----------------------------------------
|
156 |
+
if 'netD' in opt:
|
157 |
+
if 'net_type' not in opt['netD']:
|
158 |
+
opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet
|
159 |
+
if 'in_nc' not in opt['netD']:
|
160 |
+
opt['netD']['in_nc'] = 3
|
161 |
+
if 'base_nc' not in opt['netD']:
|
162 |
+
opt['netD']['base_nc'] = 64
|
163 |
+
if 'n_layers' not in opt['netD']:
|
164 |
+
opt['netD']['n_layers'] = 3
|
165 |
+
if 'norm_type' not in opt['netD']:
|
166 |
+
opt['netD']['norm_type'] = 'spectral'
|
167 |
+
|
168 |
+
|
169 |
+
return opt
|
170 |
+
|
171 |
+
|
172 |
+
def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
|
173 |
+
"""
|
174 |
+
Args:
|
175 |
+
save_dir: model folder
|
176 |
+
net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
|
177 |
+
pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
|
178 |
+
|
179 |
+
Return:
|
180 |
+
init_iter: iteration number
|
181 |
+
init_path: model path
|
182 |
+
"""
|
183 |
+
file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
|
184 |
+
if file_list:
|
185 |
+
iter_exist = []
|
186 |
+
for file_ in file_list:
|
187 |
+
iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
|
188 |
+
iter_exist.append(int(iter_current[0]))
|
189 |
+
init_iter = max(iter_exist)
|
190 |
+
init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
|
191 |
+
else:
|
192 |
+
init_iter = 0
|
193 |
+
init_path = pretrained_path
|
194 |
+
return init_iter, init_path
|
195 |
+
|
196 |
+
|
197 |
+
'''
|
198 |
+
# --------------------------------------------
|
199 |
+
# convert the opt into json file
|
200 |
+
# --------------------------------------------
|
201 |
+
'''
|
202 |
+
|
203 |
+
|
204 |
+
def save(opt):
|
205 |
+
opt_path = opt['opt_path']
|
206 |
+
opt_path_copy = opt['path']['options']
|
207 |
+
dirname, filename_ext = os.path.split(opt_path)
|
208 |
+
filename, ext = os.path.splitext(filename_ext)
|
209 |
+
dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
|
210 |
+
with open(dump_path, 'w') as dump_file:
|
211 |
+
json.dump(opt, dump_file, indent=2)
|
212 |
+
|
213 |
+
|
214 |
+
'''
|
215 |
+
# --------------------------------------------
|
216 |
+
# dict to string for logger
|
217 |
+
# --------------------------------------------
|
218 |
+
'''
|
219 |
+
|
220 |
+
|
221 |
+
def dict2str(opt, indent_l=1):
|
222 |
+
msg = ''
|
223 |
+
for k, v in opt.items():
|
224 |
+
if isinstance(v, dict):
|
225 |
+
msg += ' ' * (indent_l * 2) + k + ':[\n'
|
226 |
+
msg += dict2str(v, indent_l + 1)
|
227 |
+
msg += ' ' * (indent_l * 2) + ']\n'
|
228 |
+
else:
|
229 |
+
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
|
230 |
+
return msg
|
231 |
+
|
232 |
+
|
233 |
+
'''
|
234 |
+
# --------------------------------------------
|
235 |
+
# convert OrderedDict to NoneDict,
|
236 |
+
# return None for missing key
|
237 |
+
# --------------------------------------------
|
238 |
+
'''
|
239 |
+
|
240 |
+
|
241 |
+
def dict_to_nonedict(opt):
|
242 |
+
if isinstance(opt, dict):
|
243 |
+
new_opt = dict()
|
244 |
+
for key, sub_opt in opt.items():
|
245 |
+
new_opt[key] = dict_to_nonedict(sub_opt)
|
246 |
+
return NoneDict(**new_opt)
|
247 |
+
elif isinstance(opt, list):
|
248 |
+
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
249 |
+
else:
|
250 |
+
return opt
|
251 |
+
|
252 |
+
|
253 |
+
class NoneDict(dict):
|
254 |
+
def __missing__(self, key):
|
255 |
+
return None
|
core/data/deg_kair_utils/utils_params.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import torchvision
|
4 |
+
|
5 |
+
from models import basicblock as B
|
6 |
+
|
7 |
+
def show_kv(net):
|
8 |
+
for k, v in net.items():
|
9 |
+
print(k)
|
10 |
+
|
11 |
+
# should run train debug mode first to get an initial model
|
12 |
+
#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
|
13 |
+
#
|
14 |
+
#for k, v in crt_net.items():
|
15 |
+
# print(k)
|
16 |
+
#for k, v in crt_net.items():
|
17 |
+
# if k in pretrained_net:
|
18 |
+
# crt_net[k] = pretrained_net[k]
|
19 |
+
# print('replace ... ', k)
|
20 |
+
|
21 |
+
# x2 -> x4
|
22 |
+
#crt_net['model.5.weight'] = pretrained_net['model.2.weight']
|
23 |
+
#crt_net['model.5.bias'] = pretrained_net['model.2.bias']
|
24 |
+
#crt_net['model.8.weight'] = pretrained_net['model.5.weight']
|
25 |
+
#crt_net['model.8.bias'] = pretrained_net['model.5.bias']
|
26 |
+
#crt_net['model.10.weight'] = pretrained_net['model.7.weight']
|
27 |
+
#crt_net['model.10.bias'] = pretrained_net['model.7.bias']
|
28 |
+
#torch.save(crt_net, '../pretrained_tmp.pth')
|
29 |
+
|
30 |
+
# x2 -> x3
|
31 |
+
'''
|
32 |
+
in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
|
33 |
+
new_filter = torch.Tensor(576, 64, 3, 3)
|
34 |
+
new_filter[0:256, :, :, :] = in_filter
|
35 |
+
new_filter[256:512, :, :, :] = in_filter
|
36 |
+
new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
|
37 |
+
crt_net['model.2.weight'] = new_filter
|
38 |
+
|
39 |
+
in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3
|
40 |
+
new_bias = torch.Tensor(576)
|
41 |
+
new_bias[0:256] = in_bias
|
42 |
+
new_bias[256:512] = in_bias
|
43 |
+
new_bias[512:] = in_bias[0:576 - 512]
|
44 |
+
crt_net['model.2.bias'] = new_bias
|
45 |
+
|
46 |
+
torch.save(crt_net, '../pretrained_tmp.pth')
|
47 |
+
'''
|
48 |
+
|
49 |
+
# x2 -> x8
|
50 |
+
'''
|
51 |
+
crt_net['model.5.weight'] = pretrained_net['model.2.weight']
|
52 |
+
crt_net['model.5.bias'] = pretrained_net['model.2.bias']
|
53 |
+
crt_net['model.8.weight'] = pretrained_net['model.2.weight']
|
54 |
+
crt_net['model.8.bias'] = pretrained_net['model.2.bias']
|
55 |
+
crt_net['model.11.weight'] = pretrained_net['model.5.weight']
|
56 |
+
crt_net['model.11.bias'] = pretrained_net['model.5.bias']
|
57 |
+
crt_net['model.13.weight'] = pretrained_net['model.7.weight']
|
58 |
+
crt_net['model.13.bias'] = pretrained_net['model.7.bias']
|
59 |
+
torch.save(crt_net, '../pretrained_tmp.pth')
|
60 |
+
'''
|
61 |
+
|
62 |
+
# x3/4/8 RGB -> Y
|
63 |
+
|
64 |
+
def rgb2gray_net(net, only_input=True):
|
65 |
+
|
66 |
+
if only_input:
|
67 |
+
in_filter = net['0.weight']
|
68 |
+
in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
|
69 |
+
in_new_filter.unsqueeze_(1)
|
70 |
+
net['0.weight'] = in_new_filter
|
71 |
+
|
72 |
+
# out_filter = pretrained_net['model.13.weight']
|
73 |
+
# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
|
74 |
+
# out_filter[2, :, :, :] * 0.114
|
75 |
+
# out_new_filter.unsqueeze_(0)
|
76 |
+
# crt_net['model.13.weight'] = out_new_filter
|
77 |
+
# out_bias = pretrained_net['model.13.bias']
|
78 |
+
# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
|
79 |
+
# out_new_bias = torch.Tensor(1).fill_(out_new_bias)
|
80 |
+
# crt_net['model.13.bias'] = out_new_bias
|
81 |
+
|
82 |
+
# torch.save(crt_net, '../pretrained_tmp.pth')
|
83 |
+
|
84 |
+
return net
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
|
90 |
+
net = torchvision.models.vgg19(pretrained=True)
|
91 |
+
for k,v in net.features.named_parameters():
|
92 |
+
if k=='0.weight':
|
93 |
+
in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
|
94 |
+
in_new_filter.unsqueeze_(1)
|
95 |
+
v = in_new_filter
|
96 |
+
print(v.shape)
|
97 |
+
print(v[0,0,0,0])
|
98 |
+
if k=='0.bias':
|
99 |
+
in_new_bias = v
|
100 |
+
print(v[0])
|
101 |
+
|
102 |
+
print(net.features[0])
|
103 |
+
|
104 |
+
net.features[0] = B.conv(1, 64, mode='C')
|
105 |
+
|
106 |
+
print(net.features[0])
|
107 |
+
net.features[0].weight.data=in_new_filter
|
108 |
+
net.features[0].bias.data=in_new_bias
|
109 |
+
|
110 |
+
for k,v in net.features.named_parameters():
|
111 |
+
if k=='0.weight':
|
112 |
+
print(v[0,0,0,0])
|
113 |
+
if k=='0.bias':
|
114 |
+
print(v[0])
|
115 |
+
|
116 |
+
# transfer parameters of old model to new one
|
117 |
+
model_old = torch.load(model_path)
|
118 |
+
state_dict = model.state_dict()
|
119 |
+
for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
|
120 |
+
state_dict[key2] = param
|
121 |
+
print([key, key2])
|
122 |
+
# print([param.size(), param2.size()])
|
123 |
+
torch.save(state_dict, 'model_new.pth')
|
124 |
+
|
125 |
+
|
126 |
+
# rgb2gray_net(net)
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
core/data/deg_kair_utils/utils_receptivefield.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# online calculation: https://fomoro.com/research/article/receptive-field-calculator#
|
4 |
+
|
5 |
+
# [filter size, stride, padding]
|
6 |
+
#Assume the two dimensions are the same
|
7 |
+
#Each kernel requires the following parameters:
|
8 |
+
# - k_i: kernel size
|
9 |
+
# - s_i: stride
|
10 |
+
# - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow)
|
11 |
+
#
|
12 |
+
#Each layer i requires the following parameters to be fully represented:
|
13 |
+
# - n_i: number of feature (data layer has n_1 = imagesize )
|
14 |
+
# - j_i: distance (projected to image pixel distance) between center of two adjacent features
|
15 |
+
# - r_i: receptive field of a feature in layer i
|
16 |
+
# - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding)
|
17 |
+
|
18 |
+
import math
|
19 |
+
|
20 |
+
def outFromIn(conv, layerIn):
|
21 |
+
n_in = layerIn[0]
|
22 |
+
j_in = layerIn[1]
|
23 |
+
r_in = layerIn[2]
|
24 |
+
start_in = layerIn[3]
|
25 |
+
k = conv[0]
|
26 |
+
s = conv[1]
|
27 |
+
p = conv[2]
|
28 |
+
|
29 |
+
n_out = math.floor((n_in - k + 2*p)/s) + 1
|
30 |
+
actualP = (n_out-1)*s - n_in + k
|
31 |
+
pR = math.ceil(actualP/2)
|
32 |
+
pL = math.floor(actualP/2)
|
33 |
+
|
34 |
+
j_out = j_in * s
|
35 |
+
r_out = r_in + (k - 1)*j_in
|
36 |
+
start_out = start_in + ((k-1)/2 - pL)*j_in
|
37 |
+
return n_out, j_out, r_out, start_out
|
38 |
+
|
39 |
+
def printLayer(layer, layer_name):
|
40 |
+
print(layer_name + ":")
|
41 |
+
print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3]))
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
layerInfos = []
|
46 |
+
if __name__ == '__main__':
|
47 |
+
|
48 |
+
convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]]
|
49 |
+
layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12']
|
50 |
+
imsize = 128
|
51 |
+
|
52 |
+
print ("-------Net summary------")
|
53 |
+
currentLayer = [imsize, 1, 1, 0.5]
|
54 |
+
printLayer(currentLayer, "input image")
|
55 |
+
for i in range(len(convnet)):
|
56 |
+
currentLayer = outFromIn(convnet[i], currentLayer)
|
57 |
+
layerInfos.append(currentLayer)
|
58 |
+
printLayer(currentLayer, layer_names[i])
|
59 |
+
|
60 |
+
|
61 |
+
# run utils/utils_receptivefield.py
|
62 |
+
|
core/data/deg_kair_utils/utils_regularizers.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
'''
|
6 |
+
# --------------------------------------------
|
7 |
+
# Kai Zhang (github: https://github.com/cszn)
|
8 |
+
# 03/Mar/2019
|
9 |
+
# --------------------------------------------
|
10 |
+
'''
|
11 |
+
|
12 |
+
|
13 |
+
# --------------------------------------------
|
14 |
+
# SVD Orthogonal Regularization
|
15 |
+
# --------------------------------------------
|
16 |
+
def regularizer_orth(m):
|
17 |
+
"""
|
18 |
+
# ----------------------------------------
|
19 |
+
# SVD Orthogonal Regularization
|
20 |
+
# ----------------------------------------
|
21 |
+
# Applies regularization to the training by performing the
|
22 |
+
# orthogonalization technique described in the paper
|
23 |
+
# This function is to be called by the torch.nn.Module.apply() method,
|
24 |
+
# which applies svd_orthogonalization() to every layer of the model.
|
25 |
+
# usage: net.apply(regularizer_orth)
|
26 |
+
# ----------------------------------------
|
27 |
+
"""
|
28 |
+
classname = m.__class__.__name__
|
29 |
+
if classname.find('Conv') != -1:
|
30 |
+
w = m.weight.data.clone()
|
31 |
+
c_out, c_in, f1, f2 = w.size()
|
32 |
+
# dtype = m.weight.data.type()
|
33 |
+
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
|
34 |
+
# self.netG.apply(svd_orthogonalization)
|
35 |
+
u, s, v = torch.svd(w)
|
36 |
+
s[s > 1.5] = s[s > 1.5] - 1e-4
|
37 |
+
s[s < 0.5] = s[s < 0.5] + 1e-4
|
38 |
+
w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
|
39 |
+
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
|
40 |
+
else:
|
41 |
+
pass
|
42 |
+
|
43 |
+
|
44 |
+
# --------------------------------------------
|
45 |
+
# SVD Orthogonal Regularization
|
46 |
+
# --------------------------------------------
|
47 |
+
def regularizer_orth2(m):
|
48 |
+
"""
|
49 |
+
# ----------------------------------------
|
50 |
+
# Applies regularization to the training by performing the
|
51 |
+
# orthogonalization technique described in the paper
|
52 |
+
# This function is to be called by the torch.nn.Module.apply() method,
|
53 |
+
# which applies svd_orthogonalization() to every layer of the model.
|
54 |
+
# usage: net.apply(regularizer_orth2)
|
55 |
+
# ----------------------------------------
|
56 |
+
"""
|
57 |
+
classname = m.__class__.__name__
|
58 |
+
if classname.find('Conv') != -1:
|
59 |
+
w = m.weight.data.clone()
|
60 |
+
c_out, c_in, f1, f2 = w.size()
|
61 |
+
# dtype = m.weight.data.type()
|
62 |
+
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
|
63 |
+
u, s, v = torch.svd(w)
|
64 |
+
s_mean = s.mean()
|
65 |
+
s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
|
66 |
+
s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
|
67 |
+
w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
|
68 |
+
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
|
69 |
+
else:
|
70 |
+
pass
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
def regularizer_clip(m):
|
75 |
+
"""
|
76 |
+
# ----------------------------------------
|
77 |
+
# usage: net.apply(regularizer_clip)
|
78 |
+
# ----------------------------------------
|
79 |
+
"""
|
80 |
+
eps = 1e-4
|
81 |
+
c_min = -1.5
|
82 |
+
c_max = 1.5
|
83 |
+
|
84 |
+
classname = m.__class__.__name__
|
85 |
+
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
|
86 |
+
w = m.weight.data.clone()
|
87 |
+
w[w > c_max] -= eps
|
88 |
+
w[w < c_min] += eps
|
89 |
+
m.weight.data = w
|
90 |
+
|
91 |
+
if m.bias is not None:
|
92 |
+
b = m.bias.data.clone()
|
93 |
+
b[b > c_max] -= eps
|
94 |
+
b[b < c_min] += eps
|
95 |
+
m.bias.data = b
|
96 |
+
|
97 |
+
# elif classname.find('BatchNorm2d') != -1:
|
98 |
+
#
|
99 |
+
# rv = m.running_var.data.clone()
|
100 |
+
# rm = m.running_mean.data.clone()
|
101 |
+
#
|
102 |
+
# if m.affine:
|
103 |
+
# m.weight.data
|
104 |
+
# m.bias.data
|
core/data/deg_kair_utils/utils_sisr.py
ADDED
@@ -0,0 +1,848 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from utils import utils_image as util
|
3 |
+
import random
|
4 |
+
|
5 |
+
import scipy
|
6 |
+
import scipy.stats as ss
|
7 |
+
import scipy.io as io
|
8 |
+
from scipy import ndimage
|
9 |
+
from scipy.interpolate import interp2d
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
"""
|
16 |
+
# --------------------------------------------
|
17 |
+
# Super-Resolution
|
18 |
+
# --------------------------------------------
|
19 |
+
#
|
20 |
+
# Kai Zhang ([email protected])
|
21 |
+
# https://github.com/cszn
|
22 |
+
# modified by Kai Zhang (github: https://github.com/cszn)
|
23 |
+
# 03/03/2020
|
24 |
+
# --------------------------------------------
|
25 |
+
"""
|
26 |
+
|
27 |
+
|
28 |
+
"""
|
29 |
+
# --------------------------------------------
|
30 |
+
# anisotropic Gaussian kernels
|
31 |
+
# --------------------------------------------
|
32 |
+
"""
|
33 |
+
|
34 |
+
|
35 |
+
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
36 |
+
""" generate an anisotropic Gaussian kernel
|
37 |
+
Args:
|
38 |
+
ksize : e.g., 15, kernel size
|
39 |
+
theta : [0, pi], rotation angle range
|
40 |
+
l1 : [0.1,50], scaling of eigenvalues
|
41 |
+
l2 : [0.1,l1], scaling of eigenvalues
|
42 |
+
If l1 = l2, will get an isotropic Gaussian kernel.
|
43 |
+
Returns:
|
44 |
+
k : kernel
|
45 |
+
"""
|
46 |
+
|
47 |
+
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
48 |
+
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
49 |
+
D = np.array([[l1, 0], [0, l2]])
|
50 |
+
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
51 |
+
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
52 |
+
|
53 |
+
return k
|
54 |
+
|
55 |
+
|
56 |
+
def gm_blur_kernel(mean, cov, size=15):
|
57 |
+
center = size / 2.0 + 0.5
|
58 |
+
k = np.zeros([size, size])
|
59 |
+
for y in range(size):
|
60 |
+
for x in range(size):
|
61 |
+
cy = y - center + 1
|
62 |
+
cx = x - center + 1
|
63 |
+
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
64 |
+
|
65 |
+
k = k / np.sum(k)
|
66 |
+
return k
|
67 |
+
|
68 |
+
|
69 |
+
"""
|
70 |
+
# --------------------------------------------
|
71 |
+
# calculate PCA projection matrix
|
72 |
+
# --------------------------------------------
|
73 |
+
"""
|
74 |
+
|
75 |
+
|
76 |
+
def get_pca_matrix(x, dim_pca=15):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
x: 225x10000 matrix
|
80 |
+
dim_pca: 15
|
81 |
+
Returns:
|
82 |
+
pca_matrix: 15x225
|
83 |
+
"""
|
84 |
+
C = np.dot(x, x.T)
|
85 |
+
w, v = scipy.linalg.eigh(C)
|
86 |
+
pca_matrix = v[:, -dim_pca:].T
|
87 |
+
|
88 |
+
return pca_matrix
|
89 |
+
|
90 |
+
|
91 |
+
def show_pca(x):
|
92 |
+
"""
|
93 |
+
x: PCA projection matrix, e.g., 15x225
|
94 |
+
"""
|
95 |
+
for i in range(x.shape[0]):
|
96 |
+
xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
|
97 |
+
util.surf(xc)
|
98 |
+
|
99 |
+
|
100 |
+
def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
|
101 |
+
kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
|
102 |
+
for i in range(num_samples):
|
103 |
+
|
104 |
+
theta = np.pi*np.random.rand(1)
|
105 |
+
l1 = 0.1+l_max*np.random.rand(1)
|
106 |
+
l2 = 0.1+(l1-0.1)*np.random.rand(1)
|
107 |
+
|
108 |
+
k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
|
109 |
+
|
110 |
+
# util.imshow(k)
|
111 |
+
|
112 |
+
kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F')
|
113 |
+
|
114 |
+
# io.savemat('k.mat', {'k': kernels})
|
115 |
+
|
116 |
+
pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
|
117 |
+
|
118 |
+
io.savemat(path, {'p': pca_matrix})
|
119 |
+
|
120 |
+
return pca_matrix
|
121 |
+
|
122 |
+
|
123 |
+
"""
|
124 |
+
# --------------------------------------------
|
125 |
+
# shifted anisotropic Gaussian kernels
|
126 |
+
# --------------------------------------------
|
127 |
+
"""
|
128 |
+
|
129 |
+
|
130 |
+
def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
131 |
+
""""
|
132 |
+
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
133 |
+
# Kai Zhang
|
134 |
+
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
135 |
+
# max_var = 2.5 * sf
|
136 |
+
"""
|
137 |
+
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
138 |
+
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
139 |
+
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
140 |
+
theta = np.random.rand() * np.pi # random theta
|
141 |
+
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
142 |
+
|
143 |
+
# Set COV matrix using Lambdas and Theta
|
144 |
+
LAMBDA = np.diag([lambda_1, lambda_2])
|
145 |
+
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
146 |
+
[np.sin(theta), np.cos(theta)]])
|
147 |
+
SIGMA = Q @ LAMBDA @ Q.T
|
148 |
+
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
149 |
+
|
150 |
+
# Set expectation position (shifting kernel for aligned image)
|
151 |
+
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
152 |
+
MU = MU[None, None, :, None]
|
153 |
+
|
154 |
+
# Create meshgrid for Gaussian
|
155 |
+
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
156 |
+
Z = np.stack([X, Y], 2)[:, :, :, None]
|
157 |
+
|
158 |
+
# Calcualte Gaussian for every pixel of the kernel
|
159 |
+
ZZ = Z-MU
|
160 |
+
ZZ_t = ZZ.transpose(0,1,3,2)
|
161 |
+
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
162 |
+
|
163 |
+
# shift the kernel so it will be centered
|
164 |
+
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
165 |
+
|
166 |
+
# Normalize the kernel and return
|
167 |
+
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
168 |
+
kernel = raw_kernel / np.sum(raw_kernel)
|
169 |
+
return kernel
|
170 |
+
|
171 |
+
|
172 |
+
def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
|
173 |
+
""""
|
174 |
+
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
175 |
+
# Kai Zhang
|
176 |
+
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
177 |
+
# max_var = 2.5 * sf
|
178 |
+
"""
|
179 |
+
sf = random.choice([1, 2, 3, 4])
|
180 |
+
scale_factor = np.array([sf, sf])
|
181 |
+
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
182 |
+
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
183 |
+
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
184 |
+
theta = np.random.rand() * np.pi # random theta
|
185 |
+
noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
|
186 |
+
|
187 |
+
# Set COV matrix using Lambdas and Theta
|
188 |
+
LAMBDA = np.diag([lambda_1, lambda_2])
|
189 |
+
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
190 |
+
[np.sin(theta), np.cos(theta)]])
|
191 |
+
SIGMA = Q @ LAMBDA @ Q.T
|
192 |
+
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
193 |
+
|
194 |
+
# Set expectation position (shifting kernel for aligned image)
|
195 |
+
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
196 |
+
MU = MU[None, None, :, None]
|
197 |
+
|
198 |
+
# Create meshgrid for Gaussian
|
199 |
+
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
200 |
+
Z = np.stack([X, Y], 2)[:, :, :, None]
|
201 |
+
|
202 |
+
# Calcualte Gaussian for every pixel of the kernel
|
203 |
+
ZZ = Z-MU
|
204 |
+
ZZ_t = ZZ.transpose(0,1,3,2)
|
205 |
+
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
206 |
+
|
207 |
+
# shift the kernel so it will be centered
|
208 |
+
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
209 |
+
|
210 |
+
# Normalize the kernel and return
|
211 |
+
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
212 |
+
kernel = raw_kernel / np.sum(raw_kernel)
|
213 |
+
return kernel
|
214 |
+
|
215 |
+
|
216 |
+
"""
|
217 |
+
# --------------------------------------------
|
218 |
+
# degradation models
|
219 |
+
# --------------------------------------------
|
220 |
+
"""
|
221 |
+
|
222 |
+
|
223 |
+
def bicubic_degradation(x, sf=3):
|
224 |
+
'''
|
225 |
+
Args:
|
226 |
+
x: HxWxC image, [0, 1]
|
227 |
+
sf: down-scale factor
|
228 |
+
Return:
|
229 |
+
bicubicly downsampled LR image
|
230 |
+
'''
|
231 |
+
x = util.imresize_np(x, scale=1/sf)
|
232 |
+
return x
|
233 |
+
|
234 |
+
|
235 |
+
def srmd_degradation(x, k, sf=3):
|
236 |
+
''' blur + bicubic downsampling
|
237 |
+
Args:
|
238 |
+
x: HxWxC image, [0, 1]
|
239 |
+
k: hxw, double
|
240 |
+
sf: down-scale factor
|
241 |
+
Return:
|
242 |
+
downsampled LR image
|
243 |
+
Reference:
|
244 |
+
@inproceedings{zhang2018learning,
|
245 |
+
title={Learning a single convolutional super-resolution network for multiple degradations},
|
246 |
+
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
247 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
248 |
+
pages={3262--3271},
|
249 |
+
year={2018}
|
250 |
+
}
|
251 |
+
'''
|
252 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
253 |
+
x = bicubic_degradation(x, sf=sf)
|
254 |
+
return x
|
255 |
+
|
256 |
+
|
257 |
+
def dpsr_degradation(x, k, sf=3):
|
258 |
+
|
259 |
+
''' bicubic downsampling + blur
|
260 |
+
Args:
|
261 |
+
x: HxWxC image, [0, 1]
|
262 |
+
k: hxw, double
|
263 |
+
sf: down-scale factor
|
264 |
+
Return:
|
265 |
+
downsampled LR image
|
266 |
+
Reference:
|
267 |
+
@inproceedings{zhang2019deep,
|
268 |
+
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
269 |
+
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
270 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
271 |
+
pages={1671--1681},
|
272 |
+
year={2019}
|
273 |
+
}
|
274 |
+
'''
|
275 |
+
x = bicubic_degradation(x, sf=sf)
|
276 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
277 |
+
return x
|
278 |
+
|
279 |
+
|
280 |
+
def classical_degradation(x, k, sf=3):
|
281 |
+
''' blur + downsampling
|
282 |
+
|
283 |
+
Args:
|
284 |
+
x: HxWxC image, [0, 1]/[0, 255]
|
285 |
+
k: hxw, double
|
286 |
+
sf: down-scale factor
|
287 |
+
|
288 |
+
Return:
|
289 |
+
downsampled LR image
|
290 |
+
'''
|
291 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
292 |
+
#x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
293 |
+
st = 0
|
294 |
+
return x[st::sf, st::sf, ...]
|
295 |
+
|
296 |
+
|
297 |
+
def modcrop_np(img, sf):
|
298 |
+
'''
|
299 |
+
Args:
|
300 |
+
img: numpy image, WxH or WxHxC
|
301 |
+
sf: scale factor
|
302 |
+
Return:
|
303 |
+
cropped image
|
304 |
+
'''
|
305 |
+
w, h = img.shape[:2]
|
306 |
+
im = np.copy(img)
|
307 |
+
return im[:w - w % sf, :h - h % sf, ...]
|
308 |
+
|
309 |
+
|
310 |
+
'''
|
311 |
+
# =================
|
312 |
+
# Numpy
|
313 |
+
# =================
|
314 |
+
'''
|
315 |
+
|
316 |
+
|
317 |
+
def shift_pixel(x, sf, upper_left=True):
|
318 |
+
"""shift pixel for super-resolution with different scale factors
|
319 |
+
Args:
|
320 |
+
x: WxHxC or WxH, image or kernel
|
321 |
+
sf: scale factor
|
322 |
+
upper_left: shift direction
|
323 |
+
"""
|
324 |
+
h, w = x.shape[:2]
|
325 |
+
shift = (sf-1)*0.5
|
326 |
+
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
327 |
+
if upper_left:
|
328 |
+
x1 = xv + shift
|
329 |
+
y1 = yv + shift
|
330 |
+
else:
|
331 |
+
x1 = xv - shift
|
332 |
+
y1 = yv - shift
|
333 |
+
|
334 |
+
x1 = np.clip(x1, 0, w-1)
|
335 |
+
y1 = np.clip(y1, 0, h-1)
|
336 |
+
|
337 |
+
if x.ndim == 2:
|
338 |
+
x = interp2d(xv, yv, x)(x1, y1)
|
339 |
+
if x.ndim == 3:
|
340 |
+
for i in range(x.shape[-1]):
|
341 |
+
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
342 |
+
|
343 |
+
return x
|
344 |
+
|
345 |
+
|
346 |
+
'''
|
347 |
+
# =================
|
348 |
+
# pytorch
|
349 |
+
# =================
|
350 |
+
'''
|
351 |
+
|
352 |
+
|
353 |
+
def splits(a, sf):
|
354 |
+
'''
|
355 |
+
a: tensor NxCxWxHx2
|
356 |
+
sf: scale factor
|
357 |
+
out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
|
358 |
+
'''
|
359 |
+
b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
|
360 |
+
b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
|
361 |
+
return b
|
362 |
+
|
363 |
+
|
364 |
+
def c2c(x):
|
365 |
+
return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
|
366 |
+
|
367 |
+
|
368 |
+
def r2c(x):
|
369 |
+
return torch.stack([x, torch.zeros_like(x)], -1)
|
370 |
+
|
371 |
+
|
372 |
+
def cdiv(x, y):
|
373 |
+
a, b = x[..., 0], x[..., 1]
|
374 |
+
c, d = y[..., 0], y[..., 1]
|
375 |
+
cd2 = c**2 + d**2
|
376 |
+
return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
|
377 |
+
|
378 |
+
|
379 |
+
def csum(x, y):
|
380 |
+
return torch.stack([x[..., 0] + y, x[..., 1]], -1)
|
381 |
+
|
382 |
+
|
383 |
+
def cabs(x):
|
384 |
+
return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
|
385 |
+
|
386 |
+
|
387 |
+
def cmul(t1, t2):
|
388 |
+
'''
|
389 |
+
complex multiplication
|
390 |
+
t1: NxCxHxWx2
|
391 |
+
output: NxCxHxWx2
|
392 |
+
'''
|
393 |
+
real1, imag1 = t1[..., 0], t1[..., 1]
|
394 |
+
real2, imag2 = t2[..., 0], t2[..., 1]
|
395 |
+
return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
|
396 |
+
|
397 |
+
|
398 |
+
def cconj(t, inplace=False):
|
399 |
+
'''
|
400 |
+
# complex's conjugation
|
401 |
+
t: NxCxHxWx2
|
402 |
+
output: NxCxHxWx2
|
403 |
+
'''
|
404 |
+
c = t.clone() if not inplace else t
|
405 |
+
c[..., 1] *= -1
|
406 |
+
return c
|
407 |
+
|
408 |
+
|
409 |
+
def rfft(t):
|
410 |
+
return torch.rfft(t, 2, onesided=False)
|
411 |
+
|
412 |
+
|
413 |
+
def irfft(t):
|
414 |
+
return torch.irfft(t, 2, onesided=False)
|
415 |
+
|
416 |
+
|
417 |
+
def fft(t):
|
418 |
+
return torch.fft(t, 2)
|
419 |
+
|
420 |
+
|
421 |
+
def ifft(t):
|
422 |
+
return torch.ifft(t, 2)
|
423 |
+
|
424 |
+
|
425 |
+
def p2o(psf, shape):
|
426 |
+
'''
|
427 |
+
Args:
|
428 |
+
psf: NxCxhxw
|
429 |
+
shape: [H,W]
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
otf: NxCxHxWx2
|
433 |
+
'''
|
434 |
+
otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
|
435 |
+
otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
|
436 |
+
for axis, axis_size in enumerate(psf.shape[2:]):
|
437 |
+
otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
|
438 |
+
otf = torch.rfft(otf, 2, onesided=False)
|
439 |
+
n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
|
440 |
+
otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
|
441 |
+
return otf
|
442 |
+
|
443 |
+
|
444 |
+
'''
|
445 |
+
# =================
|
446 |
+
PyTorch
|
447 |
+
# =================
|
448 |
+
'''
|
449 |
+
|
450 |
+
def INVLS_pytorch(FB, FBC, F2B, FR, tau, sf=2):
|
451 |
+
'''
|
452 |
+
FB: NxCxWxHx2
|
453 |
+
F2B: NxCxWxHx2
|
454 |
+
|
455 |
+
x1 = FB.*FR;
|
456 |
+
FBR = BlockMM(nr,nc,Nb,m,x1);
|
457 |
+
invW = BlockMM(nr,nc,Nb,m,F2B);
|
458 |
+
invWBR = FBR./(invW + tau*Nb);
|
459 |
+
fun = @(block_struct) block_struct.data.*invWBR;
|
460 |
+
FCBinvWBR = blockproc(FBC,[nr,nc],fun);
|
461 |
+
FX = (FR-FCBinvWBR)/tau;
|
462 |
+
Xest = real(ifft2(FX));
|
463 |
+
'''
|
464 |
+
x1 = cmul(FB, FR)
|
465 |
+
FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
|
466 |
+
invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
|
467 |
+
invWBR = cdiv(FBR, csum(invW, tau))
|
468 |
+
FCBinvWBR = cmul(FBC, invWBR.repeat(1,1,sf,sf,1))
|
469 |
+
FX = (FR-FCBinvWBR)/tau
|
470 |
+
Xest = torch.irfft(FX, 2, onesided=False)
|
471 |
+
return Xest
|
472 |
+
|
473 |
+
|
474 |
+
def real2complex(x):
|
475 |
+
return torch.stack([x, torch.zeros_like(x)], -1)
|
476 |
+
|
477 |
+
|
478 |
+
def modcrop(img, sf):
|
479 |
+
'''
|
480 |
+
img: tensor image, NxCxWxH or CxWxH or WxH
|
481 |
+
sf: scale factor
|
482 |
+
'''
|
483 |
+
w, h = img.shape[-2:]
|
484 |
+
im = img.clone()
|
485 |
+
return im[..., :w - w % sf, :h - h % sf]
|
486 |
+
|
487 |
+
|
488 |
+
def upsample(x, sf=3, center=False):
|
489 |
+
'''
|
490 |
+
x: tensor image, NxCxWxH
|
491 |
+
'''
|
492 |
+
st = (sf-1)//2 if center else 0
|
493 |
+
z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
|
494 |
+
z[..., st::sf, st::sf].copy_(x)
|
495 |
+
return z
|
496 |
+
|
497 |
+
|
498 |
+
def downsample(x, sf=3, center=False):
|
499 |
+
st = (sf-1)//2 if center else 0
|
500 |
+
return x[..., st::sf, st::sf]
|
501 |
+
|
502 |
+
|
503 |
+
def circular_pad(x, pad):
|
504 |
+
'''
|
505 |
+
# x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
|
506 |
+
'''
|
507 |
+
x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
|
508 |
+
x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
|
509 |
+
x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
|
510 |
+
x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
|
511 |
+
return x
|
512 |
+
|
513 |
+
|
514 |
+
def pad_circular(input, padding):
|
515 |
+
# type: (Tensor, List[int]) -> Tensor
|
516 |
+
"""
|
517 |
+
Arguments
|
518 |
+
:param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
|
519 |
+
:param padding: (tuple): m-elem tuple where m is the degree of convolution
|
520 |
+
Returns
|
521 |
+
:return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
|
522 |
+
H + 2 * padding[1]], W + 2 * padding[2]))`
|
523 |
+
"""
|
524 |
+
offset = 3
|
525 |
+
for dimension in range(input.dim() - offset + 1):
|
526 |
+
input = dim_pad_circular(input, padding[dimension], dimension + offset)
|
527 |
+
return input
|
528 |
+
|
529 |
+
|
530 |
+
def dim_pad_circular(input, padding, dimension):
|
531 |
+
# type: (Tensor, int, int) -> Tensor
|
532 |
+
input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
|
533 |
+
[slice(0, padding)]]], dim=dimension - 1)
|
534 |
+
input = torch.cat([input[[slice(None)] * (dimension - 1) +
|
535 |
+
[slice(-2 * padding, -padding)]], input], dim=dimension - 1)
|
536 |
+
return input
|
537 |
+
|
538 |
+
|
539 |
+
def imfilter(x, k):
|
540 |
+
'''
|
541 |
+
x: image, NxcxHxW
|
542 |
+
k: kernel, cx1xhxw
|
543 |
+
'''
|
544 |
+
x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
|
545 |
+
x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
|
546 |
+
return x
|
547 |
+
|
548 |
+
|
549 |
+
def G(x, k, sf=3, center=False):
|
550 |
+
'''
|
551 |
+
x: image, NxcxHxW
|
552 |
+
k: kernel, cx1xhxw
|
553 |
+
sf: scale factor
|
554 |
+
center: the first one or the moddle one
|
555 |
+
|
556 |
+
Matlab function:
|
557 |
+
tmp = imfilter(x,h,'circular');
|
558 |
+
y = downsample2(tmp,K);
|
559 |
+
'''
|
560 |
+
x = downsample(imfilter(x, k), sf=sf, center=center)
|
561 |
+
return x
|
562 |
+
|
563 |
+
|
564 |
+
def Gt(x, k, sf=3, center=False):
|
565 |
+
'''
|
566 |
+
x: image, NxcxHxW
|
567 |
+
k: kernel, cx1xhxw
|
568 |
+
sf: scale factor
|
569 |
+
center: the first one or the moddle one
|
570 |
+
|
571 |
+
Matlab function:
|
572 |
+
tmp = upsample2(x,K);
|
573 |
+
y = imfilter(tmp,h,'circular');
|
574 |
+
'''
|
575 |
+
x = imfilter(upsample(x, sf=sf, center=center), k)
|
576 |
+
return x
|
577 |
+
|
578 |
+
|
579 |
+
def interpolation_down(x, sf, center=False):
|
580 |
+
mask = torch.zeros_like(x)
|
581 |
+
if center:
|
582 |
+
start = torch.tensor((sf-1)//2)
|
583 |
+
mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
|
584 |
+
LR = x[..., start::sf, start::sf]
|
585 |
+
else:
|
586 |
+
mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
|
587 |
+
LR = x[..., ::sf, ::sf]
|
588 |
+
y = x.mul(mask)
|
589 |
+
|
590 |
+
return LR, y, mask
|
591 |
+
|
592 |
+
|
593 |
+
'''
|
594 |
+
# =================
|
595 |
+
Numpy
|
596 |
+
# =================
|
597 |
+
'''
|
598 |
+
|
599 |
+
|
600 |
+
def blockproc(im, blocksize, fun):
|
601 |
+
xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
|
602 |
+
xblocks_proc = []
|
603 |
+
for xb in xblocks:
|
604 |
+
yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
|
605 |
+
yblocks_proc = []
|
606 |
+
for yb in yblocks:
|
607 |
+
yb_proc = fun(yb)
|
608 |
+
yblocks_proc.append(yb_proc)
|
609 |
+
xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
|
610 |
+
|
611 |
+
proc = np.concatenate(xblocks_proc, axis=0)
|
612 |
+
|
613 |
+
return proc
|
614 |
+
|
615 |
+
|
616 |
+
def fun_reshape(a):
|
617 |
+
return np.reshape(a, (-1,1,a.shape[-1]), order='F')
|
618 |
+
|
619 |
+
|
620 |
+
def fun_mul(a, b):
|
621 |
+
return a*b
|
622 |
+
|
623 |
+
|
624 |
+
def BlockMM(nr, nc, Nb, m, x1):
|
625 |
+
'''
|
626 |
+
myfun = @(block_struct) reshape(block_struct.data,m,1);
|
627 |
+
x1 = blockproc(x1,[nr nc],myfun);
|
628 |
+
x1 = reshape(x1,m,Nb);
|
629 |
+
x1 = sum(x1,2);
|
630 |
+
x = reshape(x1,nr,nc);
|
631 |
+
'''
|
632 |
+
fun = fun_reshape
|
633 |
+
x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
|
634 |
+
x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
|
635 |
+
x1 = np.sum(x1, 1)
|
636 |
+
x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
|
637 |
+
return x
|
638 |
+
|
639 |
+
|
640 |
+
def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
|
641 |
+
'''
|
642 |
+
x1 = FB.*FR;
|
643 |
+
FBR = BlockMM(nr,nc,Nb,m,x1);
|
644 |
+
invW = BlockMM(nr,nc,Nb,m,F2B);
|
645 |
+
invWBR = FBR./(invW + tau*Nb);
|
646 |
+
fun = @(block_struct) block_struct.data.*invWBR;
|
647 |
+
FCBinvWBR = blockproc(FBC,[nr,nc],fun);
|
648 |
+
FX = (FR-FCBinvWBR)/tau;
|
649 |
+
Xest = real(ifft2(FX));
|
650 |
+
'''
|
651 |
+
x1 = FB*FR
|
652 |
+
FBR = BlockMM(nr, nc, Nb, m, x1)
|
653 |
+
invW = BlockMM(nr, nc, Nb, m, F2B)
|
654 |
+
invWBR = FBR/(invW + tau*Nb)
|
655 |
+
FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
|
656 |
+
FX = (FR-FCBinvWBR)/tau
|
657 |
+
Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
|
658 |
+
return Xest
|
659 |
+
|
660 |
+
|
661 |
+
def psf2otf(psf, shape=None):
|
662 |
+
"""
|
663 |
+
Convert point-spread function to optical transfer function.
|
664 |
+
Compute the Fast Fourier Transform (FFT) of the point-spread
|
665 |
+
function (PSF) array and creates the optical transfer function (OTF)
|
666 |
+
array that is not influenced by the PSF off-centering.
|
667 |
+
By default, the OTF array is the same size as the PSF array.
|
668 |
+
To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
|
669 |
+
post-pads the PSF array (down or to the right) with zeros to match
|
670 |
+
dimensions specified in OUTSIZE, then circularly shifts the values of
|
671 |
+
the PSF array up (or to the left) until the central pixel reaches (1,1)
|
672 |
+
position.
|
673 |
+
Parameters
|
674 |
+
----------
|
675 |
+
psf : `numpy.ndarray`
|
676 |
+
PSF array
|
677 |
+
shape : int
|
678 |
+
Output shape of the OTF array
|
679 |
+
Returns
|
680 |
+
-------
|
681 |
+
otf : `numpy.ndarray`
|
682 |
+
OTF array
|
683 |
+
Notes
|
684 |
+
-----
|
685 |
+
Adapted from MATLAB psf2otf function
|
686 |
+
"""
|
687 |
+
if type(shape) == type(None):
|
688 |
+
shape = psf.shape
|
689 |
+
shape = np.array(shape)
|
690 |
+
if np.all(psf == 0):
|
691 |
+
# return np.zeros_like(psf)
|
692 |
+
return np.zeros(shape)
|
693 |
+
if len(psf.shape) == 1:
|
694 |
+
psf = psf.reshape((1, psf.shape[0]))
|
695 |
+
inshape = psf.shape
|
696 |
+
psf = zero_pad(psf, shape, position='corner')
|
697 |
+
for axis, axis_size in enumerate(inshape):
|
698 |
+
psf = np.roll(psf, -int(axis_size / 2), axis=axis)
|
699 |
+
# Compute the OTF
|
700 |
+
otf = np.fft.fft2(psf, axes=(0, 1))
|
701 |
+
# Estimate the rough number of operations involved in the FFT
|
702 |
+
# and discard the PSF imaginary part if within roundoff error
|
703 |
+
# roundoff error = machine epsilon = sys.float_info.epsilon
|
704 |
+
# or np.finfo().eps
|
705 |
+
n_ops = np.sum(psf.size * np.log2(psf.shape))
|
706 |
+
otf = np.real_if_close(otf, tol=n_ops)
|
707 |
+
return otf
|
708 |
+
|
709 |
+
|
710 |
+
def zero_pad(image, shape, position='corner'):
|
711 |
+
"""
|
712 |
+
Extends image to a certain size with zeros
|
713 |
+
Parameters
|
714 |
+
----------
|
715 |
+
image: real 2d `numpy.ndarray`
|
716 |
+
Input image
|
717 |
+
shape: tuple of int
|
718 |
+
Desired output shape of the image
|
719 |
+
position : str, optional
|
720 |
+
The position of the input image in the output one:
|
721 |
+
* 'corner'
|
722 |
+
top-left corner (default)
|
723 |
+
* 'center'
|
724 |
+
centered
|
725 |
+
Returns
|
726 |
+
-------
|
727 |
+
padded_img: real `numpy.ndarray`
|
728 |
+
The zero-padded image
|
729 |
+
"""
|
730 |
+
shape = np.asarray(shape, dtype=int)
|
731 |
+
imshape = np.asarray(image.shape, dtype=int)
|
732 |
+
if np.alltrue(imshape == shape):
|
733 |
+
return image
|
734 |
+
if np.any(shape <= 0):
|
735 |
+
raise ValueError("ZERO_PAD: null or negative shape given")
|
736 |
+
dshape = shape - imshape
|
737 |
+
if np.any(dshape < 0):
|
738 |
+
raise ValueError("ZERO_PAD: target size smaller than source one")
|
739 |
+
pad_img = np.zeros(shape, dtype=image.dtype)
|
740 |
+
idx, idy = np.indices(imshape)
|
741 |
+
if position == 'center':
|
742 |
+
if np.any(dshape % 2 != 0):
|
743 |
+
raise ValueError("ZERO_PAD: source and target shapes "
|
744 |
+
"have different parity.")
|
745 |
+
offx, offy = dshape // 2
|
746 |
+
else:
|
747 |
+
offx, offy = (0, 0)
|
748 |
+
pad_img[idx + offx, idy + offy] = image
|
749 |
+
return pad_img
|
750 |
+
|
751 |
+
|
752 |
+
def upsample_np(x, sf=3, center=False):
|
753 |
+
st = (sf-1)//2 if center else 0
|
754 |
+
z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
|
755 |
+
z[st::sf, st::sf, ...] = x
|
756 |
+
return z
|
757 |
+
|
758 |
+
|
759 |
+
def downsample_np(x, sf=3, center=False):
|
760 |
+
st = (sf-1)//2 if center else 0
|
761 |
+
return x[st::sf, st::sf, ...]
|
762 |
+
|
763 |
+
|
764 |
+
def imfilter_np(x, k):
|
765 |
+
'''
|
766 |
+
x: image, NxcxHxW
|
767 |
+
k: kernel, cx1xhxw
|
768 |
+
'''
|
769 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
770 |
+
return x
|
771 |
+
|
772 |
+
|
773 |
+
def G_np(x, k, sf=3, center=False):
|
774 |
+
'''
|
775 |
+
x: image, NxcxHxW
|
776 |
+
k: kernel, cx1xhxw
|
777 |
+
|
778 |
+
Matlab function:
|
779 |
+
tmp = imfilter(x,h,'circular');
|
780 |
+
y = downsample2(tmp,K);
|
781 |
+
'''
|
782 |
+
x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
|
783 |
+
return x
|
784 |
+
|
785 |
+
|
786 |
+
def Gt_np(x, k, sf=3, center=False):
|
787 |
+
'''
|
788 |
+
x: image, NxcxHxW
|
789 |
+
k: kernel, cx1xhxw
|
790 |
+
|
791 |
+
Matlab function:
|
792 |
+
tmp = upsample2(x,K);
|
793 |
+
y = imfilter(tmp,h,'circular');
|
794 |
+
'''
|
795 |
+
x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
|
796 |
+
return x
|
797 |
+
|
798 |
+
|
799 |
+
if __name__ == '__main__':
|
800 |
+
img = util.imread_uint('test.bmp', 3)
|
801 |
+
|
802 |
+
img = util.uint2single(img)
|
803 |
+
k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
|
804 |
+
util.imshow(k*10)
|
805 |
+
|
806 |
+
|
807 |
+
for sf in [2, 3, 4]:
|
808 |
+
|
809 |
+
# modcrop
|
810 |
+
img = modcrop_np(img, sf=sf)
|
811 |
+
|
812 |
+
# 1) bicubic degradation
|
813 |
+
img_b = bicubic_degradation(img, sf=sf)
|
814 |
+
print(img_b.shape)
|
815 |
+
|
816 |
+
# 2) srmd degradation
|
817 |
+
img_s = srmd_degradation(img, k, sf=sf)
|
818 |
+
print(img_s.shape)
|
819 |
+
|
820 |
+
# 3) dpsr degradation
|
821 |
+
img_d = dpsr_degradation(img, k, sf=sf)
|
822 |
+
print(img_d.shape)
|
823 |
+
|
824 |
+
# 4) classical degradation
|
825 |
+
img_d = classical_degradation(img, k, sf=sf)
|
826 |
+
print(img_d.shape)
|
827 |
+
|
828 |
+
k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
|
829 |
+
#print(k)
|
830 |
+
# util.imshow(k*10)
|
831 |
+
|
832 |
+
k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
|
833 |
+
# util.imshow(k*10)
|
834 |
+
|
835 |
+
|
836 |
+
# PCA
|
837 |
+
# pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
|
838 |
+
# print(pca_matrix.shape)
|
839 |
+
# show_pca(pca_matrix)
|
840 |
+
# run utils/utils_sisr.py
|
841 |
+
# run utils_sisr.py
|
842 |
+
|
843 |
+
|
844 |
+
|
845 |
+
|
846 |
+
|
847 |
+
|
848 |
+
|
core/data/deg_kair_utils/utils_video.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from os import path as osp
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from abc import ABCMeta, abstractmethod
|
9 |
+
|
10 |
+
|
11 |
+
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
12 |
+
"""Scan a directory to find the interested files.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
dir_path (str): Path of the directory.
|
16 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
17 |
+
interested in. Default: None.
|
18 |
+
recursive (bool, optional): If set to True, recursively scan the
|
19 |
+
directory. Default: False.
|
20 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
21 |
+
Default: False.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
A generator for all the interested files with relative paths.
|
25 |
+
"""
|
26 |
+
|
27 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
28 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
29 |
+
|
30 |
+
root = dir_path
|
31 |
+
|
32 |
+
def _scandir(dir_path, suffix, recursive):
|
33 |
+
for entry in os.scandir(dir_path):
|
34 |
+
if not entry.name.startswith('.') and entry.is_file():
|
35 |
+
if full_path:
|
36 |
+
return_path = entry.path
|
37 |
+
else:
|
38 |
+
return_path = osp.relpath(entry.path, root)
|
39 |
+
|
40 |
+
if suffix is None:
|
41 |
+
yield return_path
|
42 |
+
elif return_path.endswith(suffix):
|
43 |
+
yield return_path
|
44 |
+
else:
|
45 |
+
if recursive:
|
46 |
+
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
47 |
+
else:
|
48 |
+
continue
|
49 |
+
|
50 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
51 |
+
|
52 |
+
|
53 |
+
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
54 |
+
"""Read a sequence of images from a given folder path.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
path (list[str] | str): List of image paths or image folder path.
|
58 |
+
require_mod_crop (bool): Require mod crop for each image.
|
59 |
+
Default: False.
|
60 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
61 |
+
return_imgname(bool): Whether return image names. Default False.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
65 |
+
list[str]: Returned image name list.
|
66 |
+
"""
|
67 |
+
if isinstance(path, list):
|
68 |
+
img_paths = path
|
69 |
+
else:
|
70 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
71 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
72 |
+
|
73 |
+
if require_mod_crop:
|
74 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
75 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
76 |
+
imgs = torch.stack(imgs, dim=0)
|
77 |
+
|
78 |
+
if return_imgname:
|
79 |
+
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
80 |
+
return imgs, imgnames
|
81 |
+
else:
|
82 |
+
return imgs
|
83 |
+
|
84 |
+
|
85 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
86 |
+
"""Numpy array to tensor.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
imgs (list[ndarray] | ndarray): Input images.
|
90 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
91 |
+
float32 (bool): Whether to change to float32.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
95 |
+
one element, just return tensor.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def _totensor(img, bgr2rgb, float32):
|
99 |
+
if img.shape[2] == 3 and bgr2rgb:
|
100 |
+
if img.dtype == 'float64':
|
101 |
+
img = img.astype('float32')
|
102 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
103 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
104 |
+
if float32:
|
105 |
+
img = img.float()
|
106 |
+
return img
|
107 |
+
|
108 |
+
if isinstance(imgs, list):
|
109 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
110 |
+
else:
|
111 |
+
return _totensor(imgs, bgr2rgb, float32)
|
112 |
+
|
113 |
+
|
114 |
+
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
115 |
+
"""Convert torch Tensors into image numpy arrays.
|
116 |
+
|
117 |
+
After clamping to [min, max], values will be normalized to [0, 1].
|
118 |
+
|
119 |
+
Args:
|
120 |
+
tensor (Tensor or list[Tensor]): Accept shapes:
|
121 |
+
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
122 |
+
2) 3D Tensor of shape (3/1 x H x W);
|
123 |
+
3) 2D Tensor of shape (H x W).
|
124 |
+
Tensor channel should be in RGB order.
|
125 |
+
rgb2bgr (bool): Whether to change rgb to bgr.
|
126 |
+
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
127 |
+
to uint8 type with range [0, 255]; otherwise, float type with
|
128 |
+
range [0, 1]. Default: ``np.uint8``.
|
129 |
+
min_max (tuple[int]): min and max values for clamp.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
133 |
+
shape (H x W). The channel order is BGR.
|
134 |
+
"""
|
135 |
+
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
136 |
+
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
137 |
+
|
138 |
+
if torch.is_tensor(tensor):
|
139 |
+
tensor = [tensor]
|
140 |
+
result = []
|
141 |
+
for _tensor in tensor:
|
142 |
+
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
143 |
+
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
144 |
+
|
145 |
+
n_dim = _tensor.dim()
|
146 |
+
if n_dim == 4:
|
147 |
+
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
148 |
+
img_np = img_np.transpose(1, 2, 0)
|
149 |
+
if rgb2bgr:
|
150 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
151 |
+
elif n_dim == 3:
|
152 |
+
img_np = _tensor.numpy()
|
153 |
+
img_np = img_np.transpose(1, 2, 0)
|
154 |
+
if img_np.shape[2] == 1: # gray image
|
155 |
+
img_np = np.squeeze(img_np, axis=2)
|
156 |
+
else:
|
157 |
+
if rgb2bgr:
|
158 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
159 |
+
elif n_dim == 2:
|
160 |
+
img_np = _tensor.numpy()
|
161 |
+
else:
|
162 |
+
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
163 |
+
if out_type == np.uint8:
|
164 |
+
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
165 |
+
img_np = (img_np * 255.0).round()
|
166 |
+
img_np = img_np.astype(out_type)
|
167 |
+
result.append(img_np)
|
168 |
+
if len(result) == 1:
|
169 |
+
result = result[0]
|
170 |
+
return result
|
171 |
+
|
172 |
+
|
173 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
174 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
175 |
+
|
176 |
+
We use vertical flip and transpose for rotation implementation.
|
177 |
+
All the images in the list use the same augmentation.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
181 |
+
is an ndarray, it will be transformed to a list.
|
182 |
+
hflip (bool): Horizontal flip. Default: True.
|
183 |
+
rotation (bool): Ratotation. Default: True.
|
184 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
185 |
+
ndarray, it will be transformed to a list.
|
186 |
+
Dimension is (h, w, 2). Default: None.
|
187 |
+
return_status (bool): Return the status of flip and rotation.
|
188 |
+
Default: False.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
192 |
+
results only have one element, just return ndarray.
|
193 |
+
|
194 |
+
"""
|
195 |
+
hflip = hflip and random.random() < 0.5
|
196 |
+
vflip = rotation and random.random() < 0.5
|
197 |
+
rot90 = rotation and random.random() < 0.5
|
198 |
+
|
199 |
+
def _augment(img):
|
200 |
+
if hflip: # horizontal
|
201 |
+
cv2.flip(img, 1, img)
|
202 |
+
if vflip: # vertical
|
203 |
+
cv2.flip(img, 0, img)
|
204 |
+
if rot90:
|
205 |
+
img = img.transpose(1, 0, 2)
|
206 |
+
return img
|
207 |
+
|
208 |
+
def _augment_flow(flow):
|
209 |
+
if hflip: # horizontal
|
210 |
+
cv2.flip(flow, 1, flow)
|
211 |
+
flow[:, :, 0] *= -1
|
212 |
+
if vflip: # vertical
|
213 |
+
cv2.flip(flow, 0, flow)
|
214 |
+
flow[:, :, 1] *= -1
|
215 |
+
if rot90:
|
216 |
+
flow = flow.transpose(1, 0, 2)
|
217 |
+
flow = flow[:, :, [1, 0]]
|
218 |
+
return flow
|
219 |
+
|
220 |
+
if not isinstance(imgs, list):
|
221 |
+
imgs = [imgs]
|
222 |
+
imgs = [_augment(img) for img in imgs]
|
223 |
+
if len(imgs) == 1:
|
224 |
+
imgs = imgs[0]
|
225 |
+
|
226 |
+
if flows is not None:
|
227 |
+
if not isinstance(flows, list):
|
228 |
+
flows = [flows]
|
229 |
+
flows = [_augment_flow(flow) for flow in flows]
|
230 |
+
if len(flows) == 1:
|
231 |
+
flows = flows[0]
|
232 |
+
return imgs, flows
|
233 |
+
else:
|
234 |
+
if return_status:
|
235 |
+
return imgs, (hflip, vflip, rot90)
|
236 |
+
else:
|
237 |
+
return imgs
|
238 |
+
|
239 |
+
|
240 |
+
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
|
241 |
+
"""Paired random crop. Support Numpy array and Tensor inputs.
|
242 |
+
|
243 |
+
It crops lists of lq and gt images with corresponding locations.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
|
247 |
+
should have the same shape. If the input is an ndarray, it will
|
248 |
+
be transformed to a list containing itself.
|
249 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
250 |
+
should have the same shape. If the input is an ndarray, it will
|
251 |
+
be transformed to a list containing itself.
|
252 |
+
gt_patch_size (int): GT patch size.
|
253 |
+
scale (int): Scale factor.
|
254 |
+
gt_path (str): Path to ground-truth. Default: None.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
258 |
+
only have one element, just return ndarray.
|
259 |
+
"""
|
260 |
+
|
261 |
+
if not isinstance(img_gts, list):
|
262 |
+
img_gts = [img_gts]
|
263 |
+
if not isinstance(img_lqs, list):
|
264 |
+
img_lqs = [img_lqs]
|
265 |
+
|
266 |
+
# determine input type: Numpy array or Tensor
|
267 |
+
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
268 |
+
|
269 |
+
if input_type == 'Tensor':
|
270 |
+
h_lq, w_lq = img_lqs[0].size()[-2:]
|
271 |
+
h_gt, w_gt = img_gts[0].size()[-2:]
|
272 |
+
else:
|
273 |
+
h_lq, w_lq = img_lqs[0].shape[0:2]
|
274 |
+
h_gt, w_gt = img_gts[0].shape[0:2]
|
275 |
+
lq_patch_size = gt_patch_size // scale
|
276 |
+
|
277 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
278 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
279 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
280 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
281 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
282 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
283 |
+
f'Please remove {gt_path}.')
|
284 |
+
|
285 |
+
# randomly choose top and left coordinates for lq patch
|
286 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
287 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
288 |
+
|
289 |
+
# crop lq patch
|
290 |
+
if input_type == 'Tensor':
|
291 |
+
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
292 |
+
else:
|
293 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
294 |
+
|
295 |
+
# crop corresponding gt patch
|
296 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
297 |
+
if input_type == 'Tensor':
|
298 |
+
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
299 |
+
else:
|
300 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
301 |
+
if len(img_gts) == 1:
|
302 |
+
img_gts = img_gts[0]
|
303 |
+
if len(img_lqs) == 1:
|
304 |
+
img_lqs = img_lqs[0]
|
305 |
+
return img_gts, img_lqs
|
306 |
+
|
307 |
+
|
308 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
309 |
+
class BaseStorageBackend(metaclass=ABCMeta):
|
310 |
+
"""Abstract class of storage backends.
|
311 |
+
|
312 |
+
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
313 |
+
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
314 |
+
as texts.
|
315 |
+
"""
|
316 |
+
|
317 |
+
@abstractmethod
|
318 |
+
def get(self, filepath):
|
319 |
+
pass
|
320 |
+
|
321 |
+
@abstractmethod
|
322 |
+
def get_text(self, filepath):
|
323 |
+
pass
|
324 |
+
|
325 |
+
|
326 |
+
class MemcachedBackend(BaseStorageBackend):
|
327 |
+
"""Memcached storage backend.
|
328 |
+
|
329 |
+
Attributes:
|
330 |
+
server_list_cfg (str): Config file for memcached server list.
|
331 |
+
client_cfg (str): Config file for memcached client.
|
332 |
+
sys_path (str | None): Additional path to be appended to `sys.path`.
|
333 |
+
Default: None.
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
337 |
+
if sys_path is not None:
|
338 |
+
import sys
|
339 |
+
sys.path.append(sys_path)
|
340 |
+
try:
|
341 |
+
import mc
|
342 |
+
except ImportError:
|
343 |
+
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
344 |
+
|
345 |
+
self.server_list_cfg = server_list_cfg
|
346 |
+
self.client_cfg = client_cfg
|
347 |
+
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
348 |
+
# mc.pyvector servers as a point which points to a memory cache
|
349 |
+
self._mc_buffer = mc.pyvector()
|
350 |
+
|
351 |
+
def get(self, filepath):
|
352 |
+
filepath = str(filepath)
|
353 |
+
import mc
|
354 |
+
self._client.Get(filepath, self._mc_buffer)
|
355 |
+
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
356 |
+
return value_buf
|
357 |
+
|
358 |
+
def get_text(self, filepath):
|
359 |
+
raise NotImplementedError
|
360 |
+
|
361 |
+
|
362 |
+
class HardDiskBackend(BaseStorageBackend):
|
363 |
+
"""Raw hard disks storage backend."""
|
364 |
+
|
365 |
+
def get(self, filepath):
|
366 |
+
filepath = str(filepath)
|
367 |
+
with open(filepath, 'rb') as f:
|
368 |
+
value_buf = f.read()
|
369 |
+
return value_buf
|
370 |
+
|
371 |
+
def get_text(self, filepath):
|
372 |
+
filepath = str(filepath)
|
373 |
+
with open(filepath, 'r') as f:
|
374 |
+
value_buf = f.read()
|
375 |
+
return value_buf
|
376 |
+
|
377 |
+
|
378 |
+
class LmdbBackend(BaseStorageBackend):
|
379 |
+
"""Lmdb storage backend.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
db_paths (str | list[str]): Lmdb database paths.
|
383 |
+
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
384 |
+
readonly (bool, optional): Lmdb environment parameter. If True,
|
385 |
+
disallow any write operations. Default: True.
|
386 |
+
lock (bool, optional): Lmdb environment parameter. If False, when
|
387 |
+
concurrent access occurs, do not lock the database. Default: False.
|
388 |
+
readahead (bool, optional): Lmdb environment parameter. If False,
|
389 |
+
disable the OS filesystem readahead mechanism, which may improve
|
390 |
+
random read performance when a database is larger than RAM.
|
391 |
+
Default: False.
|
392 |
+
|
393 |
+
Attributes:
|
394 |
+
db_paths (list): Lmdb database path.
|
395 |
+
_client (list): A list of several lmdb envs.
|
396 |
+
"""
|
397 |
+
|
398 |
+
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
399 |
+
try:
|
400 |
+
import lmdb
|
401 |
+
except ImportError:
|
402 |
+
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
403 |
+
|
404 |
+
if isinstance(client_keys, str):
|
405 |
+
client_keys = [client_keys]
|
406 |
+
|
407 |
+
if isinstance(db_paths, list):
|
408 |
+
self.db_paths = [str(v) for v in db_paths]
|
409 |
+
elif isinstance(db_paths, str):
|
410 |
+
self.db_paths = [str(db_paths)]
|
411 |
+
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
412 |
+
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
413 |
+
|
414 |
+
self._client = {}
|
415 |
+
for client, path in zip(client_keys, self.db_paths):
|
416 |
+
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
417 |
+
|
418 |
+
def get(self, filepath, client_key):
|
419 |
+
"""Get values according to the filepath from one lmdb named client_key.
|
420 |
+
|
421 |
+
Args:
|
422 |
+
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
423 |
+
client_key (str): Used for distinguishing different lmdb envs.
|
424 |
+
"""
|
425 |
+
filepath = str(filepath)
|
426 |
+
assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
|
427 |
+
client = self._client[client_key]
|
428 |
+
with client.begin(write=False) as txn:
|
429 |
+
value_buf = txn.get(filepath.encode('ascii'))
|
430 |
+
return value_buf
|
431 |
+
|
432 |
+
def get_text(self, filepath):
|
433 |
+
raise NotImplementedError
|
434 |
+
|
435 |
+
|
436 |
+
class FileClient(object):
|
437 |
+
"""A general file client to access files in different backend.
|
438 |
+
|
439 |
+
The client loads a file or text in a specified backend from its path
|
440 |
+
and return it as a binary file. it can also register other backend
|
441 |
+
accessor with a given name and backend class.
|
442 |
+
|
443 |
+
Attributes:
|
444 |
+
backend (str): The storage backend type. Options are "disk",
|
445 |
+
"memcached" and "lmdb".
|
446 |
+
client (:obj:`BaseStorageBackend`): The backend object.
|
447 |
+
"""
|
448 |
+
|
449 |
+
_backends = {
|
450 |
+
'disk': HardDiskBackend,
|
451 |
+
'memcached': MemcachedBackend,
|
452 |
+
'lmdb': LmdbBackend,
|
453 |
+
}
|
454 |
+
|
455 |
+
def __init__(self, backend='disk', **kwargs):
|
456 |
+
if backend not in self._backends:
|
457 |
+
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
458 |
+
f' are {list(self._backends.keys())}')
|
459 |
+
self.backend = backend
|
460 |
+
self.client = self._backends[backend](**kwargs)
|
461 |
+
|
462 |
+
def get(self, filepath, client_key='default'):
|
463 |
+
# client_key is used only for lmdb, where different fileclients have
|
464 |
+
# different lmdb environments.
|
465 |
+
if self.backend == 'lmdb':
|
466 |
+
return self.client.get(filepath, client_key)
|
467 |
+
else:
|
468 |
+
return self.client.get(filepath)
|
469 |
+
|
470 |
+
def get_text(self, filepath):
|
471 |
+
return self.client.get_text(filepath)
|
472 |
+
|
473 |
+
|
474 |
+
def imfrombytes(content, flag='color', float32=False):
|
475 |
+
"""Read an image from bytes.
|
476 |
+
|
477 |
+
Args:
|
478 |
+
content (bytes): Image bytes got from files or other streams.
|
479 |
+
flag (str): Flags specifying the color type of a loaded image,
|
480 |
+
candidates are `color`, `grayscale` and `unchanged`.
|
481 |
+
float32 (bool): Whether to change to float32., If True, will also norm
|
482 |
+
to [0, 1]. Default: False.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
ndarray: Loaded image array.
|
486 |
+
"""
|
487 |
+
img_np = np.frombuffer(content, np.uint8)
|
488 |
+
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
489 |
+
img = cv2.imdecode(img_np, imread_flags[flag])
|
490 |
+
if float32:
|
491 |
+
img = img.astype(np.float32) / 255.
|
492 |
+
return img
|
493 |
+
|
core/data/deg_kair_utils/utils_videoio.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from os import path as osp
|
7 |
+
from torchvision.utils import make_grid
|
8 |
+
import sys
|
9 |
+
from pathlib import Path
|
10 |
+
import six
|
11 |
+
from collections import OrderedDict
|
12 |
+
import math
|
13 |
+
import glob
|
14 |
+
import av
|
15 |
+
import io
|
16 |
+
from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
|
17 |
+
CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
|
18 |
+
CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
|
19 |
+
|
20 |
+
if sys.version_info <= (3, 3):
|
21 |
+
FileNotFoundError = IOError
|
22 |
+
else:
|
23 |
+
FileNotFoundError = FileNotFoundError
|
24 |
+
|
25 |
+
|
26 |
+
def is_str(x):
|
27 |
+
"""Whether the input is an string instance."""
|
28 |
+
return isinstance(x, six.string_types)
|
29 |
+
|
30 |
+
|
31 |
+
def is_filepath(x):
|
32 |
+
return is_str(x) or isinstance(x, Path)
|
33 |
+
|
34 |
+
|
35 |
+
def fopen(filepath, *args, **kwargs):
|
36 |
+
if is_str(filepath):
|
37 |
+
return open(filepath, *args, **kwargs)
|
38 |
+
elif isinstance(filepath, Path):
|
39 |
+
return filepath.open(*args, **kwargs)
|
40 |
+
raise ValueError('`filepath` should be a string or a Path')
|
41 |
+
|
42 |
+
|
43 |
+
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
44 |
+
if not osp.isfile(filename):
|
45 |
+
raise FileNotFoundError(msg_tmpl.format(filename))
|
46 |
+
|
47 |
+
|
48 |
+
def mkdir_or_exist(dir_name, mode=0o777):
|
49 |
+
if dir_name == '':
|
50 |
+
return
|
51 |
+
dir_name = osp.expanduser(dir_name)
|
52 |
+
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
53 |
+
|
54 |
+
|
55 |
+
def symlink(src, dst, overwrite=True, **kwargs):
|
56 |
+
if os.path.lexists(dst) and overwrite:
|
57 |
+
os.remove(dst)
|
58 |
+
os.symlink(src, dst, **kwargs)
|
59 |
+
|
60 |
+
|
61 |
+
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
|
62 |
+
"""Scan a directory to find the interested files.
|
63 |
+
Args:
|
64 |
+
dir_path (str | :obj:`Path`): Path of the directory.
|
65 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
66 |
+
interested in. Default: None.
|
67 |
+
recursive (bool, optional): If set to True, recursively scan the
|
68 |
+
directory. Default: False.
|
69 |
+
case_sensitive (bool, optional) : If set to False, ignore the case of
|
70 |
+
suffix. Default: True.
|
71 |
+
Returns:
|
72 |
+
A generator for all the interested files with relative paths.
|
73 |
+
"""
|
74 |
+
if isinstance(dir_path, (str, Path)):
|
75 |
+
dir_path = str(dir_path)
|
76 |
+
else:
|
77 |
+
raise TypeError('"dir_path" must be a string or Path object')
|
78 |
+
|
79 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
80 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
81 |
+
|
82 |
+
if suffix is not None and not case_sensitive:
|
83 |
+
suffix = suffix.lower() if isinstance(suffix, str) else tuple(
|
84 |
+
item.lower() for item in suffix)
|
85 |
+
|
86 |
+
root = dir_path
|
87 |
+
|
88 |
+
def _scandir(dir_path, suffix, recursive, case_sensitive):
|
89 |
+
for entry in os.scandir(dir_path):
|
90 |
+
if not entry.name.startswith('.') and entry.is_file():
|
91 |
+
rel_path = osp.relpath(entry.path, root)
|
92 |
+
_rel_path = rel_path if case_sensitive else rel_path.lower()
|
93 |
+
if suffix is None or _rel_path.endswith(suffix):
|
94 |
+
yield rel_path
|
95 |
+
elif recursive and os.path.isdir(entry.path):
|
96 |
+
# scan recursively if entry.path is a directory
|
97 |
+
yield from _scandir(entry.path, suffix, recursive,
|
98 |
+
case_sensitive)
|
99 |
+
|
100 |
+
return _scandir(dir_path, suffix, recursive, case_sensitive)
|
101 |
+
|
102 |
+
|
103 |
+
class Cache:
|
104 |
+
|
105 |
+
def __init__(self, capacity):
|
106 |
+
self._cache = OrderedDict()
|
107 |
+
self._capacity = int(capacity)
|
108 |
+
if capacity <= 0:
|
109 |
+
raise ValueError('capacity must be a positive integer')
|
110 |
+
|
111 |
+
@property
|
112 |
+
def capacity(self):
|
113 |
+
return self._capacity
|
114 |
+
|
115 |
+
@property
|
116 |
+
def size(self):
|
117 |
+
return len(self._cache)
|
118 |
+
|
119 |
+
def put(self, key, val):
|
120 |
+
if key in self._cache:
|
121 |
+
return
|
122 |
+
if len(self._cache) >= self.capacity:
|
123 |
+
self._cache.popitem(last=False)
|
124 |
+
self._cache[key] = val
|
125 |
+
|
126 |
+
def get(self, key, default=None):
|
127 |
+
val = self._cache[key] if key in self._cache else default
|
128 |
+
return val
|
129 |
+
|
130 |
+
|
131 |
+
class VideoReader:
|
132 |
+
"""Video class with similar usage to a list object.
|
133 |
+
|
134 |
+
This video warpper class provides convenient apis to access frames.
|
135 |
+
There exists an issue of OpenCV's VideoCapture class that jumping to a
|
136 |
+
certain frame may be inaccurate. It is fixed in this class by checking
|
137 |
+
the position after jumping each time.
|
138 |
+
Cache is used when decoding videos. So if the same frame is visited for
|
139 |
+
the second time, there is no need to decode again if it is stored in the
|
140 |
+
cache.
|
141 |
+
|
142 |
+
"""
|
143 |
+
|
144 |
+
def __init__(self, filename, cache_capacity=10):
|
145 |
+
# Check whether the video path is a url
|
146 |
+
if not filename.startswith(('https://', 'http://')):
|
147 |
+
check_file_exist(filename, 'Video file not found: ' + filename)
|
148 |
+
self._vcap = cv2.VideoCapture(filename)
|
149 |
+
assert cache_capacity > 0
|
150 |
+
self._cache = Cache(cache_capacity)
|
151 |
+
self._position = 0
|
152 |
+
# get basic info
|
153 |
+
self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
|
154 |
+
self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
|
155 |
+
self._fps = self._vcap.get(CAP_PROP_FPS)
|
156 |
+
self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
|
157 |
+
self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
|
158 |
+
|
159 |
+
@property
|
160 |
+
def vcap(self):
|
161 |
+
""":obj:`cv2.VideoCapture`: The raw VideoCapture object."""
|
162 |
+
return self._vcap
|
163 |
+
|
164 |
+
@property
|
165 |
+
def opened(self):
|
166 |
+
"""bool: Indicate whether the video is opened."""
|
167 |
+
return self._vcap.isOpened()
|
168 |
+
|
169 |
+
@property
|
170 |
+
def width(self):
|
171 |
+
"""int: Width of video frames."""
|
172 |
+
return self._width
|
173 |
+
|
174 |
+
@property
|
175 |
+
def height(self):
|
176 |
+
"""int: Height of video frames."""
|
177 |
+
return self._height
|
178 |
+
|
179 |
+
@property
|
180 |
+
def resolution(self):
|
181 |
+
"""tuple: Video resolution (width, height)."""
|
182 |
+
return (self._width, self._height)
|
183 |
+
|
184 |
+
@property
|
185 |
+
def fps(self):
|
186 |
+
"""float: FPS of the video."""
|
187 |
+
return self._fps
|
188 |
+
|
189 |
+
@property
|
190 |
+
def frame_cnt(self):
|
191 |
+
"""int: Total frames of the video."""
|
192 |
+
return self._frame_cnt
|
193 |
+
|
194 |
+
@property
|
195 |
+
def fourcc(self):
|
196 |
+
"""str: "Four character code" of the video."""
|
197 |
+
return self._fourcc
|
198 |
+
|
199 |
+
@property
|
200 |
+
def position(self):
|
201 |
+
"""int: Current cursor position, indicating frame decoded."""
|
202 |
+
return self._position
|
203 |
+
|
204 |
+
def _get_real_position(self):
|
205 |
+
return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
|
206 |
+
|
207 |
+
def _set_real_position(self, frame_id):
|
208 |
+
self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
|
209 |
+
pos = self._get_real_position()
|
210 |
+
for _ in range(frame_id - pos):
|
211 |
+
self._vcap.read()
|
212 |
+
self._position = frame_id
|
213 |
+
|
214 |
+
def read(self):
|
215 |
+
"""Read the next frame.
|
216 |
+
|
217 |
+
If the next frame have been decoded before and in the cache, then
|
218 |
+
return it directly, otherwise decode, cache and return it.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
ndarray or None: Return the frame if successful, otherwise None.
|
222 |
+
"""
|
223 |
+
# pos = self._position
|
224 |
+
if self._cache:
|
225 |
+
img = self._cache.get(self._position)
|
226 |
+
if img is not None:
|
227 |
+
ret = True
|
228 |
+
else:
|
229 |
+
if self._position != self._get_real_position():
|
230 |
+
self._set_real_position(self._position)
|
231 |
+
ret, img = self._vcap.read()
|
232 |
+
if ret:
|
233 |
+
self._cache.put(self._position, img)
|
234 |
+
else:
|
235 |
+
ret, img = self._vcap.read()
|
236 |
+
if ret:
|
237 |
+
self._position += 1
|
238 |
+
return img
|
239 |
+
|
240 |
+
def get_frame(self, frame_id):
|
241 |
+
"""Get frame by index.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
frame_id (int): Index of the expected frame, 0-based.
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
ndarray or None: Return the frame if successful, otherwise None.
|
248 |
+
"""
|
249 |
+
if frame_id < 0 or frame_id >= self._frame_cnt:
|
250 |
+
raise IndexError(
|
251 |
+
f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
|
252 |
+
if frame_id == self._position:
|
253 |
+
return self.read()
|
254 |
+
if self._cache:
|
255 |
+
img = self._cache.get(frame_id)
|
256 |
+
if img is not None:
|
257 |
+
self._position = frame_id + 1
|
258 |
+
return img
|
259 |
+
self._set_real_position(frame_id)
|
260 |
+
ret, img = self._vcap.read()
|
261 |
+
if ret:
|
262 |
+
if self._cache:
|
263 |
+
self._cache.put(self._position, img)
|
264 |
+
self._position += 1
|
265 |
+
return img
|
266 |
+
|
267 |
+
def current_frame(self):
|
268 |
+
"""Get the current frame (frame that is just visited).
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
ndarray or None: If the video is fresh, return None, otherwise
|
272 |
+
return the frame.
|
273 |
+
"""
|
274 |
+
if self._position == 0:
|
275 |
+
return None
|
276 |
+
return self._cache.get(self._position - 1)
|
277 |
+
|
278 |
+
def cvt2frames(self,
|
279 |
+
frame_dir,
|
280 |
+
file_start=0,
|
281 |
+
filename_tmpl='{:06d}.jpg',
|
282 |
+
start=0,
|
283 |
+
max_num=0,
|
284 |
+
show_progress=False):
|
285 |
+
"""Convert a video to frame images.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
frame_dir (str): Output directory to store all the frame images.
|
289 |
+
file_start (int): Filenames will start from the specified number.
|
290 |
+
filename_tmpl (str): Filename template with the index as the
|
291 |
+
placeholder.
|
292 |
+
start (int): The starting frame index.
|
293 |
+
max_num (int): Maximum number of frames to be written.
|
294 |
+
show_progress (bool): Whether to show a progress bar.
|
295 |
+
"""
|
296 |
+
mkdir_or_exist(frame_dir)
|
297 |
+
if max_num == 0:
|
298 |
+
task_num = self.frame_cnt - start
|
299 |
+
else:
|
300 |
+
task_num = min(self.frame_cnt - start, max_num)
|
301 |
+
if task_num <= 0:
|
302 |
+
raise ValueError('start must be less than total frame number')
|
303 |
+
if start > 0:
|
304 |
+
self._set_real_position(start)
|
305 |
+
|
306 |
+
def write_frame(file_idx):
|
307 |
+
img = self.read()
|
308 |
+
if img is None:
|
309 |
+
return
|
310 |
+
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
|
311 |
+
cv2.imwrite(filename, img)
|
312 |
+
|
313 |
+
if show_progress:
|
314 |
+
pass
|
315 |
+
#track_progress(write_frame, range(file_start,file_start + task_num))
|
316 |
+
else:
|
317 |
+
for i in range(task_num):
|
318 |
+
write_frame(file_start + i)
|
319 |
+
|
320 |
+
def __len__(self):
|
321 |
+
return self.frame_cnt
|
322 |
+
|
323 |
+
def __getitem__(self, index):
|
324 |
+
if isinstance(index, slice):
|
325 |
+
return [
|
326 |
+
self.get_frame(i)
|
327 |
+
for i in range(*index.indices(self.frame_cnt))
|
328 |
+
]
|
329 |
+
# support negative indexing
|
330 |
+
if index < 0:
|
331 |
+
index += self.frame_cnt
|
332 |
+
if index < 0:
|
333 |
+
raise IndexError('index out of range')
|
334 |
+
return self.get_frame(index)
|
335 |
+
|
336 |
+
def __iter__(self):
|
337 |
+
self._set_real_position(0)
|
338 |
+
return self
|
339 |
+
|
340 |
+
def __next__(self):
|
341 |
+
img = self.read()
|
342 |
+
if img is not None:
|
343 |
+
return img
|
344 |
+
else:
|
345 |
+
raise StopIteration
|
346 |
+
|
347 |
+
next = __next__
|
348 |
+
|
349 |
+
def __enter__(self):
|
350 |
+
return self
|
351 |
+
|
352 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
353 |
+
self._vcap.release()
|
354 |
+
|
355 |
+
|
356 |
+
def frames2video(frame_dir,
|
357 |
+
video_file,
|
358 |
+
fps=30,
|
359 |
+
fourcc='XVID',
|
360 |
+
filename_tmpl='{:06d}.jpg',
|
361 |
+
start=0,
|
362 |
+
end=0,
|
363 |
+
show_progress=False):
|
364 |
+
"""Read the frame images from a directory and join them as a video.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
frame_dir (str): The directory containing video frames.
|
368 |
+
video_file (str): Output filename.
|
369 |
+
fps (float): FPS of the output video.
|
370 |
+
fourcc (str): Fourcc of the output video, this should be compatible
|
371 |
+
with the output file type.
|
372 |
+
filename_tmpl (str): Filename template with the index as the variable.
|
373 |
+
start (int): Starting frame index.
|
374 |
+
end (int): Ending frame index.
|
375 |
+
show_progress (bool): Whether to show a progress bar.
|
376 |
+
"""
|
377 |
+
if end == 0:
|
378 |
+
ext = filename_tmpl.split('.')[-1]
|
379 |
+
end = len([name for name in scandir(frame_dir, ext)])
|
380 |
+
first_file = osp.join(frame_dir, filename_tmpl.format(start))
|
381 |
+
check_file_exist(first_file, 'The start frame not found: ' + first_file)
|
382 |
+
img = cv2.imread(first_file)
|
383 |
+
height, width = img.shape[:2]
|
384 |
+
resolution = (width, height)
|
385 |
+
vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
|
386 |
+
resolution)
|
387 |
+
|
388 |
+
def write_frame(file_idx):
|
389 |
+
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
|
390 |
+
img = cv2.imread(filename)
|
391 |
+
vwriter.write(img)
|
392 |
+
|
393 |
+
if show_progress:
|
394 |
+
pass
|
395 |
+
# track_progress(write_frame, range(start, end))
|
396 |
+
else:
|
397 |
+
for i in range(start, end):
|
398 |
+
write_frame(i)
|
399 |
+
vwriter.release()
|
400 |
+
|
401 |
+
|
402 |
+
def video2images(video_path, output_dir):
|
403 |
+
vidcap = cv2.VideoCapture(video_path)
|
404 |
+
in_fps = vidcap.get(cv2.CAP_PROP_FPS)
|
405 |
+
print('video fps:', in_fps)
|
406 |
+
if not os.path.isdir(output_dir):
|
407 |
+
os.makedirs(output_dir)
|
408 |
+
loaded, frame = vidcap.read()
|
409 |
+
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
410 |
+
print(f'number of total frames is: {total_frames:06}')
|
411 |
+
for i_frame in range(total_frames):
|
412 |
+
if i_frame % 100 == 0:
|
413 |
+
print(f'{i_frame:06} / {total_frames:06}')
|
414 |
+
frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png')
|
415 |
+
cv2.imwrite(frame_name, frame)
|
416 |
+
loaded, frame = vidcap.read()
|
417 |
+
|
418 |
+
|
419 |
+
def images2video(image_dir, video_path, fps=24, image_ext='png'):
|
420 |
+
'''
|
421 |
+
#codec = cv2.VideoWriter_fourcc(*'XVID')
|
422 |
+
#codec = cv2.VideoWriter_fourcc('A','V','C','1')
|
423 |
+
#codec = cv2.VideoWriter_fourcc('Y','U','V','1')
|
424 |
+
#codec = cv2.VideoWriter_fourcc('P','I','M','1')
|
425 |
+
#codec = cv2.VideoWriter_fourcc('M','J','P','G')
|
426 |
+
codec = cv2.VideoWriter_fourcc('M','P','4','2')
|
427 |
+
#codec = cv2.VideoWriter_fourcc('D','I','V','3')
|
428 |
+
#codec = cv2.VideoWriter_fourcc('D','I','V','X')
|
429 |
+
#codec = cv2.VideoWriter_fourcc('U','2','6','3')
|
430 |
+
#codec = cv2.VideoWriter_fourcc('I','2','6','3')
|
431 |
+
#codec = cv2.VideoWriter_fourcc('F','L','V','1')
|
432 |
+
#codec = cv2.VideoWriter_fourcc('H','2','6','4')
|
433 |
+
#codec = cv2.VideoWriter_fourcc('A','Y','U','V')
|
434 |
+
#codec = cv2.VideoWriter_fourcc('I','U','Y','V')
|
435 |
+
编码器常用的几种:
|
436 |
+
cv2.VideoWriter_fourcc("I", "4", "2", "0")
|
437 |
+
压缩的yuv颜色编码器,4:2:0色彩度子采样 兼容性好,产生很大的视频 avi
|
438 |
+
cv2.VideoWriter_fourcc("P", I", "M", "1")
|
439 |
+
采用mpeg-1编码,文件为avi
|
440 |
+
cv2.VideoWriter_fourcc("X", "V", "T", "D")
|
441 |
+
采用mpeg-4编码,得到视频大小平均 拓展名avi
|
442 |
+
cv2.VideoWriter_fourcc("T", "H", "E", "O")
|
443 |
+
Ogg Vorbis, 拓展名为ogv
|
444 |
+
cv2.VideoWriter_fourcc("F", "L", "V", "1")
|
445 |
+
FLASH视频,拓展名为.flv
|
446 |
+
'''
|
447 |
+
image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext))))
|
448 |
+
print(len(image_files))
|
449 |
+
height, width, _ = cv2.imread(image_files[0]).shape
|
450 |
+
out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V')
|
451 |
+
out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height))
|
452 |
+
|
453 |
+
for image_file in image_files:
|
454 |
+
img = cv2.imread(image_file)
|
455 |
+
img = cv2.resize(img, (width, height), interpolation=3)
|
456 |
+
out_video.write(img)
|
457 |
+
out_video.release()
|
458 |
+
|
459 |
+
|
460 |
+
def add_video_compression(imgs):
|
461 |
+
codec_type = ['libx264', 'h264', 'mpeg4']
|
462 |
+
codec_prob = [1 / 3., 1 / 3., 1 / 3.]
|
463 |
+
codec = random.choices(codec_type, codec_prob)[0]
|
464 |
+
# codec = 'mpeg4'
|
465 |
+
bitrate = [1e4, 1e5]
|
466 |
+
bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
|
467 |
+
|
468 |
+
buf = io.BytesIO()
|
469 |
+
with av.open(buf, 'w', 'mp4') as container:
|
470 |
+
stream = container.add_stream(codec, rate=1)
|
471 |
+
stream.height = imgs[0].shape[0]
|
472 |
+
stream.width = imgs[0].shape[1]
|
473 |
+
stream.pix_fmt = 'yuv420p'
|
474 |
+
stream.bit_rate = bitrate
|
475 |
+
|
476 |
+
for img in imgs:
|
477 |
+
img = np.uint8((img.clip(0, 1)*255.).round())
|
478 |
+
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
479 |
+
frame.pict_type = 'NONE'
|
480 |
+
# pdb.set_trace()
|
481 |
+
for packet in stream.encode(frame):
|
482 |
+
container.mux(packet)
|
483 |
+
|
484 |
+
# Flush stream
|
485 |
+
for packet in stream.encode():
|
486 |
+
container.mux(packet)
|
487 |
+
|
488 |
+
outputs = []
|
489 |
+
with av.open(buf, 'r', 'mp4') as container:
|
490 |
+
if container.streams.video:
|
491 |
+
for frame in container.decode(**{'video': 0}):
|
492 |
+
outputs.append(
|
493 |
+
frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
|
494 |
+
|
495 |
+
#outputs = np.stack(outputs, axis=0)
|
496 |
+
return outputs
|
497 |
+
|
498 |
+
|
499 |
+
if __name__ == '__main__':
|
500 |
+
|
501 |
+
# -----------------------------------
|
502 |
+
# test VideoReader(filename, cache_capacity=10)
|
503 |
+
# -----------------------------------
|
504 |
+
# video_reader = VideoReader('utils/test.mp4')
|
505 |
+
# from utils import utils_image as util
|
506 |
+
# inputs = []
|
507 |
+
# for frame in video_reader:
|
508 |
+
# print(frame.dtype)
|
509 |
+
# util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
510 |
+
# #util.imshow(np.flip(frame, axis=2))
|
511 |
+
|
512 |
+
# -----------------------------------
|
513 |
+
# test video2images(video_path, output_dir)
|
514 |
+
# -----------------------------------
|
515 |
+
# video2images('utils/test.mp4', 'frames')
|
516 |
+
|
517 |
+
# -----------------------------------
|
518 |
+
# test images2video(image_dir, video_path, fps=24, image_ext='png')
|
519 |
+
# -----------------------------------
|
520 |
+
# images2video('frames', 'video_02.mp4', fps=30, image_ext='png')
|
521 |
+
|
522 |
+
|
523 |
+
# -----------------------------------
|
524 |
+
# test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png')
|
525 |
+
# -----------------------------------
|
526 |
+
# frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png')
|
527 |
+
|
528 |
+
|
529 |
+
# -----------------------------------
|
530 |
+
# test add_video_compression(imgs)
|
531 |
+
# -----------------------------------
|
532 |
+
# imgs = []
|
533 |
+
# image_ext = 'png'
|
534 |
+
# frames = 'frames'
|
535 |
+
# from utils import utils_image as util
|
536 |
+
# image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext))))
|
537 |
+
# for i, image_file in enumerate(image_files):
|
538 |
+
# if i < 7:
|
539 |
+
# img = util.imread_uint(image_file, 3)
|
540 |
+
# img = util.uint2single(img)
|
541 |
+
# imgs.append(img)
|
542 |
+
#
|
543 |
+
# results = add_video_compression(imgs)
|
544 |
+
# for i, img in enumerate(results):
|
545 |
+
# util.imshow(util.single2uint(img))
|
546 |
+
# util.imsave(util.single2uint(img),f'{i:05}.png')
|
547 |
+
|
548 |
+
# run utils/utils_video.py
|
549 |
+
|
550 |
+
|
551 |
+
|
552 |
+
|
553 |
+
|
554 |
+
|
555 |
+
|
core/scripts/__init__.py
ADDED
File without changes
|
core/scripts/cli.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
from .. import WarpCore
|
4 |
+
from .. import templates
|
5 |
+
|
6 |
+
|
7 |
+
def template_init(args):
|
8 |
+
return ''''
|
9 |
+
|
10 |
+
|
11 |
+
'''.strip()
|
12 |
+
|
13 |
+
|
14 |
+
def init_template(args):
|
15 |
+
parser = argparse.ArgumentParser(description='WarpCore template init tool')
|
16 |
+
parser.add_argument('-t', '--template', type=str, default='WarpCore')
|
17 |
+
args = parser.parse_args(args)
|
18 |
+
|
19 |
+
if args.template == 'WarpCore':
|
20 |
+
template_cls = WarpCore
|
21 |
+
else:
|
22 |
+
try:
|
23 |
+
template_cls = __import__(args.template)
|
24 |
+
except ModuleNotFoundError:
|
25 |
+
template_cls = getattr(templates, args.template)
|
26 |
+
print(template_cls)
|
27 |
+
|
28 |
+
|
29 |
+
def main():
|
30 |
+
if len(sys.argv) < 2:
|
31 |
+
print('Usage: core <command>')
|
32 |
+
sys.exit(1)
|
33 |
+
if sys.argv[1] == 'init':
|
34 |
+
init_template(sys.argv[2:])
|
35 |
+
else:
|
36 |
+
print('Unknown command')
|
37 |
+
sys.exit(1)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
main()
|
core/templates/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .diffusion import DiffusionCore
|
core/templates/diffusion.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .. import WarpCore
|
2 |
+
from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
|
3 |
+
from abc import abstractmethod
|
4 |
+
from dataclasses import dataclass
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from gdf import GDF
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
import wandb
|
12 |
+
|
13 |
+
import webdataset as wds
|
14 |
+
from webdataset.handlers import warn_and_continue
|
15 |
+
from torch.distributed import barrier
|
16 |
+
from enum import Enum
|
17 |
+
|
18 |
+
class TargetReparametrization(Enum):
|
19 |
+
EPSILON = 'epsilon'
|
20 |
+
X0 = 'x0'
|
21 |
+
|
22 |
+
class DiffusionCore(WarpCore):
|
23 |
+
@dataclass(frozen=True)
|
24 |
+
class Config(WarpCore.Config):
|
25 |
+
# TRAINING PARAMS
|
26 |
+
lr: float = EXPECTED_TRAIN
|
27 |
+
grad_accum_steps: int = EXPECTED_TRAIN
|
28 |
+
batch_size: int = EXPECTED_TRAIN
|
29 |
+
updates: int = EXPECTED_TRAIN
|
30 |
+
warmup_updates: int = EXPECTED_TRAIN
|
31 |
+
save_every: int = 500
|
32 |
+
backup_every: int = 20000
|
33 |
+
use_fsdp: bool = True
|
34 |
+
|
35 |
+
# EMA UPDATE
|
36 |
+
ema_start_iters: int = None
|
37 |
+
ema_iters: int = None
|
38 |
+
ema_beta: float = None
|
39 |
+
|
40 |
+
# GDF setting
|
41 |
+
gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
|
42 |
+
|
43 |
+
@dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
|
44 |
+
class Info(WarpCore.Info):
|
45 |
+
ema_loss: float = None
|
46 |
+
|
47 |
+
@dataclass(frozen=True)
|
48 |
+
class Models(WarpCore.Models):
|
49 |
+
generator : nn.Module = EXPECTED
|
50 |
+
generator_ema : nn.Module = None # optional
|
51 |
+
|
52 |
+
@dataclass(frozen=True)
|
53 |
+
class Optimizers(WarpCore.Optimizers):
|
54 |
+
generator : any = EXPECTED
|
55 |
+
|
56 |
+
@dataclass(frozen=True)
|
57 |
+
class Schedulers(WarpCore.Schedulers):
|
58 |
+
generator: any = None
|
59 |
+
|
60 |
+
@dataclass(frozen=True)
|
61 |
+
class Extras(WarpCore.Extras):
|
62 |
+
gdf: GDF = EXPECTED
|
63 |
+
sampling_configs: dict = EXPECTED
|
64 |
+
|
65 |
+
# --------------------------------------------
|
66 |
+
info: Info
|
67 |
+
config: Config
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
|
71 |
+
raise NotImplementedError("This method needs to be overriden")
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
|
75 |
+
raise NotImplementedError("This method needs to be overriden")
|
76 |
+
|
77 |
+
@abstractmethod
|
78 |
+
def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
|
79 |
+
raise NotImplementedError("This method needs to be overriden")
|
80 |
+
|
81 |
+
@abstractmethod
|
82 |
+
def webdataset_path(self, extras: Extras):
|
83 |
+
raise NotImplementedError("This method needs to be overriden")
|
84 |
+
|
85 |
+
@abstractmethod
|
86 |
+
def webdataset_filters(self, extras: Extras):
|
87 |
+
raise NotImplementedError("This method needs to be overriden")
|
88 |
+
|
89 |
+
@abstractmethod
|
90 |
+
def webdataset_preprocessors(self, extras: Extras):
|
91 |
+
raise NotImplementedError("This method needs to be overriden")
|
92 |
+
|
93 |
+
@abstractmethod
|
94 |
+
def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
|
95 |
+
raise NotImplementedError("This method needs to be overriden")
|
96 |
+
# -------------
|
97 |
+
|
98 |
+
def setup_data(self, extras: Extras) -> WarpCore.Data:
|
99 |
+
# SETUP DATASET
|
100 |
+
dataset_path = self.webdataset_path(extras)
|
101 |
+
preprocessors = self.webdataset_preprocessors(extras)
|
102 |
+
filters = self.webdataset_filters(extras)
|
103 |
+
|
104 |
+
handler = warn_and_continue # None
|
105 |
+
# handler = None
|
106 |
+
dataset = wds.WebDataset(
|
107 |
+
dataset_path, resampled=True, handler=handler
|
108 |
+
).select(filters).shuffle(690, handler=handler).decode(
|
109 |
+
"pilrgb", handler=handler
|
110 |
+
).to_tuple(
|
111 |
+
*[p[0] for p in preprocessors], handler=handler
|
112 |
+
).map_tuple(
|
113 |
+
*[p[1] for p in preprocessors], handler=handler
|
114 |
+
).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
|
115 |
+
|
116 |
+
# SETUP DATALOADER
|
117 |
+
real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
|
118 |
+
dataloader = DataLoader(
|
119 |
+
dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
|
120 |
+
)
|
121 |
+
|
122 |
+
return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
|
123 |
+
|
124 |
+
def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
|
125 |
+
batch = next(data.iterator)
|
126 |
+
|
127 |
+
with torch.no_grad():
|
128 |
+
conditions = self.get_conditions(batch, models, extras)
|
129 |
+
latents = self.encode_latents(batch, models, extras)
|
130 |
+
noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
|
131 |
+
|
132 |
+
# FORWARD PASS
|
133 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
134 |
+
pred = models.generator(noised, noise_cond, **conditions)
|
135 |
+
if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
|
136 |
+
pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
|
137 |
+
target = noise
|
138 |
+
elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
|
139 |
+
pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
|
140 |
+
target = latents
|
141 |
+
loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
|
142 |
+
loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
|
143 |
+
|
144 |
+
return loss, loss_adjusted
|
145 |
+
|
146 |
+
def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
|
147 |
+
start_iter = self.info.iter+1
|
148 |
+
max_iters = self.config.updates * self.config.grad_accum_steps
|
149 |
+
if self.is_main_node:
|
150 |
+
print(f"STARTING AT STEP: {start_iter}/{max_iters}")
|
151 |
+
|
152 |
+
pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
|
153 |
+
models.generator.train()
|
154 |
+
for i in pbar:
|
155 |
+
# FORWARD PASS
|
156 |
+
loss, loss_adjusted = self.forward_pass(data, extras, models)
|
157 |
+
|
158 |
+
# BACKWARD PASS
|
159 |
+
if i % self.config.grad_accum_steps == 0 or i == max_iters:
|
160 |
+
loss_adjusted.backward()
|
161 |
+
grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
|
162 |
+
optimizers_dict = optimizers.to_dict()
|
163 |
+
for k in optimizers_dict:
|
164 |
+
optimizers_dict[k].step()
|
165 |
+
schedulers_dict = schedulers.to_dict()
|
166 |
+
for k in schedulers_dict:
|
167 |
+
schedulers_dict[k].step()
|
168 |
+
models.generator.zero_grad(set_to_none=True)
|
169 |
+
self.info.total_steps += 1
|
170 |
+
else:
|
171 |
+
with models.generator.no_sync():
|
172 |
+
loss_adjusted.backward()
|
173 |
+
self.info.iter = i
|
174 |
+
|
175 |
+
# UPDATE EMA
|
176 |
+
if models.generator_ema is not None and i % self.config.ema_iters == 0:
|
177 |
+
update_weights_ema(
|
178 |
+
models.generator_ema, models.generator,
|
179 |
+
beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
|
180 |
+
)
|
181 |
+
|
182 |
+
# UPDATE LOSS METRICS
|
183 |
+
self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
|
184 |
+
|
185 |
+
if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
|
186 |
+
wandb.alert(
|
187 |
+
title=f"NaN value encountered in training run {self.info.wandb_run_id}",
|
188 |
+
text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
|
189 |
+
wait_duration=60*30
|
190 |
+
)
|
191 |
+
|
192 |
+
if self.is_main_node:
|
193 |
+
logs = {
|
194 |
+
'loss': self.info.ema_loss,
|
195 |
+
'raw_loss': loss.mean().item(),
|
196 |
+
'grad_norm': grad_norm.item(),
|
197 |
+
'lr': optimizers.generator.param_groups[0]['lr'],
|
198 |
+
'total_steps': self.info.total_steps,
|
199 |
+
}
|
200 |
+
|
201 |
+
pbar.set_postfix(logs)
|
202 |
+
if self.config.wandb_project is not None:
|
203 |
+
wandb.log(logs)
|
204 |
+
|
205 |
+
if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
|
206 |
+
# SAVE AND CHECKPOINT STUFF
|
207 |
+
if np.isnan(loss.mean().item()):
|
208 |
+
if self.is_main_node and self.config.wandb_project is not None:
|
209 |
+
tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
|
210 |
+
wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
|
211 |
+
else:
|
212 |
+
self.save_checkpoints(models, optimizers)
|
213 |
+
if self.is_main_node:
|
214 |
+
create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
|
215 |
+
self.sample(models, data, extras)
|
216 |
+
|
217 |
+
def models_to_save(self):
|
218 |
+
return ['generator', 'generator_ema']
|
219 |
+
|
220 |
+
def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
|
221 |
+
barrier()
|
222 |
+
suffix = '' if suffix is None else suffix
|
223 |
+
self.save_info(self.info, suffix=suffix)
|
224 |
+
models_dict = models.to_dict()
|
225 |
+
optimizers_dict = optimizers.to_dict()
|
226 |
+
for key in self.models_to_save():
|
227 |
+
model = models_dict[key]
|
228 |
+
if model is not None:
|
229 |
+
self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
|
230 |
+
for key in optimizers_dict:
|
231 |
+
optimizer = optimizers_dict[key]
|
232 |
+
if optimizer is not None:
|
233 |
+
self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
|
234 |
+
if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
|
235 |
+
self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
|
236 |
+
torch.cuda.empty_cache()
|
core/utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
|
2 |
+
from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
|
3 |
+
|
4 |
+
# MOVE IT SOMERWHERE ELSE
|
5 |
+
def update_weights_ema(tgt_model, src_model, beta=0.999):
|
6 |
+
for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
|
7 |
+
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
|
8 |
+
for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
|
9 |
+
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
|
core/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (763 Bytes). View file
|
|
core/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (804 Bytes). View file
|
|
core/utils/__pycache__/base_dto.cpython-310.pyc
ADDED
Binary file (3.09 kB). View file
|
|
core/utils/__pycache__/base_dto.cpython-39.pyc
ADDED
Binary file (3.11 kB). View file
|
|
core/utils/__pycache__/save_and_load.cpython-310.pyc
ADDED
Binary file (2.19 kB). View file
|
|
core/utils/__pycache__/save_and_load.cpython-39.pyc
ADDED
Binary file (2.2 kB). View file
|
|
core/utils/base_dto.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from dataclasses import dataclass, _MISSING_TYPE
|
3 |
+
from munch import Munch
|
4 |
+
|
5 |
+
EXPECTED = "___REQUIRED___"
|
6 |
+
EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
|
7 |
+
|
8 |
+
# pylint: disable=invalid-field-call
|
9 |
+
def nested_dto(x, raw=False):
|
10 |
+
return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
|
11 |
+
|
12 |
+
@dataclass(frozen=True)
|
13 |
+
class Base:
|
14 |
+
training: bool = None
|
15 |
+
def __new__(cls, **kwargs):
|
16 |
+
training = kwargs.get('training', True)
|
17 |
+
setteable_fields = cls.setteable_fields(**kwargs)
|
18 |
+
mandatory_fields = cls.mandatory_fields(**kwargs)
|
19 |
+
invalid_kwargs = [
|
20 |
+
{k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
|
21 |
+
]
|
22 |
+
print(mandatory_fields)
|
23 |
+
assert (
|
24 |
+
len(invalid_kwargs) == 0
|
25 |
+
), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
|
26 |
+
missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
|
27 |
+
assert (
|
28 |
+
len(missing_kwargs) == 0
|
29 |
+
), f"Required fields missing initializing this DTO: {missing_kwargs}."
|
30 |
+
return object.__new__(cls)
|
31 |
+
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def setteable_fields(cls, **kwargs):
|
35 |
+
return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def mandatory_fields(cls, **kwargs):
|
39 |
+
training = kwargs.get('training', True)
|
40 |
+
return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def from_dict(cls, kwargs):
|
44 |
+
for k in kwargs:
|
45 |
+
if isinstance(kwargs[k], (dict, list, tuple)):
|
46 |
+
kwargs[k] = Munch.fromDict(kwargs[k])
|
47 |
+
return cls(**kwargs)
|
48 |
+
|
49 |
+
def to_dict(self):
|
50 |
+
# selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
|
51 |
+
selfdict = {}
|
52 |
+
for k in dataclasses.fields(self):
|
53 |
+
selfdict[k.name] = getattr(self, k.name)
|
54 |
+
if isinstance(selfdict[k.name], Munch):
|
55 |
+
selfdict[k.name] = selfdict[k.name].toDict()
|
56 |
+
return selfdict
|
core/utils/save_and_load.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
import safetensors
|
6 |
+
import wandb
|
7 |
+
|
8 |
+
|
9 |
+
def create_folder_if_necessary(path):
|
10 |
+
path = "/".join(path.split("/")[:-1])
|
11 |
+
Path(path).mkdir(parents=True, exist_ok=True)
|
12 |
+
|
13 |
+
|
14 |
+
def safe_save(ckpt, path):
|
15 |
+
try:
|
16 |
+
os.remove(f"{path}.bak")
|
17 |
+
except OSError:
|
18 |
+
pass
|
19 |
+
try:
|
20 |
+
os.rename(path, f"{path}.bak")
|
21 |
+
except OSError:
|
22 |
+
pass
|
23 |
+
if path.endswith(".pt") or path.endswith(".ckpt"):
|
24 |
+
torch.save(ckpt, path)
|
25 |
+
elif path.endswith(".json"):
|
26 |
+
with open(path, "w", encoding="utf-8") as f:
|
27 |
+
json.dump(ckpt, f, indent=4)
|
28 |
+
elif path.endswith(".safetensors"):
|
29 |
+
safetensors.torch.save_file(ckpt, path)
|
30 |
+
else:
|
31 |
+
raise ValueError(f"File extension not supported: {path}")
|
32 |
+
|
33 |
+
|
34 |
+
def load_or_fail(path, wandb_run_id=None):
|
35 |
+
accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
|
36 |
+
try:
|
37 |
+
assert any(
|
38 |
+
[path.endswith(ext) for ext in accepted_extensions]
|
39 |
+
), f"Automatic loading not supported for this extension: {path}"
|
40 |
+
if not os.path.exists(path):
|
41 |
+
checkpoint = None
|
42 |
+
elif path.endswith(".pt") or path.endswith(".ckpt"):
|
43 |
+
checkpoint = torch.load(path, map_location="cpu")
|
44 |
+
elif path.endswith(".json"):
|
45 |
+
with open(path, "r", encoding="utf-8") as f:
|
46 |
+
checkpoint = json.load(f)
|
47 |
+
elif path.endswith(".safetensors"):
|
48 |
+
checkpoint = {}
|
49 |
+
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
|
50 |
+
for key in f.keys():
|
51 |
+
checkpoint[key] = f.get_tensor(key)
|
52 |
+
return checkpoint
|
53 |
+
except Exception as e:
|
54 |
+
if wandb_run_id is not None:
|
55 |
+
wandb.alert(
|
56 |
+
title=f"Corrupt checkpoint for run {wandb_run_id}",
|
57 |
+
text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
|
58 |
+
)
|
59 |
+
raise e
|
figures/California_000490.jpg
ADDED
Git LFS Details
|
figures/example_dataset/000008.jpg
ADDED
Git LFS Details
|
figures/example_dataset/000008.json
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
{ "caption": "The image captures the iconic Shard, a modern skyscraper that stands as the tallest building in the United Kingdom. The Shard, with its glass and steel structure, pierces the sky, its pointed top reaching towards the heavens. The photograph is taken from a low angle, which emphasizes the height and grandeur of the building. The sky forms a beautiful backdrop, painted in hues of pinkish-orange, suggesting that the photo was taken at sunset. The Shard is nestled between two other buildings, their presence subtly hinted at in the foreground. The image does not contain any discernible text or countable objects, and there are no visible actions taking place. The relative positions of the objects confirm that the Shard is the central focus of the image, with the other buildings and the sky providing context to its location. The image is devoid of any aesthetic descriptions, focusing solely on the factual representation of the scene."
|
2 |
+
}
|
figures/example_dataset/000012.jpg
ADDED
Git LFS Details
|
figures/example_dataset/000012.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"caption": "cars in a road during daytime"}
|
figures/teaser.jpg
ADDED
Git LFS Details
|
gdf/__init__.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .scalers import *
|
3 |
+
from .targets import *
|
4 |
+
from .schedulers import *
|
5 |
+
from .noise_conditions import *
|
6 |
+
from .loss_weights import *
|
7 |
+
from .samplers import *
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import math
|
10 |
+
class GDF():
|
11 |
+
def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0):
|
12 |
+
self.schedule = schedule
|
13 |
+
self.input_scaler = input_scaler
|
14 |
+
self.target = target
|
15 |
+
self.noise_cond = noise_cond
|
16 |
+
self.loss_weight = loss_weight
|
17 |
+
self.offset_noise = offset_noise
|
18 |
+
|
19 |
+
def setup_limits(self, stretch_max=True, stretch_min=True, shift=1):
|
20 |
+
stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift)
|
21 |
+
return stretched_limits
|
22 |
+
|
23 |
+
def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None):
|
24 |
+
if epsilon is None:
|
25 |
+
epsilon = torch.randn_like(x0)
|
26 |
+
if self.offset_noise > 0:
|
27 |
+
if offset is None:
|
28 |
+
offset = torch.randn([x0.size(0), x0.size(1)] + [1]*(len(x0.shape)-2)).to(x0.device)
|
29 |
+
epsilon = epsilon + offset * self.offset_noise
|
30 |
+
logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device)
|
31 |
+
a, b = self.input_scaler(logSNR) # B
|
32 |
+
if len(a.shape) == 1:
|
33 |
+
a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) # BxCxHxW
|
34 |
+
#print('in line 33 a b', a.shape, b.shape, x0.shape, logSNR.shape, logSNR, self.noise_cond(logSNR))
|
35 |
+
target = self.target(x0, epsilon, logSNR, a, b)
|
36 |
+
|
37 |
+
# noised, noise, logSNR, t_cond
|
38 |
+
#noised, noise, target, logSNR, noise_cond, loss_weight
|
39 |
+
return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift)
|
40 |
+
|
41 |
+
def undiffuse(self, x, logSNR, pred):
|
42 |
+
a, b = self.input_scaler(logSNR)
|
43 |
+
if len(a.shape) == 1:
|
44 |
+
a, b = a.view(-1, *[1]*(len(x.shape)-1)), b.view(-1, *[1]*(len(x.shape)-1))
|
45 |
+
return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b)
|
46 |
+
|
47 |
+
def sample(self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"):
|
48 |
+
sampler_params = {} if sampler_params is None else sampler_params
|
49 |
+
if sampler is None:
|
50 |
+
sampler = DDPMSampler(self)
|
51 |
+
r_range = torch.linspace(t_start, t_end, timesteps+1)
|
52 |
+
schedule = self.schedule if schedule is None else schedule
|
53 |
+
logSNR_range = schedule(r_range, shift=shift)[:, None].expand(
|
54 |
+
-1, shape[0] if x_init is None else x_init.size(0)
|
55 |
+
).to(device)
|
56 |
+
|
57 |
+
x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
|
58 |
+
|
59 |
+
if cfg is not None:
|
60 |
+
if unconditional_inputs is None:
|
61 |
+
unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
|
62 |
+
model_inputs = {
|
63 |
+
k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor)
|
64 |
+
else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list)
|
65 |
+
else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict)
|
66 |
+
else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
|
67 |
+
}
|
68 |
+
|
69 |
+
for i in range(0, timesteps):
|
70 |
+
noise_cond = self.noise_cond(logSNR_range[i])
|
71 |
+
if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
|
72 |
+
cfg_val = cfg
|
73 |
+
if isinstance(cfg_val, (list, tuple)):
|
74 |
+
assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
|
75 |
+
cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
|
76 |
+
|
77 |
+
pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2)
|
78 |
+
|
79 |
+
pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
|
80 |
+
if cfg_rho > 0:
|
81 |
+
std_pos, std_cfg = pred.std(), pred_cfg.std()
|
82 |
+
pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
|
83 |
+
else:
|
84 |
+
pred = pred_cfg
|
85 |
+
else:
|
86 |
+
pred = model(x, noise_cond, **model_inputs)
|
87 |
+
x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
|
88 |
+
x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params)
|
89 |
+
#print('in line 86', x0.shape, x.shape, i, )
|
90 |
+
altered_vars = yield (x0, x, pred)
|
91 |
+
|
92 |
+
# Update some running variables if the user wants
|
93 |
+
if altered_vars is not None:
|
94 |
+
cfg = altered_vars.get('cfg', cfg)
|
95 |
+
cfg_rho = altered_vars.get('cfg_rho', cfg_rho)
|
96 |
+
sampler = altered_vars.get('sampler', sampler)
|
97 |
+
model_inputs = altered_vars.get('model_inputs', model_inputs)
|
98 |
+
x = altered_vars.get('x', x)
|
99 |
+
x_init = altered_vars.get('x_init', x_init)
|
100 |
+
|
101 |
+
class GDF_dual_fixlrt(GDF):
|
102 |
+
def ref_noise(self, noised, x0, logSNR):
|
103 |
+
a, b = self.input_scaler(logSNR)
|
104 |
+
if len(a.shape) == 1:
|
105 |
+
a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1))
|
106 |
+
#print('in line 210', a.shape, b.shape, x0.shape, noised.shape)
|
107 |
+
return self.target.noise_givenx0_noised(x0, noised, logSNR, a, b)
|
108 |
+
|
109 |
+
def sample(self, model, model_inputs, shape, shape_lr, unconditional_inputs=None, sampler=None,
|
110 |
+
schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None,
|
111 |
+
cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"):
|
112 |
+
sampler_params = {} if sampler_params is None else sampler_params
|
113 |
+
if sampler is None:
|
114 |
+
sampler = DDPMSampler(self)
|
115 |
+
r_range = torch.linspace(t_start, t_end, timesteps+1)
|
116 |
+
schedule = self.schedule if schedule is None else schedule
|
117 |
+
logSNR_range = schedule(r_range, shift=shift)[:, None].expand(
|
118 |
+
-1, shape[0] if x_init is None else x_init.size(0)
|
119 |
+
).to(device)
|
120 |
+
|
121 |
+
x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
|
122 |
+
x_lr = sampler.init_x(shape_lr).to(device) if x_init is None else x_init.clone()
|
123 |
+
if cfg is not None:
|
124 |
+
if unconditional_inputs is None:
|
125 |
+
unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
|
126 |
+
model_inputs = {
|
127 |
+
k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor)
|
128 |
+
else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list)
|
129 |
+
else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict)
|
130 |
+
else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
|
131 |
+
}
|
132 |
+
|
133 |
+
###############################################lr sampling
|
134 |
+
|
135 |
+
guide_feas = [None] * timesteps
|
136 |
+
|
137 |
+
for i in range(0, timesteps):
|
138 |
+
noise_cond = self.noise_cond(logSNR_range[i])
|
139 |
+
if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
|
140 |
+
cfg_val = cfg
|
141 |
+
if isinstance(cfg_val, (list, tuple)):
|
142 |
+
assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
|
143 |
+
cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
if i == timesteps -1 :
|
148 |
+
output, guide_lr_enc, guide_lr_dec = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs)
|
149 |
+
guide_feas[i] = ([f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_enc], [f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_dec])
|
150 |
+
else:
|
151 |
+
output, _, _ = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs)
|
152 |
+
|
153 |
+
pred, pred_unconditional = output.chunk(2)
|
154 |
+
|
155 |
+
|
156 |
+
pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
|
157 |
+
if cfg_rho > 0:
|
158 |
+
std_pos, std_cfg = pred.std(), pred_cfg.std()
|
159 |
+
pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
|
160 |
+
else:
|
161 |
+
pred = pred_cfg
|
162 |
+
else:
|
163 |
+
pred = model(x_lr, noise_cond, **model_inputs)
|
164 |
+
x0_lr, epsilon_lr = self.undiffuse(x_lr, logSNR_range[i], pred)
|
165 |
+
x_lr = sampler(x_lr, x0_lr, epsilon_lr, logSNR_range[i], logSNR_range[i+1], **sampler_params)
|
166 |
+
|
167 |
+
###############################################hr HR sampling
|
168 |
+
for i in range(0, timesteps):
|
169 |
+
noise_cond = self.noise_cond(logSNR_range[i])
|
170 |
+
if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
|
171 |
+
cfg_val = cfg
|
172 |
+
if isinstance(cfg_val, (list, tuple)):
|
173 |
+
assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
|
174 |
+
cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
|
175 |
+
|
176 |
+
out_pred, t_emb = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), \
|
177 |
+
lr_guide=guide_feas[timesteps -1] if i <=19 else None , **model_inputs, require_t=True, guide_weight=1 - i/timesteps)
|
178 |
+
pred, pred_unconditional = out_pred.chunk(2)
|
179 |
+
pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
|
180 |
+
if cfg_rho > 0:
|
181 |
+
std_pos, std_cfg = pred.std(), pred_cfg.std()
|
182 |
+
pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
|
183 |
+
else:
|
184 |
+
pred = pred_cfg
|
185 |
+
else:
|
186 |
+
pred = model(x, noise_cond, guide_lr=(guide_lr_enc, guide_lr_dec), **model_inputs)
|
187 |
+
x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
|
188 |
+
|
189 |
+
x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params)
|
190 |
+
altered_vars = yield (x0, x, pred, x_lr)
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
# Update some running variables if the user wants
|
195 |
+
if altered_vars is not None:
|
196 |
+
cfg = altered_vars.get('cfg', cfg)
|
197 |
+
cfg_rho = altered_vars.get('cfg_rho', cfg_rho)
|
198 |
+
sampler = altered_vars.get('sampler', sampler)
|
199 |
+
model_inputs = altered_vars.get('model_inputs', model_inputs)
|
200 |
+
x = altered_vars.get('x', x)
|
201 |
+
x_init = altered_vars.get('x_init', x_init)
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
|
gdf/loss_weights.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
# --- Loss Weighting
|
5 |
+
class BaseLossWeight():
|
6 |
+
def weight(self, logSNR):
|
7 |
+
raise NotImplementedError("this method needs to be overridden")
|
8 |
+
|
9 |
+
def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
|
10 |
+
clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
|
11 |
+
if shift != 1:
|
12 |
+
logSNR = logSNR.clone() + 2 * np.log(shift)
|
13 |
+
return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)
|
14 |
+
|
15 |
+
class ComposedLossWeight(BaseLossWeight):
|
16 |
+
def __init__(self, div, mul):
|
17 |
+
self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
|
18 |
+
self.div = [div] if isinstance(div, BaseLossWeight) else div
|
19 |
+
|
20 |
+
def weight(self, logSNR):
|
21 |
+
prod, div = 1, 1
|
22 |
+
for m in self.mul:
|
23 |
+
prod *= m.weight(logSNR)
|
24 |
+
for d in self.div:
|
25 |
+
div *= d.weight(logSNR)
|
26 |
+
return prod/div
|
27 |
+
|
28 |
+
class ConstantLossWeight(BaseLossWeight):
|
29 |
+
def __init__(self, v=1):
|
30 |
+
self.v = v
|
31 |
+
|
32 |
+
def weight(self, logSNR):
|
33 |
+
return torch.ones_like(logSNR) * self.v
|
34 |
+
|
35 |
+
class SNRLossWeight(BaseLossWeight):
|
36 |
+
def weight(self, logSNR):
|
37 |
+
return logSNR.exp()
|
38 |
+
|
39 |
+
class P2LossWeight(BaseLossWeight):
|
40 |
+
def __init__(self, k=1.0, gamma=1.0, s=1.0):
|
41 |
+
self.k, self.gamma, self.s = k, gamma, s
|
42 |
+
|
43 |
+
def weight(self, logSNR):
|
44 |
+
return (self.k + (logSNR * self.s).exp()) ** -self.gamma
|
45 |
+
|
46 |
+
class SNRPlusOneLossWeight(BaseLossWeight):
|
47 |
+
def weight(self, logSNR):
|
48 |
+
return logSNR.exp() + 1
|
49 |
+
|
50 |
+
class MinSNRLossWeight(BaseLossWeight):
|
51 |
+
def __init__(self, max_snr=5):
|
52 |
+
self.max_snr = max_snr
|
53 |
+
|
54 |
+
def weight(self, logSNR):
|
55 |
+
return logSNR.exp().clamp(max=self.max_snr)
|
56 |
+
|
57 |
+
class MinSNRPlusOneLossWeight(BaseLossWeight):
|
58 |
+
def __init__(self, max_snr=5):
|
59 |
+
self.max_snr = max_snr
|
60 |
+
|
61 |
+
def weight(self, logSNR):
|
62 |
+
return (logSNR.exp() + 1).clamp(max=self.max_snr)
|
63 |
+
|
64 |
+
class TruncatedSNRLossWeight(BaseLossWeight):
|
65 |
+
def __init__(self, min_snr=1):
|
66 |
+
self.min_snr = min_snr
|
67 |
+
|
68 |
+
def weight(self, logSNR):
|
69 |
+
return logSNR.exp().clamp(min=self.min_snr)
|
70 |
+
|
71 |
+
class SechLossWeight(BaseLossWeight):
|
72 |
+
def __init__(self, div=2):
|
73 |
+
self.div = div
|
74 |
+
|
75 |
+
def weight(self, logSNR):
|
76 |
+
return 1/(logSNR/self.div).cosh()
|
77 |
+
|
78 |
+
class DebiasedLossWeight(BaseLossWeight):
|
79 |
+
def weight(self, logSNR):
|
80 |
+
return 1/logSNR.exp().sqrt()
|
81 |
+
|
82 |
+
class SigmoidLossWeight(BaseLossWeight):
|
83 |
+
def __init__(self, s=1):
|
84 |
+
self.s = s
|
85 |
+
|
86 |
+
def weight(self, logSNR):
|
87 |
+
return (logSNR * self.s).sigmoid()
|
88 |
+
|
89 |
+
class AdaptiveLossWeight(BaseLossWeight):
|
90 |
+
def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
|
91 |
+
self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1)
|
92 |
+
self.bucket_losses = torch.ones(buckets)
|
93 |
+
self.weight_range = weight_range
|
94 |
+
|
95 |
+
def weight(self, logSNR):
|
96 |
+
indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
|
97 |
+
return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)
|
98 |
+
|
99 |
+
def update_buckets(self, logSNR, loss, beta=0.99):
|
100 |
+
indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
|
101 |
+
self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta)
|
gdf/noise_conditions.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class BaseNoiseCond():
|
5 |
+
def __init__(self, *args, shift=1, clamp_range=None, **kwargs):
|
6 |
+
clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
|
7 |
+
self.shift = shift
|
8 |
+
self.clamp_range = clamp_range
|
9 |
+
self.setup(*args, **kwargs)
|
10 |
+
|
11 |
+
def setup(self, *args, **kwargs):
|
12 |
+
pass # this method is optional, override it if required
|
13 |
+
|
14 |
+
def cond(self, logSNR):
|
15 |
+
raise NotImplementedError("this method needs to be overriden")
|
16 |
+
|
17 |
+
def __call__(self, logSNR):
|
18 |
+
if self.shift != 1:
|
19 |
+
logSNR = logSNR.clone() + 2 * np.log(self.shift)
|
20 |
+
return self.cond(logSNR).clamp(*self.clamp_range)
|
21 |
+
|
22 |
+
class CosineTNoiseCond(BaseNoiseCond):
|
23 |
+
def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999]
|
24 |
+
self.s = torch.tensor([s])
|
25 |
+
self.clamp_range = clamp_range
|
26 |
+
self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
|
27 |
+
|
28 |
+
def cond(self, logSNR):
|
29 |
+
var = logSNR.sigmoid()
|
30 |
+
var = var.clamp(*self.clamp_range)
|
31 |
+
s, min_var = self.s.to(var.device), self.min_var.to(var.device)
|
32 |
+
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
33 |
+
return t
|
34 |
+
|
35 |
+
class EDMNoiseCond(BaseNoiseCond):
|
36 |
+
def cond(self, logSNR):
|
37 |
+
return -logSNR/8
|
38 |
+
|
39 |
+
class SigmoidNoiseCond(BaseNoiseCond):
|
40 |
+
def cond(self, logSNR):
|
41 |
+
return (-logSNR).sigmoid()
|
42 |
+
|
43 |
+
class LogSNRNoiseCond(BaseNoiseCond):
|
44 |
+
def cond(self, logSNR):
|
45 |
+
return logSNR
|
46 |
+
|
47 |
+
class EDMSigmaNoiseCond(BaseNoiseCond):
|
48 |
+
def setup(self, sigma_data=1):
|
49 |
+
self.sigma_data = sigma_data
|
50 |
+
|
51 |
+
def cond(self, logSNR):
|
52 |
+
return torch.exp(-logSNR / 2) * self.sigma_data
|
53 |
+
|
54 |
+
class RectifiedFlowsNoiseCond(BaseNoiseCond):
|
55 |
+
def cond(self, logSNR):
|
56 |
+
_a = logSNR.exp() - 1
|
57 |
+
_a[_a == 0] = 1e-3 # Avoid division by zero
|
58 |
+
a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a)
|
59 |
+
return a
|
60 |
+
|
61 |
+
# Any NoiseCond that cannot be described easily as a continuous function of t
|
62 |
+
# It needs to define self.x and self.y in the setup() method
|
63 |
+
class PiecewiseLinearNoiseCond(BaseNoiseCond):
|
64 |
+
def setup(self):
|
65 |
+
self.x = None
|
66 |
+
self.y = None
|
67 |
+
|
68 |
+
def piecewise_linear(self, y, xs, ys):
|
69 |
+
indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y)
|
70 |
+
x_min, x_max = xs[indices], xs[indices+1]
|
71 |
+
y_min, y_max = ys[indices], ys[indices+1]
|
72 |
+
x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min)
|
73 |
+
return x
|
74 |
+
|
75 |
+
def cond(self, logSNR):
|
76 |
+
var = logSNR.sigmoid()
|
77 |
+
t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0)
|
78 |
+
return t
|
79 |
+
|
80 |
+
class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond):
|
81 |
+
def setup(self, linear_range=[0.00085, 0.012], total_steps=1000):
|
82 |
+
self.total_steps = total_steps
|
83 |
+
linear_range_sqrt = [r**0.5 for r in linear_range]
|
84 |
+
self.x = torch.linspace(0, 1, total_steps+1)
|
85 |
+
|
86 |
+
alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2
|
87 |
+
self.y = alphas.cumprod(dim=-1)
|
88 |
+
|
89 |
+
def cond(self, logSNR):
|
90 |
+
return super().cond(logSNR).clamp(0, 1)
|
91 |
+
|
92 |
+
class DiscreteNoiseCond(BaseNoiseCond):
|
93 |
+
def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]):
|
94 |
+
self.noise_cond = noise_cond
|
95 |
+
self.steps = steps
|
96 |
+
self.continuous_range = continuous_range
|
97 |
+
|
98 |
+
def cond(self, logSNR):
|
99 |
+
cond = self.noise_cond(logSNR)
|
100 |
+
cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0])
|
101 |
+
return cond.mul(self.steps).long()
|
102 |
+
|
gdf/readme.md
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generic Diffusion Framework (GDF)
|
2 |
+
|
3 |
+
# Basic usage
|
4 |
+
GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM
|
5 |
+
, EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different
|
6 |
+
frameworks
|
7 |
+
|
8 |
+
Using GDF is very straighforward, first of all just define an instance of the GDF class:
|
9 |
+
|
10 |
+
```python
|
11 |
+
from gdf import GDF
|
12 |
+
from gdf import CosineSchedule
|
13 |
+
from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight
|
14 |
+
|
15 |
+
gdf = GDF(
|
16 |
+
schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
17 |
+
input_scaler=VPScaler(), target=EpsilonTarget(),
|
18 |
+
noise_cond=CosineTNoiseCond(),
|
19 |
+
loss_weight=P2LossWeight(),
|
20 |
+
)
|
21 |
+
```
|
22 |
+
|
23 |
+
You need to define the following components:
|
24 |
+
* **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution.
|
25 |
+
* **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule.
|
26 |
+
* **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows)
|
27 |
+
* **Target**: What the target is during training, usually: epsilon, x0 or v
|
28 |
+
* **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8`
|
29 |
+
* **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use
|
30 |
+
|
31 |
+
All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just:
|
32 |
+
```python
|
33 |
+
class VPScaler():
|
34 |
+
def __call__(self, logSNR):
|
35 |
+
a_squared = logSNR.sigmoid()
|
36 |
+
a = a_squared.sqrt()
|
37 |
+
b = (1-a_squared).sqrt()
|
38 |
+
return a, b
|
39 |
+
|
40 |
+
```
|
41 |
+
|
42 |
+
So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc...
|
43 |
+
|
44 |
+
### Training
|
45 |
+
|
46 |
+
When you define your training loop you can get all you need by just doing:
|
47 |
+
```python
|
48 |
+
shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution
|
49 |
+
for inputs, extra_conditions in dataloader_iterator:
|
50 |
+
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift)
|
51 |
+
pred = diffusion_model(noised, noise_cond, extra_conditions)
|
52 |
+
|
53 |
+
loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
|
54 |
+
loss_adjusted = (loss * loss_weight).mean()
|
55 |
+
|
56 |
+
loss_adjusted.backward()
|
57 |
+
optimizer.step()
|
58 |
+
optimizer.zero_grad(set_to_none=True)
|
59 |
+
```
|
60 |
+
|
61 |
+
And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the
|
62 |
+
training from the GDF class.
|
63 |
+
|
64 |
+
### Sampling
|
65 |
+
|
66 |
+
The other important part is sampling, when you want to use this framework to sample you can just do the following:
|
67 |
+
|
68 |
+
```python
|
69 |
+
from gdf import DDPMSampler
|
70 |
+
|
71 |
+
shift = 1
|
72 |
+
sampling_configs = {
|
73 |
+
"timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift,
|
74 |
+
"schedule": CosineSchedule(clamp_range=[0.0001, 0.9999])
|
75 |
+
}
|
76 |
+
|
77 |
+
*_, (sampled, _, _) = gdf.sample(
|
78 |
+
diffusion_model, {"cond": extra_conditions}, latents.shape,
|
79 |
+
unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)},
|
80 |
+
device=device, **sampling_configs
|
81 |
+
)
|
82 |
+
```
|
83 |
+
|
84 |
+
# Available modules
|
85 |
+
|
86 |
+
TODO
|