|
import argparse |
|
import warnings |
|
import numpy as np |
|
import torch |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional") |
|
|
|
|
|
def str2bool(v: str, strict=True) -> bool: |
|
if isinstance(v, bool): |
|
return v |
|
elif isinstance(v, str): |
|
if v.lower() in ("true", "yes", "on" "t", "y", "1"): |
|
return True |
|
elif v.lower() in ("false", "no", "off", "f", "n", "0"): |
|
return False |
|
if strict: |
|
raise argparse.ArgumentTypeError("Unsupported value encountered.") |
|
else: |
|
return True |
|
|
|
|
|
def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None): |
|
if isinstance(data, torch.Tensor): |
|
data = data.to(device) |
|
elif isinstance(data, (tuple, list, set)): |
|
data = [to_cuda(b, device) for b in data] |
|
elif isinstance(data, dict): |
|
if exclude_keys is None: |
|
exclude_keys = [] |
|
for k in data.keys(): |
|
if k not in exclude_keys: |
|
data[k] = to_cuda(data[k], device) |
|
else: |
|
|
|
data = data |
|
return data |
|
|
|
|
|
def pad_img_to_square(img: np.ndarray): |
|
H, W = img.shape[:2] |
|
if H != W: |
|
new_size = max(H, W) |
|
img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant") |
|
assert img.shape[0] == img.shape[1] == new_size |
|
return img |
|
|