|
import readline
|
|
import rlcompleter
|
|
readline.parse_and_bind("tab: complete")
|
|
import code
|
|
import pdb
|
|
|
|
import time
|
|
import argparse
|
|
import os
|
|
import imageio
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
|
|
|
|
def interact(local=None):
|
|
"""interactive console with autocomplete function. Useful for debugging.
|
|
interact(locals())
|
|
"""
|
|
if local is None:
|
|
local=dict(globals(), **locals())
|
|
|
|
readline.set_completer(rlcompleter.Completer(local).complete)
|
|
code.interact(local=local)
|
|
|
|
def set_trace(local=None):
|
|
"""debugging with pdb
|
|
"""
|
|
if local is None:
|
|
local=dict(globals(), **locals())
|
|
|
|
pdb.Pdb.complete = rlcompleter.Completer(local).complete
|
|
pdb.set_trace()
|
|
|
|
|
|
class Timer():
|
|
"""Brought from https://github.com/thstkdgus35/EDSR-PyTorch
|
|
"""
|
|
def __init__(self):
|
|
self.acc = 0
|
|
self.tic()
|
|
|
|
def tic(self):
|
|
self.t0 = time.time()
|
|
|
|
def toc(self):
|
|
return time.time() - self.t0
|
|
|
|
def hold(self):
|
|
self.acc += self.toc()
|
|
|
|
def release(self):
|
|
ret = self.acc
|
|
self.acc = 0
|
|
|
|
return ret
|
|
|
|
def reset(self):
|
|
self.acc = 0
|
|
|
|
|
|
|
|
def str2bool(val):
|
|
"""enable default constant true arguments"""
|
|
|
|
if isinstance(val, bool):
|
|
return val
|
|
elif val.lower() == 'true':
|
|
return True
|
|
elif val.lower() == 'false':
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError('Boolean value expected')
|
|
|
|
def int2str(val):
|
|
"""convert int to str for environment variable related arguments"""
|
|
if isinstance(val, int):
|
|
return str(val)
|
|
elif isinstance(val, str):
|
|
return val
|
|
else:
|
|
raise argparse.ArgumentTypeError('number value expected')
|
|
|
|
|
|
|
|
class MultiSaver():
|
|
def __init__(self, result_dir=None):
|
|
self.queue = None
|
|
self.process = None
|
|
self.result_dir = result_dir
|
|
|
|
def begin_background(self):
|
|
self.queue = mp.Queue()
|
|
|
|
def t(queue):
|
|
while True:
|
|
if queue.empty():
|
|
continue
|
|
img, name = queue.get()
|
|
if name:
|
|
try:
|
|
basename, ext = os.path.splitext(name)
|
|
if ext != '.png':
|
|
name = '{}.png'.format(basename)
|
|
imageio.imwrite(name, img)
|
|
except Exception as e:
|
|
print(e)
|
|
else:
|
|
return
|
|
|
|
worker = lambda: mp.Process(target=t, args=(self.queue,), daemon=False)
|
|
cpu_count = min(8, mp.cpu_count() - 1)
|
|
self.process = [worker() for _ in range(cpu_count)]
|
|
for p in self.process:
|
|
p.start()
|
|
|
|
def end_background(self):
|
|
if self.queue is None:
|
|
return
|
|
|
|
for _ in self.process:
|
|
self.queue.put((None, None))
|
|
|
|
def join_background(self):
|
|
if self.queue is None:
|
|
return
|
|
|
|
while not self.queue.empty():
|
|
time.sleep(0.5)
|
|
|
|
for p in self.process:
|
|
p.join()
|
|
|
|
self.queue = None
|
|
|
|
def save_image(self, output, save_names, result_dir=None):
|
|
result_dir = result_dir if self.result_dir is None else self.result_dir
|
|
if result_dir is None:
|
|
raise Exception('no result dir specified!')
|
|
|
|
if self.queue is None:
|
|
try:
|
|
self.begin_background()
|
|
except Exception as e:
|
|
print(e)
|
|
return
|
|
|
|
|
|
if output.ndim == 2:
|
|
output = output.expand([1, 1] + list(output.shape))
|
|
elif output.ndim == 3:
|
|
output = output.expand([1] + list(output.shape))
|
|
|
|
for output_img, save_name in zip(output, save_names):
|
|
|
|
output_img = output_img.add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
|
|
|
save_name = os.path.join(result_dir, save_name)
|
|
save_dir = os.path.dirname(save_name)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
self.queue.put((output_img, save_name))
|
|
|
|
return
|
|
|
|
class Map(dict):
|
|
"""
|
|
https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
|
|
Example:
|
|
m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
|
|
"""
|
|
def __init__(self, *args, **kwargs):
|
|
super(Map, self).__init__(*args, **kwargs)
|
|
for arg in args:
|
|
if isinstance(arg, dict):
|
|
for k, v in arg.items():
|
|
self[k] = v
|
|
|
|
if kwargs:
|
|
for k, v in kwargs.items():
|
|
self[k] = v
|
|
|
|
def __getattr__(self, attr):
|
|
return self.get(attr)
|
|
|
|
def __setattr__(self, key, value):
|
|
self.__setitem__(key, value)
|
|
|
|
def __setitem__(self, key, value):
|
|
super(Map, self).__setitem__(key, value)
|
|
self.__dict__.update({key: value})
|
|
|
|
def __delattr__(self, item):
|
|
self.__delitem__(item)
|
|
|
|
def __delitem__(self, key):
|
|
super(Map, self).__delitem__(key)
|
|
del self.__dict__[key]
|
|
|
|
def toDict(self):
|
|
return self.__dict__ |