|
import dataclasses
|
|
import importlib
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import PIL
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from jaxtyping import Float, Int, Num
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from torch import Tensor
|
|
|
|
|
|
class BaseModule(nn.Module):
|
|
@dataclass
|
|
class Config:
|
|
pass
|
|
|
|
cfg: Config
|
|
|
|
def __init__(
|
|
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
|
|
) -> None:
|
|
super().__init__()
|
|
self.cfg = parse_structured(self.Config, cfg)
|
|
self.configure(*args, **kwargs)
|
|
|
|
def configure(self, *args, **kwargs) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
def find_class(cls_string):
|
|
module_string = ".".join(cls_string.split(".")[:-1])
|
|
cls_name = cls_string.split(".")[-1]
|
|
module = importlib.import_module(module_string, package=None)
|
|
cls = getattr(module, cls_name)
|
|
return cls
|
|
|
|
|
|
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
|
|
|
cfg_ = cfg.copy()
|
|
keys = list(cfg_.keys())
|
|
|
|
field_names = {f.name for f in dataclasses.fields(fields)}
|
|
for key in keys:
|
|
|
|
if key not in field_names:
|
|
print(f"Ignoring {key} as it's not supported by {fields}")
|
|
cfg_.pop(key)
|
|
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
|
|
return scfg
|
|
|
|
|
|
EPS_DTYPE = {
|
|
torch.float16: 1e-4,
|
|
torch.bfloat16: 1e-4,
|
|
torch.float32: 1e-7,
|
|
torch.float64: 1e-8,
|
|
}
|
|
|
|
|
|
def dot(x, y, dim=-1):
|
|
return torch.sum(x * y, dim, keepdim=True)
|
|
|
|
|
|
def reflect(x, n):
|
|
return x - 2 * dot(x, n) * n
|
|
|
|
|
|
def normalize(x, dim=-1, eps=None):
|
|
if eps is None:
|
|
eps = EPS_DTYPE[x.dtype]
|
|
return F.normalize(x, dim=dim, p=2, eps=eps)
|
|
|
|
|
|
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
|
|
|
|
|
|
def scale_tensor(
|
|
dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
|
|
):
|
|
if inp_scale is None:
|
|
inp_scale = (0, 1)
|
|
if tgt_scale is None:
|
|
tgt_scale = (0, 1)
|
|
if isinstance(tgt_scale, Tensor):
|
|
assert dat.shape[-1] == tgt_scale.shape[-1]
|
|
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
|
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
|
return dat
|
|
|
|
|
|
def dilate_fill(img, mask, iterations=10):
|
|
oldMask = mask.float()
|
|
oldImg = img
|
|
|
|
mask_kernel = torch.ones(
|
|
(1, 1, 3, 3),
|
|
dtype=oldMask.dtype,
|
|
device=oldMask.device,
|
|
)
|
|
|
|
for i in range(iterations):
|
|
newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
|
|
|
|
|
|
img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
|
|
mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
|
new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
|
|
|
|
|
mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
|
|
2
|
|
)
|
|
|
|
fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
|
|
|
|
mask_conv = F.conv2d(
|
|
newMask, mask_kernel, padding=1
|
|
)
|
|
newImg = F.fold(
|
|
fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
|
|
) / mask_conv.clamp(1)
|
|
|
|
diffMask = newMask - oldMask
|
|
|
|
oldMask = newMask
|
|
oldImg = torch.lerp(oldImg, newImg, diffMask)
|
|
|
|
return oldImg
|
|
|
|
|
|
def float32_to_uint8_np(
|
|
x: Float[np.ndarray, "*B H W C"],
|
|
dither: bool = True,
|
|
dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
|
|
dither_strength: float = 1.0,
|
|
) -> Int[np.ndarray, "*B H W C"]:
|
|
if dither:
|
|
dither = (
|
|
dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
|
|
)
|
|
if dither_mask is not None:
|
|
dither = dither * dither_mask
|
|
return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
|
|
return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
|
|
|
|
|
|
def convert_data(data):
|
|
if data is None:
|
|
return None
|
|
elif isinstance(data, np.ndarray):
|
|
return data
|
|
elif isinstance(data, torch.Tensor):
|
|
if data.dtype in [torch.float16, torch.bfloat16]:
|
|
data = data.float()
|
|
return data.detach().cpu().numpy()
|
|
elif isinstance(data, list):
|
|
return [convert_data(d) for d in data]
|
|
elif isinstance(data, dict):
|
|
return {k: convert_data(v) for k, v in data.items()}
|
|
else:
|
|
raise TypeError(
|
|
"Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
|
|
type(data),
|
|
)
|
|
|
|
|
|
class ImageProcessor:
|
|
def convert_and_resize(
|
|
self,
|
|
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
|
size: int,
|
|
):
|
|
if isinstance(image, PIL.Image.Image):
|
|
image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
|
elif isinstance(image, np.ndarray):
|
|
if image.dtype == np.uint8:
|
|
image = torch.from_numpy(image.astype(np.float32) / 255.0)
|
|
else:
|
|
image = torch.from_numpy(image)
|
|
elif isinstance(image, torch.Tensor):
|
|
pass
|
|
|
|
batched = image.ndim >= 4
|
|
view_batch = image.ndim >= 5
|
|
|
|
if view_batch:
|
|
image = image.view(-1, *image.shape[2:])
|
|
elif not batched:
|
|
image = image[None, ...]
|
|
|
|
image = F.interpolate(
|
|
image.permute(0, 3, 1, 2),
|
|
(size, size),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
antialias=True,
|
|
).permute(0, 2, 3, 1)
|
|
if not batched:
|
|
image = image[0]
|
|
return image
|
|
|
|
def __call__(
|
|
self,
|
|
image: Union[
|
|
PIL.Image.Image,
|
|
np.ndarray,
|
|
torch.FloatTensor,
|
|
List[PIL.Image.Image],
|
|
List[np.ndarray],
|
|
List[torch.FloatTensor],
|
|
],
|
|
size: int,
|
|
) -> Any:
|
|
if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
|
|
image = self.convert_and_resize(image, size)
|
|
else:
|
|
if not isinstance(image, list):
|
|
image = [image]
|
|
image = [self.convert_and_resize(im, size) for im in image]
|
|
image = torch.stack(image, dim=0)
|
|
return image
|
|
|
|
|
|
def get_intrinsic_from_fov(fov, H, W, bs=-1):
|
|
focal_length = 0.5 * H / np.tan(0.5 * fov)
|
|
intrinsic = np.identity(3, dtype=np.float32)
|
|
intrinsic[0, 0] = focal_length
|
|
intrinsic[1, 1] = focal_length
|
|
intrinsic[0, 2] = W / 2.0
|
|
intrinsic[1, 2] = H / 2.0
|
|
|
|
if bs > 0:
|
|
intrinsic = intrinsic[None].repeat(bs, axis=0)
|
|
|
|
return torch.from_numpy(intrinsic)
|
|
|