distildire / utils /utils.py
Yewon Lim
first
424919d
import argparse
import warnings
import numpy as np
import torch
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
def str2bool(v: str, strict=True) -> bool:
if isinstance(v, bool):
return v
elif isinstance(v, str):
if v.lower() in ("true", "yes", "on" "t", "y", "1"):
return True
elif v.lower() in ("false", "no", "off", "f", "n", "0"):
return False
if strict:
raise argparse.ArgumentTypeError("Unsupported value encountered.")
else:
return True
def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None):
if isinstance(data, torch.Tensor):
data = data.to(device)
elif isinstance(data, (tuple, list, set)):
data = [to_cuda(b, device) for b in data]
elif isinstance(data, dict):
if exclude_keys is None:
exclude_keys = []
for k in data.keys():
if k not in exclude_keys:
data[k] = to_cuda(data[k], device)
else:
# raise TypeError(f"Unsupported type: {type(data)}")
data = data
return data
def pad_img_to_square(img: np.ndarray):
H, W = img.shape[:2]
if H != W:
new_size = max(H, W)
img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant")
assert img.shape[0] == img.shape[1] == new_size
return img