import time import os import torch import numpy as np import torchvision import torch.nn.functional as F from torchvision.datasets import ImageFolder import torchvision.transforms as transforms from tqdm import tqdm import pickle import argparse from PIL import Image concat = lambda x: np.concatenate(x, axis=0) to_np = lambda x: x.data.to("cpu").numpy() class Wrapper(torch.nn.Module): def __init__(self, model): super(Wrapper, self).__init__() self.model = model self.avgpool_output = None self.query = None self.cossim_value = {} def fw_hook(module, input, output): self.avgpool_output = output.squeeze() self.model.avgpool.register_forward_hook(fw_hook) def forward(self, input): _ = self.model(input) return self.avgpool_output def __repr__(self): return "Wrappper" def QueryToEmbedding(query_pil): dataset_transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) model = torchvision.models.resnet50(pretrained=True) model.eval() myw = Wrapper(model) # query_pil = Image.open(query_path) query_pt = dataset_transform(query_pil) with torch.no_grad(): embedding = to_np(myw(query_pt.unsqueeze(0))) return np.asarray([embedding])