import h5py import numpy as np from tqdm import tqdm import torch from knowledge import TextDB class ImageCropsIdx: def __init__(self, knowledge_idx, topk_w, topk_f, topk_n): topk = {"whole": topk_w, "five": topk_f, "nine": topk_n} self.topk = {k: v for k, v in topk.items() if v > 0} self.knowledge_idx, self.fdim, self.file_hash = self.load(knowledge_idx, self.topk) def load(self, knowledge_idx, topk): with h5py.File(knowledge_idx, "r") as f: fdim = f.attrs["fdim"] file_hash = f.attrs["file_hash"] knowledge_idx_ = {} for i in tqdm(range(len(f)), desc="Load sentence idx", dynamic_ncols=True, mininterval=1.0): knowledge_idx_[str(i)] = {"image_ids": f[f"{i}/image_ids"][:]} for k, v in topk.items(): knowledge_idx_[str(i)][k] = { "index": f[f"{i}/{k}/index"][:, :, :v], "score": f[f"{i}/{k}/score"][:, :, :v], "query": f[f"{i}/{k}/query"][:] } knowledge_idx = {} for i in knowledge_idx_.keys(): for j, id in enumerate(knowledge_idx_[i]["image_ids"]): knowledge_idx[id] = {} for k in topk.keys(): knowledge_idx[id][k] = { "index": knowledge_idx_[i][k]["index"][j], "score": knowledge_idx_[i][k]["score"][j], "query": knowledge_idx_[i][k]["query"][j], } return knowledge_idx, fdim, file_hash def __getitem__(self, image_id): return self.knowledge_idx[image_id] class KnowAugImageCrops: def __init__(self, knowledge_db: TextDB, knowledge_idx: ImageCropsIdx, return_txt=False): self.knowledge_db = knowledge_db self.knowledge_idx = knowledge_idx assert knowledge_db.file_hash == knowledge_idx.file_hash self.ncrop = {"whole": 1, "five": 5, "nine": 9} self.topk = knowledge_idx.topk self.fdim = knowledge_idx.fdim self.return_txt = return_txt def __call__(self, image_id): ret = {} for k in self.topk.keys(): ki = self.knowledge_idx[image_id][k]["index"].flatten() ke, kt = self.knowledge_db[ki] kq = self.knowledge_idx[image_id][k]["query"] kp = np.tile(np.arange(self.ncrop[k])[:, None], (1, self.topk[k])).flatten() ks = self.knowledge_idx[image_id][k]["score"].flatten() ke = torch.FloatTensor(ke) kq = torch.FloatTensor(kq) kp = torch.LongTensor(kp) ks = torch.FloatTensor(ks) ret[k] = {"embed": ke, "query": kq, "pos": kp, "score": ks} if self.return_txt: ret[k]["text"] = kt return ret class KnowAugImageCropsCombined: def __init__( self, knwl_aug_obj: KnowAugImageCrops, knwl_aug_attr: KnowAugImageCrops, knwl_aug_act: KnowAugImageCrops ): self.knwl_aug_obj = knwl_aug_obj self.knwl_aug_act = knwl_aug_act self.knwl_aug_attr = knwl_aug_attr self.fdim = knwl_aug_obj.fdim def __call__(self, image_id): knwl_obj = self.knwl_aug_obj(image_id) knwl_attr = self.knwl_aug_attr(image_id) knwl_act = self.knwl_aug_act(image_id) ret = {} for k in knwl_obj.keys(): ret[k] = { "obj": knwl_obj[k], "attr": knwl_attr[k], "act": knwl_act[k] } return ret