yuanze1024 commited on
Commit
f15a1cd
·
1 Parent(s): 6dc0a5f
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .cache
2
+ __pycache__/
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: LD T3D
3
- emoji: 🚀
4
  colorFrom: indigo
5
  colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.22.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: LD T3D
3
+ emoji: 🐳
4
  colorFrom: indigo
5
  colorTo: yellow
6
+ sdk: docker
7
+ app_port: 7860
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ import functools
6
+ from datasets import load_dataset
7
+ from feature_extractors.uni3d_embedding_encoder import Uni3dEmbeddingEncoder
8
+
9
+ # os.environ['HTTP_PROXY'] = 'http://192.168.48.17:18000'
10
+ # os.environ['HTTPS_PROXY'] = 'http://192.168.48.17:18000'
11
+
12
+ MAX_BATCH_SIZE = 16
13
+ MAX_QUEUE_SIZE = 10
14
+ MAX_K_RETRIEVAL = 20
15
+ cache_dir = "./.cache"
16
+
17
+ encoder = Uni3dEmbeddingEncoder(cache_dir)
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ source_id_list = torch.load("data/source_id_list.pt")
20
+ source_to_id = {source_id: i for i, source_id in enumerate(source_id_list)}
21
+ dataset = load_dataset("VAST-AI/LD-T3D", name=f"rendered_imgs_diag_above", split="base", cache_dir=cache_dir)
22
+
23
+ @functools.lru_cache()
24
+ def get_embedding(option, modality, angle=None):
25
+ save_path = f'data/objaverse_{option}_{modality + (("_" + str(angle)) if angle is not None else "")}_embeddings.pt'
26
+ if os.path.exists(save_path):
27
+ return torch.load(save_path)
28
+ else:
29
+ return gr.Error(f"Embedding file not found: {save_path}")
30
+
31
+ def predict(xb, xq, top_k):
32
+ xb = xb.to(xq.device)
33
+ sim = xq @ xb.T # (nq, nb)
34
+ _, indices = sim.topk(k=top_k, largest=True)
35
+ return indices
36
+
37
+ def get_image(index):
38
+ return dataset[index]["image"]
39
+
40
+ def retrieve_3D_models(textual_query, top_k, modality_list):
41
+ if textual_query == "":
42
+ raise gr.Error("Please enter a textual query")
43
+ if len(textual_query.split()) > 20:
44
+ gr.Warning("Retrieval result may be inaccurate due to long textual query")
45
+ if len(modality_list) == 0:
46
+ raise gr.Error("Please select at least one modality")
47
+
48
+ def _retrieve_3D_models(query, top_k, modals:list):
49
+ option = "uni3d"
50
+ op = "add"
51
+ is_text = True if "text" in modals else False
52
+ is_3D = True if "3D" in modals else False
53
+ if is_text:
54
+ modals.remove("text")
55
+ if is_3D:
56
+ modals.remove("3D")
57
+ angles = modals
58
+
59
+ # get base embeddings
60
+ embeddings = []
61
+ if is_text:
62
+ embeddings.append(get_embedding(option, "text"))
63
+ if len(angles) > 0:
64
+ for angle in angles:
65
+ embeddings.append(get_embedding(option, "image", angle=angle))
66
+ if is_3D:
67
+ embeddings.append(get_embedding(option, "3D"))
68
+
69
+ ## fuse base embeddings
70
+ if len(embeddings) > 1:
71
+ if op == "concat":
72
+ embeddings = torch.cat(embeddings, dim=-1)
73
+ elif op == "add":
74
+ embeddings = sum(embeddings)
75
+ else:
76
+ raise ValueError(f"Unsupported operation: {op}")
77
+ embeddings /= embeddings.norm(dim=-1, keepdim=True)
78
+ else:
79
+ embeddings = embeddings[0]
80
+
81
+ # encode query embeddings
82
+ xq = encoder.encode_query(query)
83
+ if op == "concat":
84
+ xq = xq.repeat(1, embeddings.shape[-1] // xq.shape[-1]) # repeat to be aligned with the xb
85
+ xq /= xq.norm(dim=-1, keepdim=True)
86
+
87
+ pred_ind_list = predict(embeddings, xq, top_k)
88
+ return pred_ind_list[0].cpu().tolist() # we have only one query
89
+
90
+ indices = _retrieve_3D_models(textual_query, top_k, modality_list)
91
+ return [get_image(index) for index in indices]
92
+
93
+ def launch():
94
+ with gr.Blocks() as demo:
95
+ with gr.Row():
96
+ textual_query = gr.Textbox(label="Textual Query", autofocus=True,
97
+ placeholder="A chair with a wooden frame and a cushioned seat")
98
+ modality_list = gr.CheckboxGroup(label="Modality List", value=[],
99
+ choices=["text", "front", "back", "left", "right", "above",
100
+ "below", "diag_above", "diag_below", "3D"])
101
+ with gr.Row():
102
+ top_k = gr.Slider(minimum=1, maximum=MAX_K_RETRIEVAL, step=1, label="Top K Retrieval Result",
103
+ value=5, scale=2)
104
+ run = gr.Button("Search", scale=1)
105
+ clear_button = gr.ClearButton(scale=1)
106
+ with gr.Row():
107
+ output = gr.Gallery(format="webp", label="Retrieval Result", columns=5, type="pil")
108
+ run.click(retrieve_3D_models, [textual_query, top_k, modality_list], output,
109
+ # batch=True, max_batch_size=MAX_BATCH_SIZE
110
+ )
111
+ clear_button.click(lambda: ["", 5, [], []], outputs=[textual_query, top_k, modality_list, output])
112
+ examples = gr.Examples(examples=[["An ice cream with a cherry on top", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
113
+ ["A mid-age castle", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
114
+ ["A coke", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]]],
115
+ inputs=[textual_query, top_k, modality_list],
116
+ # cache_examples=True,
117
+ outputs=output,
118
+ fn=retrieve_3D_models)
119
+
120
+ demo.queue(max_size=10)
121
+
122
+ # os.environ.pop('HTTP_PROXY')
123
+ # os.environ.pop('HTTPS_PROXY')
124
+
125
+ demo.launch(server_name='0.0.0.0')
126
+
127
+ if __name__ == "__main__":
128
+ launch()
129
+ # print(len(retrieve_3D_models("A chair with a wooden frame and a cushioned seat", 5, ["3D", "diag_above", "diag_below"])))
change_setup.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import os.path as osp
4
+
5
+ from setuptools import find_packages, setup
6
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7
+
8
+ this_dir = osp.dirname(osp.abspath(__file__))
9
+ _ext_src_root = osp.join("pointnet2_ops", "_ext-src")
10
+ _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
11
+ osp.join(_ext_src_root, "src", "*.cu")
12
+ )
13
+ _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
14
+
15
+ requirements = ["torch>=1.4"]
16
+
17
+ exec(open(osp.join("pointnet2_ops", "_version.py")).read())
18
+
19
+ setup(
20
+ name="pointnet2_ops",
21
+ version=__version__,
22
+ author="Erik Wijmans",
23
+ packages=find_packages(),
24
+ install_requires=requirements,
25
+ ext_modules=[
26
+ CUDAExtension(
27
+ name="pointnet2_ops._ext",
28
+ sources=_ext_sources,
29
+ extra_compile_args={
30
+ "cxx": ["-O3"],
31
+ "nvcc": ["-O3", "-Xfatbin", "-compress-all"],
32
+ },
33
+ include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
34
+ )
35
+ ],
36
+ cmdclass={"build_ext": BuildExtension},
37
+ include_package_data=True,
38
+ )
data/objaverse_uni3d_3D_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b05400ab75009785535bd78d859db0a902176fbeb5df2ef73e55a95990ded1b8
3
+ size 365511995
data/objaverse_uni3d_image_above_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0708d9bfb4df4e6f86a21bd5a1096401c8c037e84575e6d0397efdb1b138289
3
+ size 365512104
data/objaverse_uni3d_image_back_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5667981bc1215e1f60c034ff8e2d214da6186a2f3212061b8ed3e1c32073ad6e
3
+ size 365512104
data/objaverse_uni3d_image_below_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f91df0329424657666dd9a5b3181d52f9155ad545dc22a2f725f24f9b854abbd
3
+ size 365512104
data/objaverse_uni3d_image_diag_above_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b44e2ee38885128e9080c75ee1d311fee8f718375e867c2209273649455c89a7
3
+ size 365512035
data/objaverse_uni3d_image_diag_below_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79eb0da600d75874e22bbfcca6001669eb14f06ec37326bf5148521db82f3e34
3
+ size 365512035
data/objaverse_uni3d_image_front_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:016208fa7a76e959840c128c30e178a0b43a570cf7a8e6cfd6fcdb442f6b72db
3
+ size 365512104
data/objaverse_uni3d_image_left_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5db0c17a56ebbb0fa1323b105dfe04386f8d7f88c876bc24b943e8713a01076
3
+ size 365512035
data/objaverse_uni3d_image_right_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5fb149475c79b465157d5b2cfe2af4ad8947ff23f99577da264c2632bc9d770
3
+ size 365512035
data/objaverse_uni3d_text_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2d908630bcc8a5a231e8b5d11714c63a3e8b6d78427a82a833da9219b2a7263
3
+ size 365512020
data/source_id_list.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c218ccb58d0045b0b6671c1378ee43362054b890f9895d7cac3de727683a9a76
3
+ size 3747900
dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ FROM nvcr.io/nvidia/pytorch:23.08
3
+
4
+ LABEL maintainer="yuanze"
5
+ LABEL email="[email protected]"
6
+
7
+ # Install webp support
8
+ RUN apt update && apt install libwebp-dev -y
9
+
10
+ RUN pip install -r requirements.txt
11
+
12
+ # note that you may need to modify the TORCH_CUDA_ARCH_LIST in the setup.py file
13
+ ENV TORCH_CUDA_ARCH_LIST="8.6"
14
+
15
+ # Install Pointnet2_PyTorch
16
+ RUN git clone https://github.com/erikwijmans/Pointnet2_PyTorch.git \
17
+ && mv -f backup_install.txt Pointnet2_PyTorch/pointnet2_ops_lib/setup.py \
18
+ && cd Pointnet2_PyTorch/pointnet2_ops_lib \
19
+ && python install .
feature_extractors/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+ from abc import ABC, abstractmethod
3
+ import torch
4
+ from PIL.Image import Image
5
+
6
+ class FeatureExtractor(ABC):
7
+ @abstractmethod
8
+ def encode_image(self, img_list: Sequence[Image]) -> torch.Tensor:
9
+ """
10
+ Encode the input images and return the corresponding embeddings.
11
+
12
+ Args:
13
+ img_list: A list of PIL.Image.Image objects.
14
+
15
+ Returns:
16
+ The embeddings of the input images. The shape should be (len(img_list), embedding_dim).
17
+ """
18
+ raise NotImplementedError
19
+
20
+ @abstractmethod
21
+ def encode_text(self, text_list: Sequence[str]) -> torch.Tensor:
22
+ """
23
+ Encode the input text data and return the corresponding embeddings.
24
+
25
+ Args:
26
+ text_list: A list of strings.
27
+
28
+ Returns:
29
+ The embeddings of the input text data. The shape should be (len(text_list), embedding_dim).
30
+ """
31
+ raise NotImplementedError
32
+
33
+ @abstractmethod
34
+ def encode_3D(self, pc_tensor: torch.Tensor) -> torch.Tensor:
35
+ """
36
+ Encode the input 3D point cloud and return the corresponding embeddings.
37
+
38
+ Args:
39
+ pc_tensor: A tensor of shape (B, N, 3 + 3).
40
+
41
+ Returns:
42
+ The embeddings of the input 3D point cloud. The shape should be (B, embedding_dim).
43
+ """
44
+ raise NotImplementedError
45
+
46
+ @abstractmethod
47
+ def encode_query(self, queries: Sequence[str]) -> torch.Tensor:
48
+ """Encode the queries and return the corresponding embeddings.
49
+
50
+ Args:
51
+ queries: A list of strings.
52
+
53
+ Returns:
54
+ The embeddings of the input text data. The shape should be (len(input_text), embedding_dim).
55
+ """
56
+ raise NotImplementedError
feature_extractors/uni3d_embedding_encoder.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ See https://github.com/baaivision/Uni3D for source code
3
+ """
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import timm
8
+ import numpy as np
9
+ from pointnet2_ops import pointnet2_utils
10
+ import open_clip
11
+ from huggingface_hub import hf_hub_download
12
+ import sys
13
+ sys.path.append('')
14
+ from feature_extractors import FeatureExtractor
15
+ from utils.tokenizer import SimpleTokenizer
16
+
17
+ import logging
18
+
19
+ def fps(data, number):
20
+ '''
21
+ data B N 3
22
+ number int
23
+ '''
24
+ fps_idx = pointnet2_utils.furthest_point_sample(data, number)
25
+ fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
26
+ return fps_data
27
+
28
+ # https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py
29
+ def knn_point(nsample, xyz, new_xyz):
30
+ """
31
+ Input:
32
+ nsample: max sample number in local region
33
+ xyz: all points, [B, N, C]
34
+ new_xyz: query points, [B, S, C]
35
+ Return:
36
+ group_idx: grouped points index, [B, S, nsample]
37
+ """
38
+ sqrdists = square_distance(new_xyz, xyz)
39
+ _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
40
+ return group_idx
41
+
42
+ def square_distance(src, dst):
43
+ """
44
+ Calculate Euclid distance between each two points.
45
+ src^T * dst = xn * xm + yn * ym + zn * zm;
46
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
47
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
48
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
49
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
50
+ Input:
51
+ src: source points, [B, N, C]
52
+ dst: target points, [B, M, C]
53
+ Output:
54
+ dist: per-point square distance, [B, N, M]
55
+ """
56
+ B, N, _ = src.shape
57
+ _, M, _ = dst.shape
58
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
59
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
60
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
61
+ return dist
62
+
63
+
64
+ class PatchDropout(nn.Module):
65
+ """
66
+ https://arxiv.org/abs/2212.00794
67
+ """
68
+
69
+ def __init__(self, prob, exclude_first_token=True):
70
+ super().__init__()
71
+ assert 0 <= prob < 1.
72
+ self.prob = prob
73
+ self.exclude_first_token = exclude_first_token # exclude CLS token
74
+ logging.info("patch dropout prob is {}".format(prob))
75
+
76
+ def forward(self, x):
77
+ # if not self.training or self.prob == 0.:
78
+ # return x
79
+
80
+ if self.exclude_first_token:
81
+ cls_tokens, x = x[:, :1], x[:, 1:]
82
+ else:
83
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
84
+
85
+ batch = x.size()[0]
86
+ num_tokens = x.size()[1]
87
+
88
+ batch_indices = torch.arange(batch)
89
+ batch_indices = batch_indices[..., None]
90
+
91
+ keep_prob = 1 - self.prob
92
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
93
+
94
+ rand = torch.randn(batch, num_tokens)
95
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
96
+
97
+ x = x[batch_indices, patch_indices_keep]
98
+
99
+ if self.exclude_first_token:
100
+ x = torch.cat((cls_tokens, x), dim=1)
101
+
102
+ return x
103
+
104
+
105
+ class Group(nn.Module):
106
+ def __init__(self, num_group, group_size):
107
+ super().__init__()
108
+ self.num_group = num_group
109
+ self.group_size = group_size
110
+
111
+ def forward(self, xyz, color):
112
+ '''
113
+ input: B N 3
114
+ ---------------------------
115
+ output: B G M 3
116
+ center : B G 3
117
+ '''
118
+ batch_size, num_points, _ = xyz.shape
119
+ # fps the centers out
120
+ center = fps(xyz, self.num_group) # B G 3
121
+ # knn to get the neighborhood
122
+ # _, idx = self.knn(xyz, center) # B G M
123
+ idx = knn_point(self.group_size, xyz, center) # B G M
124
+ assert idx.size(1) == self.num_group
125
+ assert idx.size(2) == self.group_size
126
+ idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
127
+ idx = idx + idx_base
128
+ idx = idx.view(-1)
129
+ neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
130
+ neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
131
+
132
+ neighborhood_color = color.view(batch_size * num_points, -1)[idx, :]
133
+ neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous()
134
+
135
+ # normalize
136
+ neighborhood = neighborhood - center.unsqueeze(2)
137
+
138
+ features = torch.cat((neighborhood, neighborhood_color), dim=-1)
139
+ return neighborhood, center, features
140
+
141
+ class Encoder(nn.Module):
142
+ def __init__(self, encoder_channel):
143
+ super().__init__()
144
+ self.encoder_channel = encoder_channel
145
+ self.first_conv = nn.Sequential(
146
+ nn.Conv1d(6, 128, 1),
147
+ nn.BatchNorm1d(128),
148
+ nn.ReLU(inplace=True),
149
+ nn.Conv1d(128, 256, 1)
150
+ )
151
+ self.second_conv = nn.Sequential(
152
+ nn.Conv1d(512, 512, 1),
153
+ nn.BatchNorm1d(512),
154
+ nn.ReLU(inplace=True),
155
+ nn.Conv1d(512, self.encoder_channel, 1)
156
+ )
157
+ def forward(self, point_groups):
158
+ '''
159
+ point_groups : B G N 3
160
+ -----------------
161
+ feature_global : B G C
162
+ '''
163
+ bs, g, n , _ = point_groups.shape
164
+ point_groups = point_groups.reshape(bs * g, n, 6)
165
+ # encoder
166
+ feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n
167
+ feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1
168
+ feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n
169
+ feature = self.second_conv(feature) # BG 1024 n
170
+ feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
171
+ return feature_global.reshape(bs, g, self.encoder_channel)
172
+
173
+ class PointcloudEncoder(nn.Module):
174
+ def __init__(self, point_transformer):
175
+ # use the giant branch of uni3d
176
+ super().__init__()
177
+ from easydict import EasyDict
178
+ self.trans_dim = 1408
179
+ self.embed_dim = 1024
180
+ self.group_size = 64
181
+ self.num_group = 512
182
+ # grouper
183
+ self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)
184
+ # define the encoder
185
+ self.encoder_dim = 512
186
+ self.encoder = Encoder(encoder_channel = self.encoder_dim)
187
+
188
+ # bridge encoder and transformer
189
+ self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim)
190
+
191
+ # bridge transformer and clip embedding
192
+ self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim)
193
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
194
+ self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
195
+
196
+ self.pos_embed = nn.Sequential(
197
+ nn.Linear(3, 128),
198
+ nn.GELU(),
199
+ nn.Linear(128, self.trans_dim)
200
+ )
201
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
202
+ self.patch_dropout = PatchDropout(0.) if 0. > 0. else nn.Identity()
203
+ self.visual = point_transformer
204
+
205
+
206
+ def forward(self, pts, colors):
207
+ # divide the point cloud in the same form. This is important
208
+ _, center, features = self.group_divider(pts, colors)
209
+
210
+ # encoder the input cloud patches
211
+ group_input_tokens = self.encoder(features) # B G N
212
+ group_input_tokens = self.encoder2trans(group_input_tokens)
213
+ # prepare cls
214
+ cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
215
+ cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
216
+ # add pos embedding
217
+ pos = self.pos_embed(center)
218
+ # final input
219
+ x = torch.cat((cls_tokens, group_input_tokens), dim=1)
220
+ pos = torch.cat((cls_pos, pos), dim=1)
221
+ # transformer
222
+ x = x + pos
223
+ # x = x.half()
224
+
225
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
226
+ x = self.patch_dropout(x)
227
+
228
+ x = self.visual.pos_drop(x)
229
+
230
+ # ModuleList not support forward
231
+ for i, blk in enumerate(self.visual.blocks):
232
+ x = blk(x)
233
+ x = self.visual.norm(x[:, 0, :])
234
+ x = self.visual.fc_norm(x)
235
+
236
+ x = self.trans2embed(x)
237
+ return x
238
+
239
+ class Uni3D(nn.Module):
240
+ def __init__(self, point_encoder):
241
+ super().__init__()
242
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
243
+ self.point_encoder = point_encoder
244
+
245
+ def encode_pc(self, pc):
246
+ xyz = pc[:,:,:3].contiguous()
247
+ color = pc[:,:,3:].contiguous()
248
+ pc_feat = self.point_encoder(xyz, color)
249
+ return pc_feat
250
+
251
+ def forward(self, pc, text, image):
252
+ text_embed_all = text
253
+ image_embed = image
254
+ pc_embed = self.encode_pc(pc)
255
+ return {'text_embed': text_embed_all,
256
+ 'pc_embed': pc_embed,
257
+ 'image_embed': image_embed,
258
+ 'logit_scale': self.logit_scale.exp()}
259
+
260
+ def get_metric_names(model):
261
+ return ['loss', 'uni3d_loss', 'pc_image_acc', 'pc_text_acc']
262
+
263
+ def create_uni3d(uni3d_path):
264
+ # create transformer blocks for point cloud via timm
265
+ point_transformer = timm.create_model("eva_giant_patch14_560")
266
+
267
+ # create whole point cloud encoder
268
+ point_encoder = PointcloudEncoder(point_transformer)
269
+
270
+ # uni3d model
271
+ model = Uni3D(point_encoder=point_encoder,)
272
+
273
+ checkpoint = torch.load(uni3d_path, map_location='cpu')
274
+ logging.info('loaded checkpoint {}'.format(uni3d_path))
275
+ sd = checkpoint['module']
276
+ if next(iter(sd.items()))[0].startswith('module'):
277
+ sd = {k[len('module.'):]: v for k, v in sd.items()}
278
+ model.load_state_dict(sd)
279
+ return model
280
+
281
+ class Uni3dEmbeddingEncoder(FeatureExtractor):
282
+ def __init__(self, cache_dir, **kwargs) -> None:
283
+ bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
284
+ uni3d_path = os.path.join(cache_dir, "Uni3D", "modelzoo", "uni3d-g", "model.pt") # concat the subfolder as hf_hub_download will put it here
285
+ clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")
286
+
287
+ if not os.path.exists(uni3d_path):
288
+ hf_hub_download("BAAI/Uni3D", "model.pt", subfolder="modelzoo/uni3d-g", cache_dir=cache_dir,
289
+ local_dir=cache_dir + os.sep + "Uni3D")
290
+ if not os.path.exists(clip_path):
291
+ hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin",
292
+ cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")
293
+
294
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
295
+ self.tokenizer = SimpleTokenizer(bpe_path)
296
+ self.model = create_uni3d(uni3d_path)
297
+ self.model.eval()
298
+ self.model.to(self.device)
299
+ self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
300
+ self.clip_model.to(self.device)
301
+
302
+ def pc_norm(self, pc):
303
+ """ pc: NxC, return NxC """
304
+ centroid = np.mean(pc, axis=0)
305
+ pc = pc - centroid
306
+ m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
307
+ pc = pc / m
308
+ return pc
309
+
310
+ @torch.no_grad()
311
+ def encode_3D(self, data):
312
+ pc = data.to(device=self.device, non_blocking=True)
313
+ pc_features = self.model.encode_pc(pc)
314
+ pc_features = pc_features / pc_features.norm(dim=-1, keepdim=True)
315
+ return pc_features.float()
316
+
317
+ @torch.no_grad()
318
+ def encode_text(self, input_text):
319
+ texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True)
320
+ if len(texts.shape) < 2:
321
+ texts = texts[None, ...]
322
+ class_embeddings = self.clip_model.encode_text(texts)
323
+ class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
324
+ return class_embeddings.float()
325
+
326
+ @torch.no_grad()
327
+ def encode_image(self, img_tensor_list):
328
+ image = img_tensor_list.to(device=self.device, non_blocking=True)
329
+ image_features = self.clip_model.encode_image(image)
330
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
331
+ return image_features.float()
332
+
333
+ def encode_query(self, query_list):
334
+ return self.encode_text(query_list)
335
+
336
+ def get_img_transform(self):
337
+ return self.preprocess
packages ADDED
@@ -0,0 +1 @@
 
 
1
+ libwebp-dev
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ datasets
3
+ timm
4
+ pillow
5
+ open-clip-torch
6
+ huggingface_hub
7
+ ftfy
8
+ regex
9
+ easydict
utils/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
utils/tokenizer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copied from github.com/baaivision/Uni3D
2
+ # # Modified from github.com/openai/CLIP
3
+ import gzip
4
+ import html
5
+ import os
6
+ from functools import lru_cache
7
+
8
+ import ftfy
9
+ import regex as re
10
+ import torch
11
+
12
+
13
+ @lru_cache()
14
+ def bytes_to_unicode():
15
+ """
16
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
17
+ The reversible bpe codes work on unicode strings.
18
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
19
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
20
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
21
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
22
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
23
+ """
24
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
25
+ cs = bs[:]
26
+ n = 0
27
+ for b in range(2**8):
28
+ if b not in bs:
29
+ bs.append(b)
30
+ cs.append(2**8+n)
31
+ n += 1
32
+ cs = [chr(n) for n in cs]
33
+ return dict(zip(bs, cs))
34
+
35
+
36
+ def get_pairs(word):
37
+ """Return set of symbol pairs in a word.
38
+ Word is represented as tuple of symbols (symbols being variable-length strings).
39
+ """
40
+ pairs = set()
41
+ prev_char = word[0]
42
+ for char in word[1:]:
43
+ pairs.add((prev_char, char))
44
+ prev_char = char
45
+ return pairs
46
+
47
+
48
+ def basic_clean(text):
49
+ text = ftfy.fix_text(text)
50
+ text = html.unescape(html.unescape(text))
51
+ return text.strip()
52
+
53
+
54
+ def whitespace_clean(text):
55
+ text = re.sub(r'\s+', ' ', text)
56
+ text = text.strip()
57
+ return text
58
+
59
+
60
+ class SimpleTokenizer(object):
61
+ def __init__(self, bpe_path):
62
+ self.byte_encoder = bytes_to_unicode()
63
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
64
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
65
+ merges = merges[1:49152-256-2+1]
66
+ merges = [tuple(merge.split()) for merge in merges]
67
+ vocab = list(bytes_to_unicode().values())
68
+ vocab = vocab + [v+'</w>' for v in vocab]
69
+ for merge in merges:
70
+ vocab.append(''.join(merge))
71
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
72
+ self.encoder = dict(zip(vocab, range(len(vocab))))
73
+ self.decoder = {v: k for k, v in self.encoder.items()}
74
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
75
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
76
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
77
+
78
+ def bpe(self, token):
79
+ if token in self.cache:
80
+ return self.cache[token]
81
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
82
+ pairs = get_pairs(word)
83
+
84
+ if not pairs:
85
+ return token+'</w>'
86
+
87
+ while True:
88
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
89
+ if bigram not in self.bpe_ranks:
90
+ break
91
+ first, second = bigram
92
+ new_word = []
93
+ i = 0
94
+ while i < len(word):
95
+ try:
96
+ j = word.index(first, i)
97
+ new_word.extend(word[i:j])
98
+ i = j
99
+ except:
100
+ new_word.extend(word[i:])
101
+ break
102
+
103
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
104
+ new_word.append(first+second)
105
+ i += 2
106
+ else:
107
+ new_word.append(word[i])
108
+ i += 1
109
+ new_word = tuple(new_word)
110
+ word = new_word
111
+ if len(word) == 1:
112
+ break
113
+ else:
114
+ pairs = get_pairs(word)
115
+ word = ' '.join(word)
116
+ self.cache[token] = word
117
+ return word
118
+
119
+ def encode(self, text):
120
+ bpe_tokens = []
121
+ text = whitespace_clean(basic_clean(text)).lower()
122
+ for token in re.findall(self.pat, text):
123
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
124
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
125
+ return bpe_tokens
126
+
127
+ def decode(self, tokens):
128
+ text = ''.join([self.decoder[token] for token in tokens])
129
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
130
+ return text
131
+
132
+ def __call__(self, texts, context_length=77):
133
+ if isinstance(texts, str):
134
+ texts = [texts]
135
+
136
+ sot_token = self.encoder["<|startoftext|>"]
137
+ eot_token = self.encoder["<|endoftext|>"]
138
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
139
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
140
+
141
+ for i, tokens in enumerate(all_tokens):
142
+ tokens = tokens[:context_length]
143
+ result[i, :len(tokens)] = torch.tensor(tokens)
144
+
145
+ if len(result) == 1:
146
+ return result[0]
147
+ return result