Spaces:
Runtime error
Runtime error
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
def requires_grad(model, flag=True): | |
for p in model.parameters(): | |
p.requires_grad = flag | |
def get_keys(d, name): | |
if 'state_dict' in d: | |
d = d['state_dict'] | |
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} | |
return d_filt | |
def load_img(path_img, img_size=(256, 256)): | |
transform = transforms.Compose( | |
[transforms.Resize(img_size), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5))]) | |
if type(path_img) is np.ndarray: | |
img = Image.fromarray(path_img) | |
else: | |
img = Image.open(path_img).convert('RGB') | |
img = transform(img) | |
img.unsqueeze_(0) | |
return img | |