from model import BrainEncodingModel from config_utils import load_from_yaml import torch cfg = load_from_yaml("./config.yaml") model = BrainEncodingModel(cfg) sd_path = './ckpt.pth' sd = torch.load(sd_path) model.load_state_dict(sd) model.eval().cuda() x = torch.randn(1, 3, 224, 224) def transform_image(x): means = [0.485, 0.456, 0.406] stds = [0.229, 0.224, 0.225] x = (x - torch.tensor(means).view(1, 3, 1, 1)) / torch.tensor(stds).view(1, 3, 1, 1) return x x = transform_image(x) x = x.cuda() subject = 'subj01' # could be 1 of 8 subjects with torch.no_grad(): out = model(x, subject) print(out.shape) # torch.Size([1, 327684])