nsd_model / example.py
huzey's picture
Upload ./example.py with huggingface_hub
978005b verified
raw
history blame
667 Bytes
from model import BrainEncodingModel
from config_utils import load_from_yaml
import torch
cfg = load_from_yaml("./config.yaml")
model = BrainEncodingModel(cfg)
sd_path = './ckpt.pth'
sd = torch.load(sd_path)
model.load_state_dict(sd)
model.eval().cuda()
x = torch.randn(1, 3, 224, 224)
def transform_image(x):
means = [0.485, 0.456, 0.406]
stds = [0.229, 0.224, 0.225]
x = (x - torch.tensor(means).view(1, 3, 1, 1)) / torch.tensor(stds).view(1, 3, 1, 1)
return x
x = transform_image(x)
x = x.cuda()
subject = 'subj01' # could be 1 of 8 subjects
with torch.no_grad():
out = model(x, subject)
print(out.shape)
# torch.Size([1, 327684])