eliphatfs
commited on
Commit
•
7a4df11
1
Parent(s):
7ee7303
Publish.
Browse files- .gitignore +2 -0
- openshape/__init__.py +47 -0
- openshape/demo/__init__.py +0 -0
- openshape/demo/caption.py +163 -0
- openshape/demo/classification.py +13 -0
- openshape/demo/lvis.py +1162 -0
- openshape/demo/lvis_cats.pt +3 -0
- openshape/demo/misc_utils.py +153 -0
- openshape/demo/retrieval.py +40 -0
- openshape/demo/sd_pc2img.py +38 -0
- openshape/pointnet_util.py +323 -0
- openshape/ppat_rgb.py +118 -0
- setup.py +21 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.egg-info
|
openshape/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
from .ppat_rgb import Projected, PointPatchTransformer
|
5 |
+
|
6 |
+
|
7 |
+
def module(state_dict: dict, name):
|
8 |
+
return {'.'.join(k.split('.')[1:]): v for k, v in state_dict.items() if k.startswith(name + '.')}
|
9 |
+
|
10 |
+
|
11 |
+
def G14(s):
|
12 |
+
model = Projected(
|
13 |
+
PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6),
|
14 |
+
nn.Linear(512, 1280)
|
15 |
+
)
|
16 |
+
model.load_state_dict(module(s['state_dict'], 'module'))
|
17 |
+
return model
|
18 |
+
|
19 |
+
|
20 |
+
def L14(s):
|
21 |
+
model = Projected(
|
22 |
+
PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6),
|
23 |
+
nn.Linear(512, 768)
|
24 |
+
)
|
25 |
+
model.load_state_dict(module(s, 'pc_encoder'))
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
def B32(s):
|
30 |
+
model = PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6)
|
31 |
+
model.load_state_dict(module(s, 'pc_encoder'))
|
32 |
+
return model
|
33 |
+
|
34 |
+
|
35 |
+
model_list = {
|
36 |
+
"openshape-pointbert-vitb32-rgb": B32,
|
37 |
+
"openshape-pointbert-vitl14-rgb": L14,
|
38 |
+
"openshape-pointbert-vitg14-rgb": G14,
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def load_pc_encoder(name):
|
43 |
+
s = torch.load(hf_hub_download("OpenShape/" + name, "model.pt", token=True), map_location='cpu')
|
44 |
+
model = model_list[name](s).eval()
|
45 |
+
if torch.cuda.is_available():
|
46 |
+
model.cuda()
|
47 |
+
return model
|
openshape/demo/__init__.py
ADDED
File without changes
|
openshape/demo/caption.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from typing import Tuple, List, Union, Optional
|
5 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
|
9 |
+
N = type(None)
|
10 |
+
V = np.array
|
11 |
+
ARRAY = np.ndarray
|
12 |
+
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
|
13 |
+
VS = Union[Tuple[V, ...], List[V]]
|
14 |
+
VN = Union[V, N]
|
15 |
+
VNS = Union[VS, N]
|
16 |
+
T = torch.Tensor
|
17 |
+
TS = Union[Tuple[T, ...], List[T]]
|
18 |
+
TN = Optional[T]
|
19 |
+
TNS = Union[Tuple[TN, ...], List[TN]]
|
20 |
+
TSN = Optional[TS]
|
21 |
+
TA = Union[T, ARRAY]
|
22 |
+
|
23 |
+
|
24 |
+
D = torch.device
|
25 |
+
|
26 |
+
|
27 |
+
class MLP(nn.Module):
|
28 |
+
|
29 |
+
def forward(self, x: T) -> T:
|
30 |
+
return self.model(x)
|
31 |
+
|
32 |
+
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
33 |
+
super(MLP, self).__init__()
|
34 |
+
layers = []
|
35 |
+
for i in range(len(sizes) -1):
|
36 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
37 |
+
if i < len(sizes) - 2:
|
38 |
+
layers.append(act())
|
39 |
+
self.model = nn.Sequential(*layers)
|
40 |
+
|
41 |
+
|
42 |
+
class ClipCaptionModel(nn.Module):
|
43 |
+
|
44 |
+
#@functools.lru_cache #FIXME
|
45 |
+
def get_dummy_token(self, batch_size: int, device: D) -> T:
|
46 |
+
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
|
47 |
+
|
48 |
+
def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
|
49 |
+
embedding_text = self.gpt.transformer.wte(tokens)
|
50 |
+
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
|
51 |
+
#print(embedding_text.size()) #torch.Size([5, 67, 768])
|
52 |
+
#print(prefix_projections.size()) #torch.Size([5, 1, 768])
|
53 |
+
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
|
54 |
+
if labels is not None:
|
55 |
+
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
56 |
+
labels = torch.cat((dummy_token, tokens), dim=1)
|
57 |
+
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
|
58 |
+
return out
|
59 |
+
|
60 |
+
def __init__(self, prefix_length: int, prefix_size: int = 512):
|
61 |
+
super(ClipCaptionModel, self).__init__()
|
62 |
+
self.prefix_length = prefix_length
|
63 |
+
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
|
64 |
+
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
65 |
+
if prefix_length > 10: # not enough memory
|
66 |
+
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
|
67 |
+
else:
|
68 |
+
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
|
69 |
+
|
70 |
+
|
71 |
+
class ClipCaptionPrefix(ClipCaptionModel):
|
72 |
+
|
73 |
+
def parameters(self, recurse: bool = True):
|
74 |
+
return self.clip_project.parameters()
|
75 |
+
|
76 |
+
def train(self, mode: bool = True):
|
77 |
+
super(ClipCaptionPrefix, self).train(mode)
|
78 |
+
self.gpt.eval()
|
79 |
+
return self
|
80 |
+
|
81 |
+
|
82 |
+
def generate2(
|
83 |
+
model,
|
84 |
+
tokenizer,
|
85 |
+
tokens=None,
|
86 |
+
prompt=None,
|
87 |
+
embed=None,
|
88 |
+
entry_count=1,
|
89 |
+
entry_length=67, # maximum number of words
|
90 |
+
top_p=0.8,
|
91 |
+
temperature=1.,
|
92 |
+
stop_token: str = '.',
|
93 |
+
):
|
94 |
+
model.eval()
|
95 |
+
generated_num = 0
|
96 |
+
generated_list = []
|
97 |
+
stop_token_index = tokenizer.encode(stop_token)[0]
|
98 |
+
filter_value = -float("Inf")
|
99 |
+
device = next(model.parameters()).device
|
100 |
+
score_col = []
|
101 |
+
with torch.no_grad():
|
102 |
+
|
103 |
+
for entry_idx in range(entry_count):
|
104 |
+
if embed is not None:
|
105 |
+
generated = embed
|
106 |
+
else:
|
107 |
+
if tokens is None:
|
108 |
+
tokens = torch.tensor(tokenizer.encode(prompt))
|
109 |
+
tokens = tokens.unsqueeze(0).to(device)
|
110 |
+
|
111 |
+
generated = model.gpt.transformer.wte(tokens)
|
112 |
+
|
113 |
+
for i in range(entry_length):
|
114 |
+
|
115 |
+
outputs = model.gpt(inputs_embeds=generated)
|
116 |
+
logits = outputs.logits
|
117 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
118 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
119 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
120 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
121 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
122 |
+
..., :-1
|
123 |
+
].clone()
|
124 |
+
sorted_indices_to_remove[..., 0] = 0
|
125 |
+
|
126 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
127 |
+
logits[:, indices_to_remove] = filter_value
|
128 |
+
next_token = torch.argmax(torch.softmax(logits, dim=-1), -1).reshape(1, 1)
|
129 |
+
score = torch.softmax(logits, dim=-1).reshape(-1)[next_token.item()].item()
|
130 |
+
score_col.append(score)
|
131 |
+
next_token_embed = model.gpt.transformer.wte(next_token)
|
132 |
+
if tokens is None:
|
133 |
+
tokens = next_token
|
134 |
+
else:
|
135 |
+
tokens = torch.cat((tokens, next_token), dim=1)
|
136 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
137 |
+
if stop_token_index == next_token.item():
|
138 |
+
break
|
139 |
+
|
140 |
+
output_list = list(tokens.squeeze(0).cpu().numpy())
|
141 |
+
output_text = tokenizer.decode(output_list)
|
142 |
+
generated_list.append(output_text)
|
143 |
+
return generated_list[0]
|
144 |
+
|
145 |
+
|
146 |
+
@torch.no_grad()
|
147 |
+
def pc_caption(pc_encoder: torch.nn.Module, pc, cond_scale):
|
148 |
+
ref_dev = next(pc_encoder.parameters()).device
|
149 |
+
prefix = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
|
150 |
+
prefix = prefix.float() * cond_scale
|
151 |
+
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
152 |
+
text = generate2(model, tokenizer, embed=prefix_embed)
|
153 |
+
return text
|
154 |
+
|
155 |
+
|
156 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
157 |
+
prefix_length = 10
|
158 |
+
model = ClipCaptionModel(prefix_length)
|
159 |
+
# print(model.gpt_embedding_size)
|
160 |
+
model.load_state_dict(torch.load(hf_hub_download('OpenShape/clipcap-cc', 'conceptual_weights.pt', token=True), map_location='cpu'))
|
161 |
+
model.eval()
|
162 |
+
if torch.cuda.is_available():
|
163 |
+
model = model.cuda()
|
openshape/demo/classification.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from collections import OrderedDict
|
4 |
+
from . import lvis
|
5 |
+
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def pred_lvis_sims(pc_encoder: torch.nn.Module, pc):
|
9 |
+
ref_dev = next(pc_encoder.parameters()).device
|
10 |
+
enc = pc_encoder(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
|
11 |
+
sim = torch.matmul(F.normalize(lvis.feats, dim=-1), F.normalize(enc, dim=-1).squeeze())
|
12 |
+
argsort = torch.argsort(sim, descending=True)
|
13 |
+
return OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
openshape/demo/lvis.py
ADDED
@@ -0,0 +1,1162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
feats = torch.load(os.path.join(os.path.dirname(__file__), 'lvis_cats.pt'))
|
6 |
+
categories = [
|
7 |
+
'Band_Aid',
|
8 |
+
'Bible',
|
9 |
+
'CD_player',
|
10 |
+
'Christmas_tree',
|
11 |
+
'Dixie_cup',
|
12 |
+
'Ferris_wheel',
|
13 |
+
'Lego',
|
14 |
+
'Rollerblade',
|
15 |
+
'Sharpie',
|
16 |
+
'Tabasco_sauce',
|
17 |
+
'aerosol_can',
|
18 |
+
'air_conditioner',
|
19 |
+
'airplane',
|
20 |
+
'alarm_clock',
|
21 |
+
'alcohol',
|
22 |
+
'alligator',
|
23 |
+
'almond',
|
24 |
+
'ambulance',
|
25 |
+
'amplifier',
|
26 |
+
'anklet',
|
27 |
+
'antenna',
|
28 |
+
'apple',
|
29 |
+
'apricot',
|
30 |
+
'apron',
|
31 |
+
'aquarium',
|
32 |
+
'arctic_(type_of_shoe)',
|
33 |
+
'armband',
|
34 |
+
'armchair',
|
35 |
+
'armoire',
|
36 |
+
'armor',
|
37 |
+
'army_tank',
|
38 |
+
'artichoke',
|
39 |
+
'ashtray',
|
40 |
+
'asparagus',
|
41 |
+
'atomizer',
|
42 |
+
'automatic_washer',
|
43 |
+
'avocado',
|
44 |
+
'award',
|
45 |
+
'awning',
|
46 |
+
'ax',
|
47 |
+
'baboon',
|
48 |
+
'baby_buggy',
|
49 |
+
'backpack',
|
50 |
+
'bagel',
|
51 |
+
'baguet',
|
52 |
+
'bait',
|
53 |
+
'ball',
|
54 |
+
'ballet_skirt',
|
55 |
+
'balloon',
|
56 |
+
'bamboo',
|
57 |
+
'banana',
|
58 |
+
'bandage',
|
59 |
+
'bandanna',
|
60 |
+
'banjo',
|
61 |
+
'banner',
|
62 |
+
'barbell',
|
63 |
+
'barge',
|
64 |
+
'barrel',
|
65 |
+
'barrow',
|
66 |
+
'baseball',
|
67 |
+
'baseball_bat',
|
68 |
+
'baseball_cap',
|
69 |
+
'baseball_glove',
|
70 |
+
'basket',
|
71 |
+
'basketball',
|
72 |
+
'basketball_backboard',
|
73 |
+
'bass_horn',
|
74 |
+
'bat_(animal)',
|
75 |
+
'bath_mat',
|
76 |
+
'bath_towel',
|
77 |
+
'bathrobe',
|
78 |
+
'bathtub',
|
79 |
+
'battery',
|
80 |
+
'beachball',
|
81 |
+
'bead',
|
82 |
+
'beanbag',
|
83 |
+
'beanie',
|
84 |
+
'bear',
|
85 |
+
'bed',
|
86 |
+
'bedpan',
|
87 |
+
'bedspread',
|
88 |
+
'beef_(food)',
|
89 |
+
'beeper',
|
90 |
+
'beer_bottle',
|
91 |
+
'beer_can',
|
92 |
+
'beetle',
|
93 |
+
'bell',
|
94 |
+
'bell_pepper',
|
95 |
+
'belt',
|
96 |
+
'belt_buckle',
|
97 |
+
'bench',
|
98 |
+
'beret',
|
99 |
+
'bicycle',
|
100 |
+
'billboard',
|
101 |
+
'binder',
|
102 |
+
'binoculars',
|
103 |
+
'bird',
|
104 |
+
'birdbath',
|
105 |
+
'birdcage',
|
106 |
+
'birdfeeder',
|
107 |
+
'birdhouse',
|
108 |
+
'birthday_cake',
|
109 |
+
'birthday_card',
|
110 |
+
'blackberry',
|
111 |
+
'blackboard',
|
112 |
+
'blanket',
|
113 |
+
'blazer',
|
114 |
+
'blender',
|
115 |
+
'blimp',
|
116 |
+
'blouse',
|
117 |
+
'blueberry',
|
118 |
+
'boat',
|
119 |
+
'bob',
|
120 |
+
'bobbin',
|
121 |
+
'boiled_egg',
|
122 |
+
'bolo_tie',
|
123 |
+
'bolt',
|
124 |
+
'bonnet',
|
125 |
+
'book',
|
126 |
+
'bookcase',
|
127 |
+
'booklet',
|
128 |
+
'bookmark',
|
129 |
+
'boom_microphone',
|
130 |
+
'boot',
|
131 |
+
'bottle',
|
132 |
+
'bottle_cap',
|
133 |
+
'bottle_opener',
|
134 |
+
'bouquet',
|
135 |
+
'bow-tie',
|
136 |
+
'bow_(decorative_ribbons)',
|
137 |
+
'bow_(weapon)',
|
138 |
+
'bowl',
|
139 |
+
'bowler_hat',
|
140 |
+
'bowling_ball',
|
141 |
+
'box',
|
142 |
+
'boxing_glove',
|
143 |
+
'bracelet',
|
144 |
+
'brass_plaque',
|
145 |
+
'brassiere',
|
146 |
+
'bread',
|
147 |
+
'bread-bin',
|
148 |
+
'breechcloth',
|
149 |
+
'bridal_gown',
|
150 |
+
'briefcase',
|
151 |
+
'broach',
|
152 |
+
'broccoli',
|
153 |
+
'broom',
|
154 |
+
'brownie',
|
155 |
+
'brussels_sprouts',
|
156 |
+
'bubble_gum',
|
157 |
+
'bucket',
|
158 |
+
'bulldog',
|
159 |
+
'bulldozer',
|
160 |
+
'bullet_train',
|
161 |
+
'bulletin_board',
|
162 |
+
'bulletproof_vest',
|
163 |
+
'bullhorn',
|
164 |
+
'bun',
|
165 |
+
'bunk_bed',
|
166 |
+
'buoy',
|
167 |
+
'burrito',
|
168 |
+
'bus_(vehicle)',
|
169 |
+
'business_card',
|
170 |
+
'butter',
|
171 |
+
'butterfly',
|
172 |
+
'button',
|
173 |
+
'cab_(taxi)',
|
174 |
+
'cabana',
|
175 |
+
'cabin_car',
|
176 |
+
'cabinet',
|
177 |
+
'cake',
|
178 |
+
'calculator',
|
179 |
+
'calendar',
|
180 |
+
'calf',
|
181 |
+
'camcorder',
|
182 |
+
'camel',
|
183 |
+
'camera',
|
184 |
+
'camera_lens',
|
185 |
+
'camper_(vehicle)',
|
186 |
+
'can',
|
187 |
+
'can_opener',
|
188 |
+
'candle',
|
189 |
+
'candle_holder',
|
190 |
+
'candy_bar',
|
191 |
+
'candy_cane',
|
192 |
+
'canister',
|
193 |
+
'canoe',
|
194 |
+
'cantaloup',
|
195 |
+
'canteen',
|
196 |
+
'cap_(headwear)',
|
197 |
+
'cape',
|
198 |
+
'cappuccino',
|
199 |
+
'car_(automobile)',
|
200 |
+
'car_battery',
|
201 |
+
'card',
|
202 |
+
'cardigan',
|
203 |
+
'cargo_ship',
|
204 |
+
'carnation',
|
205 |
+
'carrot',
|
206 |
+
'cart',
|
207 |
+
'carton',
|
208 |
+
'cash_register',
|
209 |
+
'casserole',
|
210 |
+
'cassette',
|
211 |
+
'cast',
|
212 |
+
'cat',
|
213 |
+
'cauliflower',
|
214 |
+
'cayenne_(spice)',
|
215 |
+
'celery',
|
216 |
+
'cellular_telephone',
|
217 |
+
'chair',
|
218 |
+
'chaise_longue',
|
219 |
+
'chalice',
|
220 |
+
'chandelier',
|
221 |
+
'checkbook',
|
222 |
+
'checkerboard',
|
223 |
+
'cherry',
|
224 |
+
'chessboard',
|
225 |
+
'chicken_(animal)',
|
226 |
+
'chili_(vegetable)',
|
227 |
+
'chime',
|
228 |
+
'chinaware',
|
229 |
+
'chocolate_bar',
|
230 |
+
'chocolate_cake',
|
231 |
+
'chocolate_milk',
|
232 |
+
'chocolate_mousse',
|
233 |
+
'choker',
|
234 |
+
'chopping_board',
|
235 |
+
'chopstick',
|
236 |
+
'cider',
|
237 |
+
'cigar_box',
|
238 |
+
'cigarette',
|
239 |
+
'cigarette_case',
|
240 |
+
'cincture',
|
241 |
+
'cistern',
|
242 |
+
'clarinet',
|
243 |
+
'clasp',
|
244 |
+
'cleansing_agent',
|
245 |
+
'cleat_(for_securing_rope)',
|
246 |
+
'clementine',
|
247 |
+
'clip',
|
248 |
+
'clipboard',
|
249 |
+
'clippers_(for_plants)',
|
250 |
+
'cloak',
|
251 |
+
'clock',
|
252 |
+
'clock_tower',
|
253 |
+
'clothes_hamper',
|
254 |
+
'clothespin',
|
255 |
+
'clutch_bag',
|
256 |
+
'coaster',
|
257 |
+
'coat',
|
258 |
+
'coat_hanger',
|
259 |
+
'coatrack',
|
260 |
+
'cock',
|
261 |
+
'cockroach',
|
262 |
+
'cocoa_(beverage)',
|
263 |
+
'coconut',
|
264 |
+
'coffee_maker',
|
265 |
+
'coffee_table',
|
266 |
+
'coffeepot',
|
267 |
+
'coil',
|
268 |
+
'coin',
|
269 |
+
'colander',
|
270 |
+
'coloring_material',
|
271 |
+
'combination_lock',
|
272 |
+
'comic_book',
|
273 |
+
'compass',
|
274 |
+
'computer_keyboard',
|
275 |
+
'condiment',
|
276 |
+
'cone',
|
277 |
+
'control',
|
278 |
+
'convertible_(automobile)',
|
279 |
+
'cooker',
|
280 |
+
'cookie',
|
281 |
+
'cooking_utensil',
|
282 |
+
'cooler_(for_food)',
|
283 |
+
'cork_(bottle_plug)',
|
284 |
+
'corkboard',
|
285 |
+
'corkscrew',
|
286 |
+
'cornbread',
|
287 |
+
'cornet',
|
288 |
+
'cornice',
|
289 |
+
'cornmeal',
|
290 |
+
'corset',
|
291 |
+
'costume',
|
292 |
+
'cougar',
|
293 |
+
'cover',
|
294 |
+
'coverall',
|
295 |
+
'cow',
|
296 |
+
'cowbell',
|
297 |
+
'cowboy_hat',
|
298 |
+
'crab_(animal)',
|
299 |
+
'crabmeat',
|
300 |
+
'cracker',
|
301 |
+
'crape',
|
302 |
+
'crate',
|
303 |
+
'crawfish',
|
304 |
+
'crayon',
|
305 |
+
'cream_pitcher',
|
306 |
+
'crescent_roll',
|
307 |
+
'crib',
|
308 |
+
'crisp_(potato_chip)',
|
309 |
+
'crossbar',
|
310 |
+
'crouton',
|
311 |
+
'crow',
|
312 |
+
'crowbar',
|
313 |
+
'crown',
|
314 |
+
'crucifix',
|
315 |
+
'cruise_ship',
|
316 |
+
'crutch',
|
317 |
+
'cub_(animal)',
|
318 |
+
'cube',
|
319 |
+
'cucumber',
|
320 |
+
'cufflink',
|
321 |
+
'cup',
|
322 |
+
'cupboard',
|
323 |
+
'cupcake',
|
324 |
+
'curtain',
|
325 |
+
'cushion',
|
326 |
+
'cylinder',
|
327 |
+
'cymbal',
|
328 |
+
'dagger',
|
329 |
+
'dalmatian',
|
330 |
+
'dartboard',
|
331 |
+
'date_(fruit)',
|
332 |
+
'deadbolt',
|
333 |
+
'deck_chair',
|
334 |
+
'deer',
|
335 |
+
'desk',
|
336 |
+
'detergent',
|
337 |
+
'diaper',
|
338 |
+
'diary',
|
339 |
+
'die',
|
340 |
+
'dinghy',
|
341 |
+
'dining_table',
|
342 |
+
'dirt_bike',
|
343 |
+
'dish',
|
344 |
+
'dish_antenna',
|
345 |
+
'dishrag',
|
346 |
+
'dishtowel',
|
347 |
+
'dishwasher',
|
348 |
+
'dishwasher_detergent',
|
349 |
+
'dispenser',
|
350 |
+
'dog',
|
351 |
+
'dog_collar',
|
352 |
+
'doll',
|
353 |
+
'dollar',
|
354 |
+
'dollhouse',
|
355 |
+
'dolphin',
|
356 |
+
'domestic_ass',
|
357 |
+
'doorknob',
|
358 |
+
'doormat',
|
359 |
+
'doughnut',
|
360 |
+
'dove',
|
361 |
+
'dragonfly',
|
362 |
+
'drawer',
|
363 |
+
'dress',
|
364 |
+
'dress_hat',
|
365 |
+
'dress_suit',
|
366 |
+
'dresser',
|
367 |
+
'drill',
|
368 |
+
'drone',
|
369 |
+
'drum_(musical_instrument)',
|
370 |
+
'drumstick',
|
371 |
+
'duck',
|
372 |
+
'duckling',
|
373 |
+
'duct_tape',
|
374 |
+
'duffel_bag',
|
375 |
+
'dumbbell',
|
376 |
+
'dumpster',
|
377 |
+
'dustpan',
|
378 |
+
'eagle',
|
379 |
+
'earphone',
|
380 |
+
'earplug',
|
381 |
+
'earring',
|
382 |
+
'easel',
|
383 |
+
'eclair',
|
384 |
+
'edible_corn',
|
385 |
+
'eel',
|
386 |
+
'egg',
|
387 |
+
'egg_roll',
|
388 |
+
'egg_yolk',
|
389 |
+
'eggbeater',
|
390 |
+
'eggplant',
|
391 |
+
'elephant',
|
392 |
+
'elevator_car',
|
393 |
+
'elk',
|
394 |
+
'envelope',
|
395 |
+
'eraser',
|
396 |
+
'escargot',
|
397 |
+
'eyepatch',
|
398 |
+
'falcon',
|
399 |
+
'fan',
|
400 |
+
'faucet',
|
401 |
+
'fedora',
|
402 |
+
'ferret',
|
403 |
+
'ferry',
|
404 |
+
'fig_(fruit)',
|
405 |
+
'fighter_jet',
|
406 |
+
'figurine',
|
407 |
+
'file_(tool)',
|
408 |
+
'file_cabinet',
|
409 |
+
'fire_alarm',
|
410 |
+
'fire_engine',
|
411 |
+
'fire_extinguisher',
|
412 |
+
'fire_hose',
|
413 |
+
'fireplace',
|
414 |
+
'fireplug',
|
415 |
+
'first-aid_kit',
|
416 |
+
'fish',
|
417 |
+
'fish_(food)',
|
418 |
+
'fishbowl',
|
419 |
+
'fishing_rod',
|
420 |
+
'flag',
|
421 |
+
'flagpole',
|
422 |
+
'flamingo',
|
423 |
+
'flannel',
|
424 |
+
'flap',
|
425 |
+
'flash',
|
426 |
+
'flashlight',
|
427 |
+
'fleece',
|
428 |
+
'flip-flop_(sandal)',
|
429 |
+
'flipper_(footwear)',
|
430 |
+
'flower_arrangement',
|
431 |
+
'flowerpot',
|
432 |
+
'flute_glass',
|
433 |
+
'foal',
|
434 |
+
'folding_chair',
|
435 |
+
'food_processor',
|
436 |
+
'football_(American)',
|
437 |
+
'football_helmet',
|
438 |
+
'footstool',
|
439 |
+
'fork',
|
440 |
+
'forklift',
|
441 |
+
'freight_car',
|
442 |
+
'freshener',
|
443 |
+
'frisbee',
|
444 |
+
'frog',
|
445 |
+
'fruit_juice',
|
446 |
+
'frying_pan',
|
447 |
+
'fume_hood',
|
448 |
+
'funnel',
|
449 |
+
'futon',
|
450 |
+
'gameboard',
|
451 |
+
'garbage',
|
452 |
+
'garbage_truck',
|
453 |
+
'garden_hose',
|
454 |
+
'gargle',
|
455 |
+
'gargoyle',
|
456 |
+
'garlic',
|
457 |
+
'gasmask',
|
458 |
+
'gazelle',
|
459 |
+
'gelatin',
|
460 |
+
'gemstone',
|
461 |
+
'generator',
|
462 |
+
'giant_panda',
|
463 |
+
'gift_wrap',
|
464 |
+
'ginger',
|
465 |
+
'giraffe',
|
466 |
+
'glass_(drink_container)',
|
467 |
+
'globe',
|
468 |
+
'glove',
|
469 |
+
'goat',
|
470 |
+
'goggles',
|
471 |
+
'goldfish',
|
472 |
+
'golf_club',
|
473 |
+
'golfcart',
|
474 |
+
'gondola_(boat)',
|
475 |
+
'goose',
|
476 |
+
'gorilla',
|
477 |
+
'gourd',
|
478 |
+
'grape',
|
479 |
+
'grater',
|
480 |
+
'gravestone',
|
481 |
+
'gravy_boat',
|
482 |
+
'green_bean',
|
483 |
+
'green_onion',
|
484 |
+
'grill',
|
485 |
+
'grits',
|
486 |
+
'grizzly',
|
487 |
+
'grocery_bag',
|
488 |
+
'guitar',
|
489 |
+
'gull',
|
490 |
+
'gun',
|
491 |
+
'hair_dryer',
|
492 |
+
'hairbrush',
|
493 |
+
'hairnet',
|
494 |
+
'halter_top',
|
495 |
+
'ham',
|
496 |
+
'hamburger',
|
497 |
+
'hammer',
|
498 |
+
'hammock',
|
499 |
+
'hamper',
|
500 |
+
'hamster',
|
501 |
+
'hand_glass',
|
502 |
+
'hand_towel',
|
503 |
+
'handbag',
|
504 |
+
'handcart',
|
505 |
+
'handcuff',
|
506 |
+
'handkerchief',
|
507 |
+
'handle',
|
508 |
+
'handsaw',
|
509 |
+
'hardback_book',
|
510 |
+
'harmonium',
|
511 |
+
'hat',
|
512 |
+
'hatbox',
|
513 |
+
'headband',
|
514 |
+
'headboard',
|
515 |
+
'headlight',
|
516 |
+
'headscarf',
|
517 |
+
'headset',
|
518 |
+
'headstall_(for_horses)',
|
519 |
+
'heart',
|
520 |
+
'heater',
|
521 |
+
'helicopter',
|
522 |
+
'helmet',
|
523 |
+
'heron',
|
524 |
+
'highchair',
|
525 |
+
'hinge',
|
526 |
+
'hippopotamus',
|
527 |
+
'hockey_stick',
|
528 |
+
'hog',
|
529 |
+
'honey',
|
530 |
+
'hook',
|
531 |
+
'hookah',
|
532 |
+
'horned_cow',
|
533 |
+
'hornet',
|
534 |
+
'horse',
|
535 |
+
'horse_buggy',
|
536 |
+
'horse_carriage',
|
537 |
+
'hose',
|
538 |
+
'hot-air_balloon',
|
539 |
+
'hot_sauce',
|
540 |
+
'hotplate',
|
541 |
+
'hourglass',
|
542 |
+
'houseboat',
|
543 |
+
'hummingbird',
|
544 |
+
'iPod',
|
545 |
+
'ice_maker',
|
546 |
+
'ice_pack',
|
547 |
+
'ice_skate',
|
548 |
+
'icecream',
|
549 |
+
'identity_card',
|
550 |
+
'igniter',
|
551 |
+
'inhaler',
|
552 |
+
'inkpad',
|
553 |
+
'iron_(for_clothing)',
|
554 |
+
'ironing_board',
|
555 |
+
'jacket',
|
556 |
+
'jam',
|
557 |
+
'jar',
|
558 |
+
'jean',
|
559 |
+
'jeep',
|
560 |
+
'jersey',
|
561 |
+
'jet_plane',
|
562 |
+
'jewel',
|
563 |
+
'jewelry',
|
564 |
+
'joystick',
|
565 |
+
'jumpsuit',
|
566 |
+
'kayak',
|
567 |
+
'keg',
|
568 |
+
'kennel',
|
569 |
+
'kettle',
|
570 |
+
'key',
|
571 |
+
'keycard',
|
572 |
+
'kilt',
|
573 |
+
'kimono',
|
574 |
+
'kitchen_sink',
|
575 |
+
'kitchen_table',
|
576 |
+
'kite',
|
577 |
+
'kitten',
|
578 |
+
'kiwi_fruit',
|
579 |
+
'knee_pad',
|
580 |
+
'knife',
|
581 |
+
'knitting_needle',
|
582 |
+
'knob',
|
583 |
+
'knocker_(on_a_door)',
|
584 |
+
'koala',
|
585 |
+
'lab_coat',
|
586 |
+
'ladder',
|
587 |
+
'ladle',
|
588 |
+
'ladybug',
|
589 |
+
'lamb-chop',
|
590 |
+
'lamb_(animal)',
|
591 |
+
'lamp',
|
592 |
+
'lamppost',
|
593 |
+
'lampshade',
|
594 |
+
'lantern',
|
595 |
+
'laptop_computer',
|
596 |
+
'lasagna',
|
597 |
+
'latch',
|
598 |
+
'lawn_mower',
|
599 |
+
'leather',
|
600 |
+
'legging_(clothing)',
|
601 |
+
'legume',
|
602 |
+
'lemon',
|
603 |
+
'lemonade',
|
604 |
+
'lettuce',
|
605 |
+
'license_plate',
|
606 |
+
'life_buoy',
|
607 |
+
'life_jacket',
|
608 |
+
'lightbulb',
|
609 |
+
'lightning_rod',
|
610 |
+
'lime',
|
611 |
+
'limousine',
|
612 |
+
'lion',
|
613 |
+
'lip_balm',
|
614 |
+
'liquor',
|
615 |
+
'lizard',
|
616 |
+
'locker',
|
617 |
+
'log',
|
618 |
+
'lollipop',
|
619 |
+
'loveseat',
|
620 |
+
'machine_gun',
|
621 |
+
'magazine',
|
622 |
+
'magnet',
|
623 |
+
'mail_slot',
|
624 |
+
'mailbox_(at_home)',
|
625 |
+
'mallard',
|
626 |
+
'mallet',
|
627 |
+
'mammoth',
|
628 |
+
'manatee',
|
629 |
+
'mandarin_orange',
|
630 |
+
'manger',
|
631 |
+
'manhole',
|
632 |
+
'map',
|
633 |
+
'marker',
|
634 |
+
'martini',
|
635 |
+
'mascot',
|
636 |
+
'mashed_potato',
|
637 |
+
'mask',
|
638 |
+
'mast',
|
639 |
+
'mat_(gym_equipment)',
|
640 |
+
'matchbox',
|
641 |
+
'mattress',
|
642 |
+
'measuring_cup',
|
643 |
+
'measuring_stick',
|
644 |
+
'meatball',
|
645 |
+
'medicine',
|
646 |
+
'melon',
|
647 |
+
'microphone',
|
648 |
+
'microscope',
|
649 |
+
'microwave_oven',
|
650 |
+
'milestone',
|
651 |
+
'milk',
|
652 |
+
'milk_can',
|
653 |
+
'milkshake',
|
654 |
+
'minivan',
|
655 |
+
'mint_candy',
|
656 |
+
'mirror',
|
657 |
+
'mitten',
|
658 |
+
'mixer_(kitchen_tool)',
|
659 |
+
'money',
|
660 |
+
'monitor_(computer_equipment) computer_monitor',
|
661 |
+
'monkey',
|
662 |
+
'mop',
|
663 |
+
'motor',
|
664 |
+
'motor_scooter',
|
665 |
+
'motor_vehicle',
|
666 |
+
'motorcycle',
|
667 |
+
'mound_(baseball)',
|
668 |
+
'mouse_(computer_equipment)',
|
669 |
+
'mousepad',
|
670 |
+
'muffin',
|
671 |
+
'mug',
|
672 |
+
'mushroom',
|
673 |
+
'music_stool',
|
674 |
+
'musical_instrument',
|
675 |
+
'nailfile',
|
676 |
+
'napkin',
|
677 |
+
'neckerchief',
|
678 |
+
'necklace',
|
679 |
+
'necktie',
|
680 |
+
'needle',
|
681 |
+
'nest',
|
682 |
+
'newspaper',
|
683 |
+
'newsstand',
|
684 |
+
'nightshirt',
|
685 |
+
'notebook',
|
686 |
+
'notepad',
|
687 |
+
'nut',
|
688 |
+
'nutcracker',
|
689 |
+
'oar',
|
690 |
+
'octopus_(animal)',
|
691 |
+
'octopus_(food)',
|
692 |
+
'oil_lamp',
|
693 |
+
'olive_oil',
|
694 |
+
'omelet',
|
695 |
+
'onion',
|
696 |
+
'orange_(fruit)',
|
697 |
+
'orange_juice',
|
698 |
+
'ostrich',
|
699 |
+
'ottoman',
|
700 |
+
'oven',
|
701 |
+
'overalls_(clothing)',
|
702 |
+
'owl',
|
703 |
+
'pacifier',
|
704 |
+
'packet',
|
705 |
+
'paddle',
|
706 |
+
'padlock',
|
707 |
+
'paintbrush',
|
708 |
+
'painting',
|
709 |
+
'pajamas',
|
710 |
+
'palette',
|
711 |
+
'pan_(for_cooking)',
|
712 |
+
'pan_(metal_container)',
|
713 |
+
'pancake',
|
714 |
+
'papaya',
|
715 |
+
'paper_plate',
|
716 |
+
'paper_towel',
|
717 |
+
'paperback_book',
|
718 |
+
'paperweight',
|
719 |
+
'parachute',
|
720 |
+
'parakeet',
|
721 |
+
'parasail_(sports)',
|
722 |
+
'parasol',
|
723 |
+
'parchment',
|
724 |
+
'parka',
|
725 |
+
'parking_meter',
|
726 |
+
'parrot',
|
727 |
+
'passenger_car_(part_of_a_train)',
|
728 |
+
'passenger_ship',
|
729 |
+
'passport',
|
730 |
+
'pastry',
|
731 |
+
'patty_(food)',
|
732 |
+
'pea_(food)',
|
733 |
+
'peach',
|
734 |
+
'peanut_butter',
|
735 |
+
'pear',
|
736 |
+
'peeler_(tool_for_fruit_and_vegetables)',
|
737 |
+
'pegboard',
|
738 |
+
'pelican',
|
739 |
+
'pen',
|
740 |
+
'pencil',
|
741 |
+
'pencil_box',
|
742 |
+
'pencil_sharpener',
|
743 |
+
'pendulum',
|
744 |
+
'penguin',
|
745 |
+
'pennant',
|
746 |
+
'penny_(coin)',
|
747 |
+
'pepper',
|
748 |
+
'pepper_mill',
|
749 |
+
'perfume',
|
750 |
+
'persimmon',
|
751 |
+
'person',
|
752 |
+
'pet',
|
753 |
+
'pew_(church_bench)',
|
754 |
+
'phonebook',
|
755 |
+
'phonograph_record',
|
756 |
+
'piano',
|
757 |
+
'pickle',
|
758 |
+
'pickup_truck',
|
759 |
+
'pie',
|
760 |
+
'pigeon',
|
761 |
+
'piggy_bank',
|
762 |
+
'pillow',
|
763 |
+
'pineapple',
|
764 |
+
'pinecone',
|
765 |
+
'ping-pong_ball',
|
766 |
+
'pinwheel',
|
767 |
+
'pipe',
|
768 |
+
'pipe_bowl',
|
769 |
+
'pirate_flag',
|
770 |
+
'pistol',
|
771 |
+
'pita_(bread)',
|
772 |
+
'pitcher_(vessel_for_liquid)',
|
773 |
+
'pitchfork',
|
774 |
+
'pizza',
|
775 |
+
'place_mat',
|
776 |
+
'plastic_bag',
|
777 |
+
'plate',
|
778 |
+
'platter',
|
779 |
+
'playpen',
|
780 |
+
'pliers',
|
781 |
+
'plow_(farm_equipment)',
|
782 |
+
'plume',
|
783 |
+
'pocket_watch',
|
784 |
+
'pocketknife',
|
785 |
+
'poker_(fire_stirring_tool)',
|
786 |
+
'poker_chip',
|
787 |
+
'polar_bear',
|
788 |
+
'pole',
|
789 |
+
'police_cruiser',
|
790 |
+
'polo_shirt',
|
791 |
+
'poncho',
|
792 |
+
'pony',
|
793 |
+
'pool_table',
|
794 |
+
'pop_(soda)',
|
795 |
+
'popsicle',
|
796 |
+
'postbox_(public)',
|
797 |
+
'postcard',
|
798 |
+
'poster',
|
799 |
+
'pot',
|
800 |
+
'potato',
|
801 |
+
'potholder',
|
802 |
+
'pottery',
|
803 |
+
'pouch',
|
804 |
+
'power_shovel',
|
805 |
+
'prawn',
|
806 |
+
'pretzel',
|
807 |
+
'printer',
|
808 |
+
'projectile_(weapon)',
|
809 |
+
'projector',
|
810 |
+
'propeller',
|
811 |
+
'prune',
|
812 |
+
'pudding',
|
813 |
+
'puffer_(fish)',
|
814 |
+
'puffin',
|
815 |
+
'pug-dog',
|
816 |
+
'pumpkin',
|
817 |
+
'puncher',
|
818 |
+
'puppet',
|
819 |
+
'puppy',
|
820 |
+
'quesadilla',
|
821 |
+
'quiche',
|
822 |
+
'quilt',
|
823 |
+
'rabbit',
|
824 |
+
'race_car',
|
825 |
+
'racket',
|
826 |
+
'radar',
|
827 |
+
'radiator',
|
828 |
+
'radio_receiver',
|
829 |
+
'radish',
|
830 |
+
'raft',
|
831 |
+
'rag_doll',
|
832 |
+
'railcar_(part_of_a_train)',
|
833 |
+
'raincoat',
|
834 |
+
'ram_(animal)',
|
835 |
+
'raspberry',
|
836 |
+
'rat',
|
837 |
+
'reamer_(juicer)',
|
838 |
+
'rearview_mirror',
|
839 |
+
'receipt',
|
840 |
+
'recliner',
|
841 |
+
'record_player',
|
842 |
+
'reflector',
|
843 |
+
'refrigerator',
|
844 |
+
'remote_control',
|
845 |
+
'rhinoceros',
|
846 |
+
'rib_(food)',
|
847 |
+
'rifle',
|
848 |
+
'ring',
|
849 |
+
'river_boat',
|
850 |
+
'road_map',
|
851 |
+
'robe',
|
852 |
+
'rocking_chair',
|
853 |
+
'rodent',
|
854 |
+
'roller_skate',
|
855 |
+
'rolling_pin',
|
856 |
+
'root_beer',
|
857 |
+
'router_(computer_equipment)',
|
858 |
+
'rubber_band',
|
859 |
+
'runner_(carpet)',
|
860 |
+
'saddle_(on_an_animal)',
|
861 |
+
'saddle_blanket',
|
862 |
+
'saddlebag',
|
863 |
+
'safety_pin',
|
864 |
+
'sail',
|
865 |
+
'salad',
|
866 |
+
'salad_plate',
|
867 |
+
'salami',
|
868 |
+
'salmon_(fish)',
|
869 |
+
'salmon_(food)',
|
870 |
+
'salsa',
|
871 |
+
'saltshaker',
|
872 |
+
'sandal_(type_of_shoe)',
|
873 |
+
'sandwich',
|
874 |
+
'satchel',
|
875 |
+
'saucepan',
|
876 |
+
'saucer',
|
877 |
+
'sausage',
|
878 |
+
'sawhorse',
|
879 |
+
'saxophone',
|
880 |
+
'scale_(measuring_instrument)',
|
881 |
+
'scarecrow',
|
882 |
+
'scarf',
|
883 |
+
'school_bus',
|
884 |
+
'scissors',
|
885 |
+
'scoreboard',
|
886 |
+
'scraper',
|
887 |
+
'screwdriver',
|
888 |
+
'scrubbing_brush',
|
889 |
+
'sculpture',
|
890 |
+
'seabird',
|
891 |
+
'seahorse',
|
892 |
+
'seaplane',
|
893 |
+
'seashell',
|
894 |
+
'sewing_machine',
|
895 |
+
'shaker',
|
896 |
+
'shampoo',
|
897 |
+
'shark',
|
898 |
+
'sharpener',
|
899 |
+
'shaver_(electric)',
|
900 |
+
'shaving_cream',
|
901 |
+
'shawl',
|
902 |
+
'shears',
|
903 |
+
'sheep',
|
904 |
+
'shepherd_dog',
|
905 |
+
'sherbert',
|
906 |
+
'shield',
|
907 |
+
'shirt',
|
908 |
+
'shoe',
|
909 |
+
'shopping_bag',
|
910 |
+
'shopping_cart',
|
911 |
+
'short_pants',
|
912 |
+
'shot_glass',
|
913 |
+
'shoulder_bag',
|
914 |
+
'shovel',
|
915 |
+
'shower_cap',
|
916 |
+
'shower_curtain',
|
917 |
+
'shower_head',
|
918 |
+
'shredder_(for_paper)',
|
919 |
+
'signboard',
|
920 |
+
'silo',
|
921 |
+
'sink',
|
922 |
+
'skateboard',
|
923 |
+
'skewer',
|
924 |
+
'ski',
|
925 |
+
'ski_boot',
|
926 |
+
'ski_parka',
|
927 |
+
'ski_pole',
|
928 |
+
'skirt',
|
929 |
+
'skullcap',
|
930 |
+
'sled',
|
931 |
+
'sleeping_bag',
|
932 |
+
'slide',
|
933 |
+
'slipper_(footwear)',
|
934 |
+
'smoothie',
|
935 |
+
'snake',
|
936 |
+
'snowboard',
|
937 |
+
'snowman',
|
938 |
+
'snowmobile',
|
939 |
+
'soap',
|
940 |
+
'soccer_ball',
|
941 |
+
'sock',
|
942 |
+
'sofa',
|
943 |
+
'sofa_bed',
|
944 |
+
'softball',
|
945 |
+
'solar_array',
|
946 |
+
'sombrero',
|
947 |
+
'soup',
|
948 |
+
'soup_bowl',
|
949 |
+
'soupspoon',
|
950 |
+
'soya_milk',
|
951 |
+
'space_shuttle',
|
952 |
+
'sparkler_(fireworks)',
|
953 |
+
'spatula',
|
954 |
+
'speaker_(stero_equipment)',
|
955 |
+
'spear',
|
956 |
+
'spectacles',
|
957 |
+
'spice_rack',
|
958 |
+
'spider',
|
959 |
+
'sponge',
|
960 |
+
'spoon',
|
961 |
+
'sportswear',
|
962 |
+
'spotlight',
|
963 |
+
'squid_(food)',
|
964 |
+
'squirrel',
|
965 |
+
'stagecoach',
|
966 |
+
'stapler_(stapling_machine)',
|
967 |
+
'starfish',
|
968 |
+
'statue_(sculpture)',
|
969 |
+
'steak_(food)',
|
970 |
+
'steak_knife',
|
971 |
+
'steering_wheel',
|
972 |
+
'step_stool',
|
973 |
+
'stepladder',
|
974 |
+
'stereo_(sound_system)',
|
975 |
+
'stew',
|
976 |
+
'stirrer',
|
977 |
+
'stirrup',
|
978 |
+
'stool',
|
979 |
+
'stop_sign',
|
980 |
+
'stove',
|
981 |
+
'strainer',
|
982 |
+
'strap',
|
983 |
+
'straw_(for_drinking)',
|
984 |
+
'strawberry',
|
985 |
+
'street_sign',
|
986 |
+
'streetlight',
|
987 |
+
'string_cheese',
|
988 |
+
'stylus',
|
989 |
+
'subwoofer',
|
990 |
+
'sugar_bowl',
|
991 |
+
'sugarcane_(plant)',
|
992 |
+
'suit_(clothing)',
|
993 |
+
'suitcase',
|
994 |
+
'sunflower',
|
995 |
+
'sunglasses',
|
996 |
+
'sunhat',
|
997 |
+
'surfboard',
|
998 |
+
'sushi',
|
999 |
+
'suspenders',
|
1000 |
+
'sweat_pants',
|
1001 |
+
'sweatband',
|
1002 |
+
'sweater',
|
1003 |
+
'sweatshirt',
|
1004 |
+
'sweet_potato',
|
1005 |
+
'swimsuit',
|
1006 |
+
'sword',
|
1007 |
+
'syringe',
|
1008 |
+
'table',
|
1009 |
+
'table-tennis_table',
|
1010 |
+
'table_lamp',
|
1011 |
+
'tablecloth',
|
1012 |
+
'tachometer',
|
1013 |
+
'taco',
|
1014 |
+
'tag',
|
1015 |
+
'taillight',
|
1016 |
+
'tambourine',
|
1017 |
+
'tank_(storage_vessel)',
|
1018 |
+
'tank_top_(clothing)',
|
1019 |
+
'tape_(sticky_cloth_or_paper)',
|
1020 |
+
'tape_measure',
|
1021 |
+
'tapestry',
|
1022 |
+
'tarp',
|
1023 |
+
'tartan',
|
1024 |
+
'tassel',
|
1025 |
+
'teacup',
|
1026 |
+
'teakettle',
|
1027 |
+
'teapot',
|
1028 |
+
'teddy_bear',
|
1029 |
+
'telephone',
|
1030 |
+
'telephone_booth',
|
1031 |
+
'telephone_pole',
|
1032 |
+
'telephoto_lens',
|
1033 |
+
'television_camera',
|
1034 |
+
'television_set',
|
1035 |
+
'tennis_ball',
|
1036 |
+
'tennis_racket',
|
1037 |
+
'tequila',
|
1038 |
+
'thermometer',
|
1039 |
+
'thermos_bottle',
|
1040 |
+
'thermostat',
|
1041 |
+
'thimble',
|
1042 |
+
'thread',
|
1043 |
+
'thumbtack',
|
1044 |
+
'tiara',
|
1045 |
+
'tiger',
|
1046 |
+
'tights_(clothing)',
|
1047 |
+
'timer',
|
1048 |
+
'tinfoil',
|
1049 |
+
'tinsel',
|
1050 |
+
'tissue_paper',
|
1051 |
+
'toast_(food)',
|
1052 |
+
'toaster',
|
1053 |
+
'toaster_oven',
|
1054 |
+
'tobacco_pipe',
|
1055 |
+
'toilet',
|
1056 |
+
'toilet_tissue',
|
1057 |
+
'tomato',
|
1058 |
+
'tongs',
|
1059 |
+
'toolbox',
|
1060 |
+
'toothbrush',
|
1061 |
+
'toothpaste',
|
1062 |
+
'toothpick',
|
1063 |
+
'tortilla',
|
1064 |
+
'tote_bag',
|
1065 |
+
'tow_truck',
|
1066 |
+
'towel',
|
1067 |
+
'towel_rack',
|
1068 |
+
'toy',
|
1069 |
+
'tractor_(farm_equipment)',
|
1070 |
+
'traffic_light',
|
1071 |
+
'trailer_truck',
|
1072 |
+
'train_(railroad_vehicle)',
|
1073 |
+
'trampoline',
|
1074 |
+
'trash_can',
|
1075 |
+
'tray',
|
1076 |
+
'trench_coat',
|
1077 |
+
'triangle_(musical_instrument)',
|
1078 |
+
'tricycle',
|
1079 |
+
'tripod',
|
1080 |
+
'trophy_cup',
|
1081 |
+
'trousers',
|
1082 |
+
'truck',
|
1083 |
+
'truffle_(chocolate)',
|
1084 |
+
'trunk',
|
1085 |
+
'turban',
|
1086 |
+
'turkey_(food)',
|
1087 |
+
'turnip',
|
1088 |
+
'turtle',
|
1089 |
+
'turtleneck_(clothing)',
|
1090 |
+
'tux',
|
1091 |
+
'typewriter',
|
1092 |
+
'umbrella',
|
1093 |
+
'underdrawers',
|
1094 |
+
'underwear',
|
1095 |
+
'unicycle',
|
1096 |
+
'urinal',
|
1097 |
+
'urn',
|
1098 |
+
'vacuum_cleaner',
|
1099 |
+
'vase',
|
1100 |
+
'veil',
|
1101 |
+
'vending_machine',
|
1102 |
+
'vent',
|
1103 |
+
'vest',
|
1104 |
+
'videotape',
|
1105 |
+
'vinegar',
|
1106 |
+
'violin',
|
1107 |
+
'visor',
|
1108 |
+
'vodka',
|
1109 |
+
'volleyball',
|
1110 |
+
'vulture',
|
1111 |
+
'waffle',
|
1112 |
+
'waffle_iron',
|
1113 |
+
'wagon',
|
1114 |
+
'walking_cane',
|
1115 |
+
'walking_stick',
|
1116 |
+
'wall_clock',
|
1117 |
+
'wall_socket',
|
1118 |
+
'wallet',
|
1119 |
+
'walrus',
|
1120 |
+
'wardrobe',
|
1121 |
+
'washbasin',
|
1122 |
+
'watch',
|
1123 |
+
'water_bottle',
|
1124 |
+
'water_cooler',
|
1125 |
+
'water_faucet',
|
1126 |
+
'water_gun',
|
1127 |
+
'water_heater',
|
1128 |
+
'water_jug',
|
1129 |
+
'water_scooter',
|
1130 |
+
'water_ski',
|
1131 |
+
'water_tower',
|
1132 |
+
'watering_can',
|
1133 |
+
'watermelon',
|
1134 |
+
'weathervane',
|
1135 |
+
'webcam',
|
1136 |
+
'wedding_cake',
|
1137 |
+
'wedding_ring',
|
1138 |
+
'wet_suit',
|
1139 |
+
'wheel',
|
1140 |
+
'wheelchair',
|
1141 |
+
'whipped_cream',
|
1142 |
+
'wig',
|
1143 |
+
'wind_chime',
|
1144 |
+
'windmill',
|
1145 |
+
'window_box_(for_plants)',
|
1146 |
+
'windsock',
|
1147 |
+
'wine_bottle',
|
1148 |
+
'wine_bucket',
|
1149 |
+
'wineglass',
|
1150 |
+
'wok',
|
1151 |
+
'wolf',
|
1152 |
+
'wooden_leg',
|
1153 |
+
'wooden_spoon',
|
1154 |
+
'wreath',
|
1155 |
+
'wrench',
|
1156 |
+
'wristband',
|
1157 |
+
'wristlet',
|
1158 |
+
'yacht',
|
1159 |
+
'yogurt',
|
1160 |
+
'zebra',
|
1161 |
+
'zucchini'
|
1162 |
+
]
|
openshape/demo/lvis_cats.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:71baf2d3f89884a082f1db75d0e94ac9a3b8036553877a3fdd98861cd01c4aec
|
3 |
+
size 5919467
|
openshape/demo/misc_utils.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import trimesh
|
3 |
+
import trimesh.sample
|
4 |
+
import trimesh.visual
|
5 |
+
import trimesh.proximity
|
6 |
+
import objaverse
|
7 |
+
import streamlit as st
|
8 |
+
import plotly.graph_objects as go
|
9 |
+
import matplotlib.pyplot as plotlib
|
10 |
+
|
11 |
+
|
12 |
+
def get_bytes(x: str):
|
13 |
+
import io, requests
|
14 |
+
return io.BytesIO(requests.get(x).content)
|
15 |
+
|
16 |
+
|
17 |
+
def get_image(x: str):
|
18 |
+
try:
|
19 |
+
return plotlib.imread(get_bytes(x), 'auto')
|
20 |
+
except Exception:
|
21 |
+
raise ValueError("Invalid image", x)
|
22 |
+
|
23 |
+
|
24 |
+
def model_to_pc(mesh: trimesh.Trimesh, n_sample_points=10000):
|
25 |
+
f32 = numpy.float32
|
26 |
+
rad = numpy.sqrt(mesh.area / (3 * n_sample_points))
|
27 |
+
for _ in range(24):
|
28 |
+
pcd, face_idx = trimesh.sample.sample_surface_even(mesh, n_sample_points, rad)
|
29 |
+
rad *= 0.85
|
30 |
+
if len(pcd) == n_sample_points:
|
31 |
+
break
|
32 |
+
else:
|
33 |
+
raise ValueError("Bad geometry, cannot finish sampling.", mesh.area)
|
34 |
+
if isinstance(mesh.visual, trimesh.visual.ColorVisuals):
|
35 |
+
rgba = mesh.visual.face_colors[face_idx]
|
36 |
+
elif isinstance(mesh.visual, trimesh.visual.TextureVisuals):
|
37 |
+
bc = trimesh.proximity.points_to_barycentric(mesh.triangles[face_idx], pcd)
|
38 |
+
if mesh.visual.uv is None or len(mesh.visual.uv) < mesh.faces[face_idx].max():
|
39 |
+
uv = numpy.zeros([len(bc), 2])
|
40 |
+
st.warning("Invalid UV, filling with zeroes")
|
41 |
+
else:
|
42 |
+
uv = numpy.einsum('ntc,nt->nc', mesh.visual.uv[mesh.faces[face_idx]], bc)
|
43 |
+
material = mesh.visual.material
|
44 |
+
if hasattr(material, 'materials'):
|
45 |
+
if len(material.materials) == 0:
|
46 |
+
rgba = numpy.ones_like(pcd) * 0.8
|
47 |
+
texture = None
|
48 |
+
st.warning("Empty MultiMaterial found, falling back to light grey")
|
49 |
+
else:
|
50 |
+
material = material.materials[0]
|
51 |
+
if hasattr(material, 'image'):
|
52 |
+
texture = material.image
|
53 |
+
if texture is None:
|
54 |
+
rgba = numpy.zeros([len(uv), len(material.main_color)]) + material.main_color
|
55 |
+
elif hasattr(material, 'baseColorTexture'):
|
56 |
+
texture = material.baseColorTexture
|
57 |
+
if texture is None:
|
58 |
+
rgba = numpy.zeros([len(uv), len(material.main_color)]) + material.main_color
|
59 |
+
else:
|
60 |
+
texture = None
|
61 |
+
rgba = numpy.ones_like(pcd) * 0.8
|
62 |
+
st.warning("Unknown material, falling back to light grey")
|
63 |
+
if texture is not None:
|
64 |
+
rgba = trimesh.visual.uv_to_interpolated_color(uv, texture)
|
65 |
+
if rgba.max() > 1:
|
66 |
+
if rgba.max() > 255:
|
67 |
+
rgba = rgba.astype(f32) / rgba.max()
|
68 |
+
else:
|
69 |
+
rgba = rgba.astype(f32) / 255.0
|
70 |
+
return numpy.concatenate([numpy.array(pcd, f32), numpy.array(rgba, f32)[:, :3]], axis=-1)
|
71 |
+
|
72 |
+
|
73 |
+
def trimesh_to_pc(scene_or_mesh):
|
74 |
+
if isinstance(scene_or_mesh, trimesh.Scene):
|
75 |
+
meshes = []
|
76 |
+
for node_name in scene_or_mesh.graph.nodes_geometry:
|
77 |
+
# which geometry does this node refer to
|
78 |
+
transform, geometry_name = scene_or_mesh.graph[node_name]
|
79 |
+
|
80 |
+
# get the actual potential mesh instance
|
81 |
+
geometry = scene_or_mesh.geometry[geometry_name].copy()
|
82 |
+
if not hasattr(geometry, 'triangles'):
|
83 |
+
continue
|
84 |
+
geometry: trimesh.Trimesh
|
85 |
+
geometry = geometry.apply_transform(transform)
|
86 |
+
meshes.append(geometry)
|
87 |
+
total_area = sum(geometry.area for geometry in meshes)
|
88 |
+
if total_area < 1e-6:
|
89 |
+
raise ValueError("Bad geometry: total area too small (< 1e-6)")
|
90 |
+
pcs = []
|
91 |
+
for geometry in meshes:
|
92 |
+
pcs.append(model_to_pc(geometry, max(1, round(geometry.area / total_area * 10000))))
|
93 |
+
if not len(pcs):
|
94 |
+
raise ValueError("Unsupported mesh object: no triangles found")
|
95 |
+
return numpy.concatenate(pcs)
|
96 |
+
else:
|
97 |
+
assert isinstance(scene_or_mesh, trimesh.Trimesh)
|
98 |
+
return model_to_pc(scene_or_mesh, 10000)
|
99 |
+
|
100 |
+
|
101 |
+
def input_3d_shape():
|
102 |
+
objaid = st.text_input("Enter an Objaverse ID")
|
103 |
+
model = st.file_uploader("Or upload a model (.glb/.obj/.ply)")
|
104 |
+
npy = st.file_uploader("Or upload a point cloud numpy array (.npy of Nx3 XYZ or Nx6 XYZRGB)")
|
105 |
+
swap_yz_axes = st.checkbox("Swap Y/Z axes of input (Y is up for OpenShape)")
|
106 |
+
f32 = numpy.float32
|
107 |
+
|
108 |
+
def load_data(prog):
|
109 |
+
# load the model
|
110 |
+
prog.progress(0.05, "Preparing Point Cloud")
|
111 |
+
if npy is not None:
|
112 |
+
pc: numpy.ndarray = numpy.load(npy)
|
113 |
+
elif model is not None:
|
114 |
+
pc = trimesh_to_pc(trimesh.load(model, model.name.split(".")[-1]))
|
115 |
+
elif objaid:
|
116 |
+
prog.progress(0.1, "Downloading Objaverse Object")
|
117 |
+
objamodel = objaverse.load_objects([objaid])[objaid]
|
118 |
+
prog.progress(0.2, "Preparing Point Cloud")
|
119 |
+
pc = trimesh_to_pc(trimesh.load(objamodel))
|
120 |
+
else:
|
121 |
+
raise ValueError("You have to supply 3D input!")
|
122 |
+
prog.progress(0.25, "Preprocessing Point Cloud")
|
123 |
+
assert pc.ndim == 2, "invalid pc shape: ndim = %d != 2" % pc.ndim
|
124 |
+
assert pc.shape[1] in [3, 6], "invalid pc shape: should have 3/6 channels, got %d" % pc.shape[1]
|
125 |
+
if swap_yz_axes:
|
126 |
+
pc[:, [1, 2]] = pc[:, [2, 1]]
|
127 |
+
pc[:, :3] = pc[:, :3] - numpy.mean(pc[:, :3], axis=0)
|
128 |
+
pc[:, :3] = pc[:, :3] / numpy.linalg.norm(pc[:, :3], axis=-1).max()
|
129 |
+
if pc.shape[1] == 3:
|
130 |
+
pc = numpy.concatenate([pc, numpy.ones_like(pc)], axis=-1)
|
131 |
+
prog.progress(0.3, "Preprocessed Point Cloud")
|
132 |
+
return pc.astype(f32)
|
133 |
+
|
134 |
+
return load_data
|
135 |
+
|
136 |
+
|
137 |
+
def render_pc(pc):
|
138 |
+
rand = numpy.random.permutation(len(pc))[:2048]
|
139 |
+
pc = pc[rand]
|
140 |
+
rgb = (pc[:, 3:] * 255).astype(numpy.uint8)
|
141 |
+
g = go.Scatter3d(
|
142 |
+
x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
|
143 |
+
mode='markers',
|
144 |
+
marker=dict(size=2, color=[f'rgb({rgb[i, 0]}, {rgb[i, 1]}, {rgb[i, 2]})' for i in range(len(pc))]),
|
145 |
+
)
|
146 |
+
fig = go.Figure(data=[g])
|
147 |
+
fig.update_layout(scene_camera=dict(up=dict(x=0, y=1, z=0)))
|
148 |
+
fig.update_scenes(aspectmode="data")
|
149 |
+
col1, col2 = st.columns(2)
|
150 |
+
with col1:
|
151 |
+
st.plotly_chart(fig, use_container_width=True)
|
152 |
+
# st.caption("Point Cloud Preview")
|
153 |
+
return col2
|
openshape/demo/retrieval.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
|
6 |
+
|
7 |
+
meta = json.load(
|
8 |
+
open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", token=True, repo_type='dataset'))
|
9 |
+
)
|
10 |
+
# {
|
11 |
+
# "u": "94db219c315742909fee67deeeacae15",
|
12 |
+
# "name": "knife", "like": 0, "view": 35,
|
13 |
+
# "tags": ["game-ready", "damascus", "damascus_steel", "kabar-knife", "knife", "blender", "blender3d", "gameready"],
|
14 |
+
# "cats": ["weapons-military"],
|
15 |
+
# "img": "https://media.sketchfab.com/models/94db219c315742909fee67deeeacae15/thumbnails/c0bbbd475d264ff2a92972f5115564ee/0cd28a130ebd4d9c9ef73190f24d9a42.jpeg",
|
16 |
+
# "desc": "", "faces": 1724, "size": 11955, "lic": "by",
|
17 |
+
# "glb": "glbs/000-000/94db219c315742909fee67deeeacae15.glb"
|
18 |
+
# }
|
19 |
+
meta = {x['u']: x for x in meta['entries']}
|
20 |
+
deser = torch.load(
|
21 |
+
hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse.pt", token=True, repo_type='dataset'), map_location='cpu'
|
22 |
+
)
|
23 |
+
us = deser['us']
|
24 |
+
feats = deser['feats']
|
25 |
+
|
26 |
+
|
27 |
+
def retrieve(embedding, top):
|
28 |
+
sims = []
|
29 |
+
embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze()
|
30 |
+
for chunk in torch.split(feats, 10240):
|
31 |
+
sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
|
32 |
+
sims = torch.cat(sims)
|
33 |
+
sims, idx = torch.topk(sims, top * 2)
|
34 |
+
results = []
|
35 |
+
for i, sim in zip(idx, sims):
|
36 |
+
if us[i] in meta:
|
37 |
+
results.append(dict(meta[us[i]], sim=sim))
|
38 |
+
if len(results) >= top:
|
39 |
+
break
|
40 |
+
return results
|
openshape/demo/sd_pc2img.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch_redstone as rst
|
3 |
+
import transformers
|
4 |
+
from diffusers import StableUnCLIPImg2ImgPipeline
|
5 |
+
|
6 |
+
|
7 |
+
class Wrapper(transformers.modeling_utils.PreTrainedModel):
|
8 |
+
def __init__(self) -> None:
|
9 |
+
super().__init__(transformers.configuration_utils.PretrainedConfig())
|
10 |
+
self.param = torch.nn.Parameter(torch.tensor(0.))
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
return rst.ObjectProxy(image_embeds=x)
|
14 |
+
|
15 |
+
|
16 |
+
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
17 |
+
"diffusers/stable-diffusion-2-1-unclip-i2i-l",
|
18 |
+
image_encoder = Wrapper()
|
19 |
+
)
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
pipe = pipe.to('cuda:' + str(torch.cuda.current_device()))
|
22 |
+
pipe.enable_model_cpu_offload(torch.cuda.current_device())
|
23 |
+
|
24 |
+
|
25 |
+
@torch.no_grad()
|
26 |
+
def pc_to_image(pc_encoder: torch.nn.Module, pc, prompt, noise_level, width, height, cfg_scale, num_steps, callback):
|
27 |
+
ref_dev = next(pc_encoder.parameters()).device
|
28 |
+
enc = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
|
29 |
+
return pipe(
|
30 |
+
prompt="best quality, super high resolution, " + prompt,
|
31 |
+
negative_prompt="cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
|
32 |
+
image=torch.nn.functional.normalize(enc, dim=-1) * (768 ** 0.5) / 2,
|
33 |
+
width=width, height=height,
|
34 |
+
guidance_scale=cfg_scale,
|
35 |
+
noise_level=noise_level,
|
36 |
+
callback=callback,
|
37 |
+
num_inference_steps=num_steps
|
38 |
+
).images[0]
|
openshape/pointnet_util.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from time import time
|
5 |
+
import numpy as np
|
6 |
+
import dgl.geometry
|
7 |
+
|
8 |
+
def timeit(tag, t):
|
9 |
+
print("{}: {}s".format(tag, time() - t))
|
10 |
+
return time()
|
11 |
+
|
12 |
+
def pc_normalize(pc):
|
13 |
+
l = pc.shape[0]
|
14 |
+
centroid = np.mean(pc, axis=0)
|
15 |
+
pc = pc - centroid
|
16 |
+
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
|
17 |
+
pc = pc / m
|
18 |
+
return pc
|
19 |
+
|
20 |
+
def square_distance(src, dst):
|
21 |
+
"""
|
22 |
+
Calculate Euclid distance between each two points.
|
23 |
+
|
24 |
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
25 |
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
26 |
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
27 |
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
28 |
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
29 |
+
|
30 |
+
Input:
|
31 |
+
src: source points, [B, N, C]
|
32 |
+
dst: target points, [B, M, C]
|
33 |
+
Output:
|
34 |
+
dist: per-point square distance, [B, N, M]
|
35 |
+
"""
|
36 |
+
B, N, _ = src.shape
|
37 |
+
_, M, _ = dst.shape
|
38 |
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
39 |
+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
40 |
+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
41 |
+
return dist
|
42 |
+
|
43 |
+
|
44 |
+
def index_points(points, idx):
|
45 |
+
"""
|
46 |
+
|
47 |
+
Input:
|
48 |
+
points: input points data, [B, N, C]
|
49 |
+
idx: sample index data, [B, S]
|
50 |
+
Return:
|
51 |
+
new_points:, indexed points data, [B, S, C]
|
52 |
+
"""
|
53 |
+
device = points.device
|
54 |
+
B = points.shape[0]
|
55 |
+
view_shape = list(idx.shape)
|
56 |
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
57 |
+
repeat_shape = list(idx.shape)
|
58 |
+
repeat_shape[0] = 1
|
59 |
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
60 |
+
new_points = points[batch_indices, idx, :]
|
61 |
+
return new_points
|
62 |
+
|
63 |
+
|
64 |
+
def farthest_point_sample(xyz, npoint):
|
65 |
+
"""
|
66 |
+
Input:
|
67 |
+
xyz: pointcloud data, [B, N, 3]
|
68 |
+
npoint: number of samples
|
69 |
+
Return:
|
70 |
+
centroids: sampled pointcloud index, [B, npoint]
|
71 |
+
"""
|
72 |
+
return dgl.geometry.farthest_point_sampler(xyz, npoint)
|
73 |
+
device = xyz.device
|
74 |
+
B, N, C = xyz.shape
|
75 |
+
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
76 |
+
distance = torch.ones(B, N).to(device) * 1e10
|
77 |
+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
78 |
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
|
79 |
+
for i in range(npoint):
|
80 |
+
centroids[:, i] = farthest
|
81 |
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
|
82 |
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
83 |
+
mask = dist < distance
|
84 |
+
distance[mask] = dist[mask]
|
85 |
+
farthest = torch.max(distance, -1)[1]
|
86 |
+
return centroids
|
87 |
+
|
88 |
+
|
89 |
+
def query_ball_point(radius, nsample, xyz, new_xyz):
|
90 |
+
"""
|
91 |
+
Input:
|
92 |
+
radius: local region radius
|
93 |
+
nsample: max sample number in local region
|
94 |
+
xyz: all points, [B, N, 3]
|
95 |
+
new_xyz: query points, [B, S, 3]
|
96 |
+
Return:
|
97 |
+
group_idx: grouped points index, [B, S, nsample]
|
98 |
+
"""
|
99 |
+
device = xyz.device
|
100 |
+
B, N, C = xyz.shape
|
101 |
+
_, S, _ = new_xyz.shape
|
102 |
+
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
103 |
+
sqrdists = square_distance(new_xyz, xyz)
|
104 |
+
group_idx[sqrdists > radius ** 2] = N
|
105 |
+
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
106 |
+
group_first = group_idx[..., :1].repeat([1, 1, nsample])
|
107 |
+
mask = group_idx == N
|
108 |
+
group_idx[mask] = group_first[mask]
|
109 |
+
return group_idx
|
110 |
+
|
111 |
+
|
112 |
+
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
|
113 |
+
"""
|
114 |
+
Input:
|
115 |
+
npoint:
|
116 |
+
radius:
|
117 |
+
nsample:
|
118 |
+
xyz: input points position data, [B, N, 3]
|
119 |
+
points: input points data, [B, N, D]
|
120 |
+
Return:
|
121 |
+
new_xyz: sampled points position data, [B, npoint, nsample, 3]
|
122 |
+
new_points: sampled points data, [B, npoint, nsample, 3+D]
|
123 |
+
"""
|
124 |
+
B, N, C = xyz.shape
|
125 |
+
S = npoint
|
126 |
+
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
|
127 |
+
# torch.cuda.empty_cache()
|
128 |
+
new_xyz = index_points(xyz, fps_idx)
|
129 |
+
# torch.cuda.empty_cache()
|
130 |
+
idx = query_ball_point(radius, nsample, xyz, new_xyz)
|
131 |
+
# torch.cuda.empty_cache()
|
132 |
+
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
|
133 |
+
# torch.cuda.empty_cache()
|
134 |
+
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
|
135 |
+
# torch.cuda.empty_cache()
|
136 |
+
|
137 |
+
if points is not None:
|
138 |
+
grouped_points = index_points(points, idx)
|
139 |
+
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
|
140 |
+
else:
|
141 |
+
new_points = grouped_xyz_norm
|
142 |
+
if returnfps:
|
143 |
+
return new_xyz, new_points, grouped_xyz, fps_idx
|
144 |
+
else:
|
145 |
+
return new_xyz, new_points
|
146 |
+
|
147 |
+
|
148 |
+
def sample_and_group_all(xyz, points):
|
149 |
+
"""
|
150 |
+
Input:
|
151 |
+
xyz: input points position data, [B, N, 3]
|
152 |
+
points: input points data, [B, N, D]
|
153 |
+
Return:
|
154 |
+
new_xyz: sampled points position data, [B, 1, 3]
|
155 |
+
new_points: sampled points data, [B, 1, N, 3+D]
|
156 |
+
"""
|
157 |
+
device = xyz.device
|
158 |
+
B, N, C = xyz.shape
|
159 |
+
new_xyz = torch.zeros(B, 1, C).to(device)
|
160 |
+
grouped_xyz = xyz.view(B, 1, N, C)
|
161 |
+
if points is not None:
|
162 |
+
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
|
163 |
+
else:
|
164 |
+
new_points = grouped_xyz
|
165 |
+
return new_xyz, new_points
|
166 |
+
|
167 |
+
|
168 |
+
class PointNetSetAbstraction(nn.Module):
|
169 |
+
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
|
170 |
+
super(PointNetSetAbstraction, self).__init__()
|
171 |
+
self.npoint = npoint
|
172 |
+
self.radius = radius
|
173 |
+
self.nsample = nsample
|
174 |
+
self.mlp_convs = nn.ModuleList()
|
175 |
+
self.mlp_bns = nn.ModuleList()
|
176 |
+
last_channel = in_channel
|
177 |
+
for out_channel in mlp:
|
178 |
+
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
179 |
+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
180 |
+
last_channel = out_channel
|
181 |
+
self.group_all = group_all
|
182 |
+
|
183 |
+
def forward(self, xyz, points):
|
184 |
+
"""
|
185 |
+
Input:
|
186 |
+
xyz: input points position data, [B, C, N]
|
187 |
+
points: input points data, [B, D, N]
|
188 |
+
Return:
|
189 |
+
new_xyz: sampled points position data, [B, C, S]
|
190 |
+
new_points_concat: sample points feature data, [B, D', S]
|
191 |
+
"""
|
192 |
+
xyz = xyz.permute(0, 2, 1)
|
193 |
+
if points is not None:
|
194 |
+
points = points.permute(0, 2, 1)
|
195 |
+
|
196 |
+
if self.group_all:
|
197 |
+
new_xyz, new_points = sample_and_group_all(xyz, points)
|
198 |
+
else:
|
199 |
+
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
|
200 |
+
# new_xyz: sampled points position data, [B, npoint, C]
|
201 |
+
# new_points: sampled points data, [B, npoint, nsample, C+D]
|
202 |
+
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
|
203 |
+
for i, conv in enumerate(self.mlp_convs):
|
204 |
+
bn = self.mlp_bns[i]
|
205 |
+
new_points = F.relu(bn(conv(new_points)))
|
206 |
+
|
207 |
+
new_points = torch.max(new_points, 2)[0]
|
208 |
+
new_xyz = new_xyz.permute(0, 2, 1)
|
209 |
+
return new_xyz, new_points
|
210 |
+
|
211 |
+
|
212 |
+
class PointNetSetAbstractionMsg(nn.Module):
|
213 |
+
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
|
214 |
+
super(PointNetSetAbstractionMsg, self).__init__()
|
215 |
+
self.npoint = npoint
|
216 |
+
self.radius_list = radius_list
|
217 |
+
self.nsample_list = nsample_list
|
218 |
+
self.conv_blocks = nn.ModuleList()
|
219 |
+
self.bn_blocks = nn.ModuleList()
|
220 |
+
for i in range(len(mlp_list)):
|
221 |
+
convs = nn.ModuleList()
|
222 |
+
bns = nn.ModuleList()
|
223 |
+
last_channel = in_channel + 3
|
224 |
+
for out_channel in mlp_list[i]:
|
225 |
+
convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
226 |
+
bns.append(nn.BatchNorm2d(out_channel))
|
227 |
+
last_channel = out_channel
|
228 |
+
self.conv_blocks.append(convs)
|
229 |
+
self.bn_blocks.append(bns)
|
230 |
+
|
231 |
+
def forward(self, xyz, points):
|
232 |
+
"""
|
233 |
+
Input:
|
234 |
+
xyz: input points position data, [B, C, N]
|
235 |
+
points: input points data, [B, D, N]
|
236 |
+
Return:
|
237 |
+
new_xyz: sampled points position data, [B, C, S]
|
238 |
+
new_points_concat: sample points feature data, [B, D', S]
|
239 |
+
"""
|
240 |
+
xyz = xyz.permute(0, 2, 1)
|
241 |
+
if points is not None:
|
242 |
+
points = points.permute(0, 2, 1)
|
243 |
+
|
244 |
+
B, N, C = xyz.shape
|
245 |
+
S = self.npoint
|
246 |
+
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
|
247 |
+
new_points_list = []
|
248 |
+
for i, radius in enumerate(self.radius_list):
|
249 |
+
K = self.nsample_list[i]
|
250 |
+
group_idx = query_ball_point(radius, K, xyz, new_xyz)
|
251 |
+
grouped_xyz = index_points(xyz, group_idx)
|
252 |
+
grouped_xyz -= new_xyz.view(B, S, 1, C)
|
253 |
+
if points is not None:
|
254 |
+
grouped_points = index_points(points, group_idx)
|
255 |
+
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
|
256 |
+
else:
|
257 |
+
grouped_points = grouped_xyz
|
258 |
+
|
259 |
+
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
|
260 |
+
for j in range(len(self.conv_blocks[i])):
|
261 |
+
conv = self.conv_blocks[i][j]
|
262 |
+
bn = self.bn_blocks[i][j]
|
263 |
+
grouped_points = F.relu(bn(conv(grouped_points)))
|
264 |
+
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
|
265 |
+
new_points_list.append(new_points)
|
266 |
+
|
267 |
+
new_xyz = new_xyz.permute(0, 2, 1)
|
268 |
+
new_points_concat = torch.cat(new_points_list, dim=1)
|
269 |
+
return new_xyz, new_points_concat
|
270 |
+
|
271 |
+
|
272 |
+
class PointNetFeaturePropagation(nn.Module):
|
273 |
+
def __init__(self, in_channel, mlp):
|
274 |
+
super(PointNetFeaturePropagation, self).__init__()
|
275 |
+
self.mlp_convs = nn.ModuleList()
|
276 |
+
self.mlp_bns = nn.ModuleList()
|
277 |
+
last_channel = in_channel
|
278 |
+
for out_channel in mlp:
|
279 |
+
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
|
280 |
+
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
|
281 |
+
last_channel = out_channel
|
282 |
+
|
283 |
+
def forward(self, xyz1, xyz2, points1, points2):
|
284 |
+
"""
|
285 |
+
Input:
|
286 |
+
xyz1: input points position data, [B, C, N]
|
287 |
+
xyz2: sampled input points position data, [B, C, S]
|
288 |
+
points1: input points data, [B, D, N]
|
289 |
+
points2: input points data, [B, D, S]
|
290 |
+
Return:
|
291 |
+
new_points: upsampled points data, [B, D', N]
|
292 |
+
"""
|
293 |
+
xyz1 = xyz1.permute(0, 2, 1)
|
294 |
+
xyz2 = xyz2.permute(0, 2, 1)
|
295 |
+
|
296 |
+
points2 = points2.permute(0, 2, 1)
|
297 |
+
B, N, C = xyz1.shape
|
298 |
+
_, S, _ = xyz2.shape
|
299 |
+
|
300 |
+
if S == 1:
|
301 |
+
interpolated_points = points2.repeat(1, N, 1)
|
302 |
+
else:
|
303 |
+
dists = square_distance(xyz1, xyz2)
|
304 |
+
dists, idx = dists.sort(dim=-1)
|
305 |
+
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
|
306 |
+
|
307 |
+
dist_recip = 1.0 / (dists + 1e-8)
|
308 |
+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
309 |
+
weight = dist_recip / norm
|
310 |
+
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
|
311 |
+
|
312 |
+
if points1 is not None:
|
313 |
+
points1 = points1.permute(0, 2, 1)
|
314 |
+
new_points = torch.cat([points1, interpolated_points], dim=-1)
|
315 |
+
else:
|
316 |
+
new_points = interpolated_points
|
317 |
+
|
318 |
+
new_points = new_points.permute(0, 2, 1)
|
319 |
+
for i, conv in enumerate(self.mlp_convs):
|
320 |
+
bn = self.mlp_bns[i]
|
321 |
+
new_points = F.relu(bn(conv(new_points)))
|
322 |
+
return new_points
|
323 |
+
|
openshape/ppat_rgb.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch_redstone as rst
|
4 |
+
from einops import rearrange
|
5 |
+
from .pointnet_util import PointNetSetAbstraction
|
6 |
+
|
7 |
+
|
8 |
+
class PreNorm(nn.Module):
|
9 |
+
def __init__(self, dim, fn):
|
10 |
+
super().__init__()
|
11 |
+
self.norm = nn.LayerNorm(dim)
|
12 |
+
self.fn = fn
|
13 |
+
def forward(self, x, *extra_args, **kwargs):
|
14 |
+
return self.fn(self.norm(x), *extra_args, **kwargs)
|
15 |
+
|
16 |
+
class FeedForward(nn.Module):
|
17 |
+
def __init__(self, dim, hidden_dim, dropout = 0.):
|
18 |
+
super().__init__()
|
19 |
+
self.net = nn.Sequential(
|
20 |
+
nn.Linear(dim, hidden_dim),
|
21 |
+
nn.GELU(),
|
22 |
+
nn.Dropout(dropout),
|
23 |
+
nn.Linear(hidden_dim, dim),
|
24 |
+
nn.Dropout(dropout)
|
25 |
+
)
|
26 |
+
def forward(self, x):
|
27 |
+
return self.net(x)
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rel_pe = False):
|
31 |
+
super().__init__()
|
32 |
+
inner_dim = dim_head * heads
|
33 |
+
project_out = not (heads == 1 and dim_head == dim)
|
34 |
+
|
35 |
+
self.heads = heads
|
36 |
+
self.scale = dim_head ** -0.5
|
37 |
+
|
38 |
+
self.attend = nn.Softmax(dim = -1)
|
39 |
+
self.dropout = nn.Dropout(dropout)
|
40 |
+
|
41 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
42 |
+
|
43 |
+
self.to_out = nn.Sequential(
|
44 |
+
nn.Linear(inner_dim, dim),
|
45 |
+
nn.Dropout(dropout)
|
46 |
+
) if project_out else nn.Identity()
|
47 |
+
|
48 |
+
self.rel_pe = rel_pe
|
49 |
+
if rel_pe:
|
50 |
+
self.pe = nn.Sequential(nn.Conv2d(3, 64, 1), nn.ReLU(), nn.Conv2d(64, 1, 1))
|
51 |
+
|
52 |
+
def forward(self, x, centroid_delta):
|
53 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
54 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
55 |
+
|
56 |
+
pe = self.pe(centroid_delta) if self.rel_pe else 0
|
57 |
+
dots = (torch.matmul(q, k.transpose(-1, -2)) + pe) * self.scale
|
58 |
+
|
59 |
+
attn = self.attend(dots)
|
60 |
+
attn = self.dropout(attn)
|
61 |
+
|
62 |
+
out = torch.matmul(attn, v)
|
63 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
64 |
+
return self.to_out(out)
|
65 |
+
|
66 |
+
|
67 |
+
class Transformer(nn.Module):
|
68 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rel_pe = False):
|
69 |
+
super().__init__()
|
70 |
+
self.layers = nn.ModuleList([])
|
71 |
+
for _ in range(depth):
|
72 |
+
self.layers.append(nn.ModuleList([
|
73 |
+
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rel_pe = rel_pe)),
|
74 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
75 |
+
]))
|
76 |
+
def forward(self, x, centroid_delta):
|
77 |
+
for attn, ff in self.layers:
|
78 |
+
x = attn(x, centroid_delta) + x
|
79 |
+
x = ff(x) + x
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class PointPatchTransformer(nn.Module):
|
84 |
+
def __init__(self, dim, depth, heads, mlp_dim, sa_dim, patches, prad, nsamp, in_dim=3, dim_head=64, rel_pe=False, patch_dropout=0) -> None:
|
85 |
+
super().__init__()
|
86 |
+
self.patches = patches
|
87 |
+
self.patch_dropout = patch_dropout
|
88 |
+
self.sa = PointNetSetAbstraction(npoint=patches, radius=prad, nsample=nsamp, in_channel=in_dim + 3, mlp=[64, 64, sa_dim], group_all=False)
|
89 |
+
self.lift = nn.Sequential(nn.Conv1d(sa_dim + 3, dim, 1), rst.Lambda(lambda x: torch.permute(x, [0, 2, 1])), nn.LayerNorm([dim]))
|
90 |
+
self.cls_token = nn.Parameter(torch.randn(dim))
|
91 |
+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, 0.0, rel_pe)
|
92 |
+
|
93 |
+
def forward(self, features):
|
94 |
+
self.sa.npoint = self.patches
|
95 |
+
if self.training:
|
96 |
+
self.sa.npoint -= self.patch_dropout
|
97 |
+
# print("input", features.shape)
|
98 |
+
centroids, feature = self.sa(features[:, :3], features)
|
99 |
+
# print("f", feature.shape, 'c', centroids.shape)
|
100 |
+
x = self.lift(torch.cat([centroids, feature], dim=1))
|
101 |
+
|
102 |
+
x = rst.supercat([self.cls_token, x], dim=-2)
|
103 |
+
centroids = rst.supercat([centroids.new_zeros(1), centroids], dim=-1)
|
104 |
+
|
105 |
+
centroid_delta = centroids.unsqueeze(-1) - centroids.unsqueeze(-2)
|
106 |
+
x = self.transformer(x, centroid_delta)
|
107 |
+
|
108 |
+
return x[:, 0]
|
109 |
+
|
110 |
+
|
111 |
+
class Projected(nn.Module):
|
112 |
+
def __init__(self, ppat, proj) -> None:
|
113 |
+
super().__init__()
|
114 |
+
self.ppat = ppat
|
115 |
+
self.proj = proj
|
116 |
+
|
117 |
+
def forward(self, features: torch.Tensor):
|
118 |
+
return self.proj(self.ppat(features))
|
setup.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import setuptools
|
2 |
+
|
3 |
+
|
4 |
+
def packages():
|
5 |
+
return setuptools.find_packages()
|
6 |
+
|
7 |
+
|
8 |
+
setuptools.setup(
|
9 |
+
name="openshape",
|
10 |
+
version="0.1",
|
11 |
+
author="flandre.info",
|
12 |
+
author_email="[email protected]",
|
13 |
+
description="Support library for OpenShape Demos.",
|
14 |
+
packages=packages(),
|
15 |
+
classifiers=[
|
16 |
+
"Programming Language :: Python :: 3 :: Only",
|
17 |
+
"License :: OSI Approved :: Apache Software License",
|
18 |
+
"Operating System :: OS Independent",
|
19 |
+
],
|
20 |
+
python_requires='~=3.7',
|
21 |
+
)
|