eliphatfs commited on
Commit
c064d59
·
1 Parent(s): 35b52f1

Add filtering support.

Browse files
Files changed (1) hide show
  1. openshape/demo/retrieval.py +20 -11
openshape/demo/retrieval.py CHANGED
@@ -8,13 +8,19 @@ meta = json.load(
8
  open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", token=True, repo_type='dataset'))
9
  )
10
  # {
11
- # "u": "94db219c315742909fee67deeeacae15",
12
- # "name": "knife", "like": 0, "view": 35,
13
- # "tags": ["game-ready", "damascus", "damascus_steel", "kabar-knife", "knife", "blender", "blender3d", "gameready"],
14
- # "cats": ["weapons-military"],
15
- # "img": "https://media.sketchfab.com/models/94db219c315742909fee67deeeacae15/thumbnails/c0bbbd475d264ff2a92972f5115564ee/0cd28a130ebd4d9c9ef73190f24d9a42.jpeg",
16
- # "desc": "", "faces": 1724, "size": 11955, "lic": "by",
17
- # "glb": "glbs/000-000/94db219c315742909fee67deeeacae15.glb"
 
 
 
 
 
 
18
  # }
19
  meta = {x['u']: x for x in meta['entries']}
20
  deser = torch.load(
@@ -24,17 +30,20 @@ us = deser['us']
24
  feats = deser['feats']
25
 
26
 
27
- def retrieve(embedding, top):
28
  sims = []
29
  embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze()
30
  for chunk in torch.split(feats, 10240):
31
  sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
32
  sims = torch.cat(sims)
33
- sims, idx = torch.topk(sims, top * 2)
34
  results = []
35
  for i, sim in zip(idx, sims):
36
  if us[i] in meta:
37
- results.append(dict(meta[us[i]], sim=sim))
38
- if len(results) >= top:
 
 
 
39
  break
40
  return results
 
8
  open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", token=True, repo_type='dataset'))
9
  )
10
  # {
11
+ # "u": "94db219c315742909fee67deeeacae15",
12
+ # "name": "knife",
13
+ # "like": 0,
14
+ # "view": 35,
15
+ # "anims": 0,
16
+ # "tags": ["game-ready"],
17
+ # "cats": ["weapons-military"],
18
+ # "img": "https://media.sketchfab.com/models/94db219c315742909fee67deeeacae15/thumbnails/c0bbbd475d264ff2a92972f5115564ee/0cd28a130ebd4d9c9ef73190f24d9a42.jpeg",
19
+ # "desc": "",
20
+ # "faces": 1724,
21
+ # "size": 11955,
22
+ # "lic": "by",
23
+ # "glb": "glbs/000-000/94db219c315742909fee67deeeacae15.glb"
24
  # }
25
  meta = {x['u']: x for x in meta['entries']}
26
  deser = torch.load(
 
30
  feats = deser['feats']
31
 
32
 
33
+ def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
34
  sims = []
35
  embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze()
36
  for chunk in torch.split(feats, 10240):
37
  sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
38
  sims = torch.cat(sims)
39
+ sims, idx = torch.sort(sims, descending=True)
40
  results = []
41
  for i, sim in zip(idx, sims):
42
  if us[i] in meta:
43
+ if filter_fn is None or filter_fn(meta[us[i]]):
44
+ results.append(dict(meta[us[i]], sim=sim))
45
+ if len(results) >= top:
46
+ break
47
+ if sim < sim_th:
48
  break
49
  return results