eliphatfs
commited on
Commit
·
c064d59
1
Parent(s):
35b52f1
Add filtering support.
Browse files- openshape/demo/retrieval.py +20 -11
openshape/demo/retrieval.py
CHANGED
@@ -8,13 +8,19 @@ meta = json.load(
|
|
8 |
open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", token=True, repo_type='dataset'))
|
9 |
)
|
10 |
# {
|
11 |
-
#
|
12 |
-
#
|
13 |
-
#
|
14 |
-
#
|
15 |
-
#
|
16 |
-
#
|
17 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# }
|
19 |
meta = {x['u']: x for x in meta['entries']}
|
20 |
deser = torch.load(
|
@@ -24,17 +30,20 @@ us = deser['us']
|
|
24 |
feats = deser['feats']
|
25 |
|
26 |
|
27 |
-
def retrieve(embedding, top):
|
28 |
sims = []
|
29 |
embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze()
|
30 |
for chunk in torch.split(feats, 10240):
|
31 |
sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
|
32 |
sims = torch.cat(sims)
|
33 |
-
sims, idx = torch.
|
34 |
results = []
|
35 |
for i, sim in zip(idx, sims):
|
36 |
if us[i] in meta:
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
break
|
40 |
return results
|
|
|
8 |
open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", token=True, repo_type='dataset'))
|
9 |
)
|
10 |
# {
|
11 |
+
# "u": "94db219c315742909fee67deeeacae15",
|
12 |
+
# "name": "knife",
|
13 |
+
# "like": 0,
|
14 |
+
# "view": 35,
|
15 |
+
# "anims": 0,
|
16 |
+
# "tags": ["game-ready"],
|
17 |
+
# "cats": ["weapons-military"],
|
18 |
+
# "img": "https://media.sketchfab.com/models/94db219c315742909fee67deeeacae15/thumbnails/c0bbbd475d264ff2a92972f5115564ee/0cd28a130ebd4d9c9ef73190f24d9a42.jpeg",
|
19 |
+
# "desc": "",
|
20 |
+
# "faces": 1724,
|
21 |
+
# "size": 11955,
|
22 |
+
# "lic": "by",
|
23 |
+
# "glb": "glbs/000-000/94db219c315742909fee67deeeacae15.glb"
|
24 |
# }
|
25 |
meta = {x['u']: x for x in meta['entries']}
|
26 |
deser = torch.load(
|
|
|
30 |
feats = deser['feats']
|
31 |
|
32 |
|
33 |
+
def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
|
34 |
sims = []
|
35 |
embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze()
|
36 |
for chunk in torch.split(feats, 10240):
|
37 |
sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
|
38 |
sims = torch.cat(sims)
|
39 |
+
sims, idx = torch.sort(sims, descending=True)
|
40 |
results = []
|
41 |
for i, sim in zip(idx, sims):
|
42 |
if us[i] in meta:
|
43 |
+
if filter_fn is None or filter_fn(meta[us[i]]):
|
44 |
+
results.append(dict(meta[us[i]], sim=sim))
|
45 |
+
if len(results) >= top:
|
46 |
+
break
|
47 |
+
if sim < sim_th:
|
48 |
break
|
49 |
return results
|