File size: 1,482 Bytes
bbd199b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import time
import os
import torch
import numpy as np
import torchvision
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from tqdm import tqdm
import pickle
import argparse
from PIL import Image
concat = lambda x: np.concatenate(x, axis=0)
to_np = lambda x: x.data.to("cpu").numpy()
class Wrapper(torch.nn.Module):
def __init__(self, model):
super(Wrapper, self).__init__()
self.model = model
self.avgpool_output = None
self.query = None
self.cossim_value = {}
def fw_hook(module, input, output):
self.avgpool_output = output.squeeze()
self.model.avgpool.register_forward_hook(fw_hook)
def forward(self, input):
_ = self.model(input)
return self.avgpool_output
def __repr__(self):
return "Wrappper"
def QueryToEmbedding(query_pil):
dataset_transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
model = torchvision.models.resnet50(pretrained=True)
model.eval()
myw = Wrapper(model)
# query_pil = Image.open(query_path)
query_pt = dataset_transform(query_pil)
with torch.no_grad():
embedding = to_np(myw(query_pt.unsqueeze(0)))
return np.asarray([embedding])
|