|
import enum |
|
import logging |
|
import os |
|
|
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import ExifTags |
|
from PIL import Image |
|
import collections |
|
import random |
|
from internal import vis |
|
from matplotlib import cm |
|
|
|
|
|
class Timing: |
|
""" |
|
Timing environment |
|
usage: |
|
with Timing("message"): |
|
your commands here |
|
will print CUDA runtime in ms |
|
""" |
|
|
|
def __init__(self, name): |
|
self.name = name |
|
|
|
def __enter__(self): |
|
self.start = torch.cuda.Event(enable_timing=True) |
|
self.end = torch.cuda.Event(enable_timing=True) |
|
self.start.record() |
|
|
|
def __exit__(self, type, value, traceback): |
|
self.end.record() |
|
torch.cuda.synchronize() |
|
print(self.name, "elapsed", self.start.elapsed_time(self.end), "ms") |
|
|
|
|
|
def handle_exception(exc_type, exc_value, exc_traceback): |
|
logging.error("Error!", exc_info=(exc_type, exc_value, exc_traceback)) |
|
|
|
|
|
def nan_sum(x): |
|
return (torch.isnan(x) | torch.isinf(x)).sum() |
|
|
|
|
|
def flatten_dict(d, parent_key='', sep='_'): |
|
items = [] |
|
for k, v in d.items(): |
|
new_key = parent_key + sep + k if parent_key else k |
|
if isinstance(v, collections.abc.MutableMapping): |
|
items.extend(flatten_dict(v, new_key, sep=sep).items()) |
|
else: |
|
items.append((new_key, v)) |
|
return dict(items) |
|
|
|
|
|
class DataSplit(enum.Enum): |
|
"""Dataset split.""" |
|
TRAIN = 'train' |
|
TEST = 'test' |
|
|
|
|
|
class BatchingMethod(enum.Enum): |
|
"""Draw rays randomly from a single image or all images, in each batch.""" |
|
ALL_IMAGES = 'all_images' |
|
SINGLE_IMAGE = 'single_image' |
|
|
|
|
|
def open_file(pth, mode='r'): |
|
return open(pth, mode=mode) |
|
|
|
|
|
def file_exists(pth): |
|
return os.path.exists(pth) |
|
|
|
|
|
def listdir(pth): |
|
return os.listdir(pth) |
|
|
|
|
|
def isdir(pth): |
|
return os.path.isdir(pth) |
|
|
|
|
|
def makedirs(pth): |
|
os.makedirs(pth, exist_ok=True) |
|
|
|
|
|
def load_img(pth): |
|
"""Load an image and cast to float32.""" |
|
image = np.array(Image.open(pth), dtype=np.float32) |
|
return image |
|
|
|
|
|
def load_exif(pth): |
|
"""Load EXIF data for an image.""" |
|
with open_file(pth, 'rb') as f: |
|
image_pil = Image.open(f) |
|
exif_pil = image_pil._getexif() |
|
if exif_pil is not None: |
|
exif = { |
|
ExifTags.TAGS[k]: v for k, v in exif_pil.items() if k in ExifTags.TAGS |
|
} |
|
else: |
|
exif = {} |
|
return exif |
|
|
|
|
|
def save_img_u8(img, pth): |
|
"""Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" |
|
Image.fromarray( |
|
(np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save( |
|
pth, 'PNG') |
|
|
|
|
|
def save_img_f32(depthmap, pth, p=0.5): |
|
"""Save an image (probably a depthmap) to disk as a float32 TIFF.""" |
|
Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(pth, 'TIFF') |
|
|