import enum |
import logging |
import os |
import cv2 |
import torch |
import numpy as np |
from PIL import ExifTags |
from PIL import Image |
import collections |
import random |
from internal import vis |
from matplotlib import cm |
class Timing: |
""" |
Timing environment |
usage: |
with Timing("message"): |
your commands here |
will print CUDA runtime in ms |
""" |
def __init__(self, name): |
self.name = name |
def __enter__(self): |
self.start = torch.cuda.Event(enable_timing=True) |
self.end = torch.cuda.Event(enable_timing=True) |
self.start.record() |
def __exit__(self, type, value, traceback): |
self.end.record() |
torch.cuda.synchronize() |
print(self.name, "elapsed", self.start.elapsed_time(self.end), "ms") |
def handle_exception(exc_type, exc_value, exc_traceback): |
logging.error("Error!", exc_info=(exc_type, exc_value, exc_traceback)) |
def nan_sum(x): |
return (torch.isnan(x) | torch.isinf(x)).sum() |
def flatten_dict(d, parent_key='', sep='_'): |
items = [] |
for k, v in d.items(): |
new_key = parent_key + sep + k if parent_key else k |
if isinstance(v, collections.abc.MutableMapping): |
items.extend(flatten_dict(v, new_key, sep=sep).items()) |
else: |
items.append((new_key, v)) |
return dict(items) |
class DataSplit(enum.Enum): |
"""Dataset split.""" |
TRAIN = 'train' |
TEST = 'test' |
class BatchingMethod(enum.Enum): |
"""Draw rays randomly from a single image or all images, in each batch.""" |
ALL_IMAGES = 'all_images' |
SINGLE_IMAGE = 'single_image' |
def open_file(pth, mode='r'): |
return open(pth, mode=mode) |
def file_exists(pth): |
return os.path.exists(pth) |
def listdir(pth): |
return os.listdir(pth) |
def isdir(pth): |
return os.path.isdir(pth) |
def makedirs(pth): |
os.makedirs(pth, exist_ok=True) |
def load_img(pth): |
"""Load an image and cast to float32.""" |
image = np.array(Image.open(pth), dtype=np.float32) |
return image |
def load_exif(pth): |
"""Load EXIF data for an image.""" |
with open_file(pth, 'rb') as f: |
image_pil = Image.open(f) |
exif_pil = image_pil._getexif() |
if exif_pil is not None: |
exif = { |
ExifTags.TAGS[k]: v for k, v in exif_pil.items() if k in ExifTags.TAGS |
} |
else: |
exif = {} |
return exif |
def save_img_u8(img, pth): |
"""Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" |
Image.fromarray( |
(np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save( |
pth, 'PNG') |
def save_img_f32(depthmap, pth, p=0.5): |
"""Save an image (probably a depthmap) to disk as a float32 TIFF.""" |
Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(pth, 'TIFF') |