File size: 3,658 Bytes
7962ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import h5py
import numpy as np
from tqdm import tqdm
import torch
from knowledge import TextDB


class ImageCropsIdx:
    def __init__(self, knowledge_idx, topk_w, topk_f, topk_n):
        topk = {"whole": topk_w, "five": topk_f, "nine": topk_n}
        self.topk = {k: v for k, v in topk.items() if v > 0}

        self.knowledge_idx, self.fdim, self.file_hash = self.load(knowledge_idx, self.topk)

    def load(self, knowledge_idx, topk):
        with h5py.File(knowledge_idx, "r") as f:
            fdim = f.attrs["fdim"]
            file_hash = f.attrs["file_hash"]

            knowledge_idx_ = {}
            for i in tqdm(range(len(f)), desc="Load sentence idx", dynamic_ncols=True, mininterval=1.0):
                knowledge_idx_[str(i)] = {"image_ids": f[f"{i}/image_ids"][:]}
                for k, v in topk.items():
                    knowledge_idx_[str(i)][k] = {
                        "index": f[f"{i}/{k}/index"][:, :, :v],
                        "score": f[f"{i}/{k}/score"][:, :, :v],
                        "query": f[f"{i}/{k}/query"][:]
                    }
        
        knowledge_idx = {}
        for i in knowledge_idx_.keys():
            for j, id in enumerate(knowledge_idx_[i]["image_ids"]):
                knowledge_idx[id] = {}
                for k in topk.keys():
                    knowledge_idx[id][k] = {
                        "index": knowledge_idx_[i][k]["index"][j],
                        "score": knowledge_idx_[i][k]["score"][j],
                        "query": knowledge_idx_[i][k]["query"][j],
                    }

        return knowledge_idx, fdim, file_hash
    
    def __getitem__(self, image_id):
        return self.knowledge_idx[image_id]
        

class KnowAugImageCrops:
    def __init__(self, knowledge_db: TextDB, knowledge_idx: ImageCropsIdx, return_txt=False):
        self.knowledge_db = knowledge_db
        self.knowledge_idx = knowledge_idx
        assert knowledge_db.file_hash == knowledge_idx.file_hash

        self.ncrop = {"whole": 1, "five": 5, "nine": 9}
        self.topk = knowledge_idx.topk
        self.fdim = knowledge_idx.fdim

        self.return_txt = return_txt
    
    def __call__(self, image_id):
        ret = {}
        for k in self.topk.keys():
            ki = self.knowledge_idx[image_id][k]["index"].flatten()
            ke, kt = self.knowledge_db[ki]
            kq = self.knowledge_idx[image_id][k]["query"]
            kp = np.tile(np.arange(self.ncrop[k])[:, None], (1, self.topk[k])).flatten()
            ks = self.knowledge_idx[image_id][k]["score"].flatten()

            ke = torch.FloatTensor(ke)
            kq = torch.FloatTensor(kq)
            kp = torch.LongTensor(kp)
            ks = torch.FloatTensor(ks)

            ret[k] = {"embed": ke, "query": kq, "pos": kp, "score": ks}
            if self.return_txt:
                ret[k]["text"] = kt
        
        return ret


class KnowAugImageCropsCombined:
    def __init__(
        self, 
        knwl_aug_obj: KnowAugImageCrops,
        knwl_aug_attr: KnowAugImageCrops,
        knwl_aug_act: KnowAugImageCrops
    ):
        self.knwl_aug_obj = knwl_aug_obj
        self.knwl_aug_act = knwl_aug_act
        self.knwl_aug_attr = knwl_aug_attr
        self.fdim = knwl_aug_obj.fdim

    def __call__(self, image_id):
        knwl_obj = self.knwl_aug_obj(image_id)
        knwl_attr = self.knwl_aug_attr(image_id)
        knwl_act = self.knwl_aug_act(image_id)

        ret = {}
        for k in knwl_obj.keys():
            ret[k] = {
                "obj": knwl_obj[k],
                "attr": knwl_attr[k],
                "act": knwl_act[k]
            }

        return ret