import numpy as np from PIL import Image from torchvision import transforms def requires_grad(model, flag=True): for p in model.parameters(): p.requires_grad = flag def get_keys(d, name): if 'state_dict' in d: d = d['state_dict'] d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} return d_filt def load_img(path_img, img_size=(256, 256)): transform = transforms.Compose( [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) if type(path_img) is np.ndarray: img = Image.fromarray(path_img) else: img = Image.open(path_img).convert('RGB') img = transform(img) img.unsqueeze_(0) return img