from model import BrainEncodingModel | |
from config_utils import load_from_yaml | |
import torch | |
cfg = load_from_yaml("./config.yaml") | |
model = BrainEncodingModel(cfg) | |
sd_path = './ckpt.pth' | |
sd = torch.load(sd_path) | |
model.load_state_dict(sd) | |
model.eval().cuda() | |
x = torch.randn(1, 3, 224, 224) | |
def transform_image(x): | |
means = [0.485, 0.456, 0.406] | |
stds = [0.229, 0.224, 0.225] | |
x = (x - torch.tensor(means).view(1, 3, 1, 1)) / torch.tensor(stds).view(1, 3, 1, 1) | |
return x | |
x = transform_image(x) | |
x = x.cuda() | |
subject = 'subj01' # could be 1 of 8 subjects | |
with torch.no_grad(): | |
out = model(x, subject) | |
print(out.shape) | |
# torch.Size([1, 327684]) |