wmpscc
add
7d1312d
raw
history blame
807 Bytes
import numpy as np
from PIL import Image
from torchvision import transforms
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
return d_filt
def load_img(path_img, img_size=(256, 256)):
transform = transforms.Compose(
[transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
if type(path_img) is np.ndarray:
img = Image.fromarray(path_img)
else:
img = Image.open(path_img).convert('RGB')
img = transform(img)
img.unsqueeze_(0)
return img