Spaces:
Running
on
Zero
Running
on
Zero
gokaygokay
commited on
Upload 43 files
Browse files- .gitattributes +39 -35
- README.md +12 -12
- app.py +131 -154
- demo_files/comp.gif +3 -0
- demo_files/examples/animal_character.png +3 -0
- demo_files/examples/animal_character_2.png +3 -0
- demo_files/examples/axe.png +0 -0
- demo_files/examples/chair1.png +0 -0
- demo_files/examples/character1.png +0 -0
- demo_files/examples/otter_samurai.png +0 -0
- demo_files/examples/raccoon_wizard.png +0 -0
- demo_files/examples/stylized-rocks.png +0 -0
- demo_files/examples/tree.png +0 -0
- demo_files/hdri/abandoned_tiled_room_1k.hdr +0 -0
- demo_files/hdri/metro_noord_1k.hdr +0 -0
- demo_files/hdri/neon_photostudio_1k.hdr +0 -0
- demo_files/hdri/peppermint_powerplant_1k.hdr +0 -0
- demo_files/hdri/rainforest_trail_1k.hdr +0 -0
- demo_files/hdri/studio_small_08_1k.hdr +0 -0
- demo_files/hdri/urban_alley_01_1k.hdr +0 -0
- demo_files/scatterplot.jpg +0 -0
- demo_files/teaser.gif +3 -0
- flux_lora.py +109 -0
- load/tets/160_tets.npz +3 -0
- requirements.txt +19 -11
- sf3d/box_uv_unwrap.py +610 -0
- sf3d/models/camera.py +32 -0
- sf3d/models/global_estimator/multi_head_estimator.py +118 -0
- sf3d/models/image_estimator/clip_based_estimator.py +168 -0
- sf3d/models/isosurface.py +229 -0
- sf3d/models/mesh.py +172 -0
- sf3d/models/network.py +195 -0
- sf3d/models/tokenizers/dinov2.py +1196 -0
- sf3d/models/tokenizers/image.py +99 -0
- sf3d/models/tokenizers/triplane.py +49 -0
- sf3d/models/transformers/attention.py +31 -0
- sf3d/models/transformers/backbone.py +515 -0
- sf3d/models/utils.py +292 -0
- sf3d/system.py +482 -0
- sf3d/texture_baker.py +87 -0
- sf3d/texture_baker.slang +93 -0
- sf3d/utils.py +91 -0
- stable_fast.py +355 -0
.gitattributes
CHANGED
@@ -1,35 +1,39 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
demo_files/comp.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
demo_files/examples/animal_character_2.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
demo_files/examples/animal_character.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
demo_files/teaser.gif filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title: FLUX.1-dev + Captioner
|
3 |
-
emoji: 🐨
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: indigo
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.37.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
-
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: FLUX.1-dev + Captioner
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.37.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,154 +1,131 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import
|
4 |
-
|
5 |
-
|
6 |
-
from
|
7 |
-
import
|
8 |
-
|
9 |
-
import
|
10 |
-
import
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
.
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
|
133 |
-
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
|
134 |
-
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5)
|
135 |
-
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
|
136 |
-
|
137 |
-
generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
|
138 |
-
|
139 |
-
with gr.Column(scale=1):
|
140 |
-
with gr.Group(elem_classes="output-group"):
|
141 |
-
output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
|
142 |
-
final_prompt = gr.Textbox(label="Final Prompt Used")
|
143 |
-
used_seed = gr.Number(label="Seed Used")
|
144 |
-
|
145 |
-
generate_btn.click(
|
146 |
-
fn=process_workflow,
|
147 |
-
inputs=[
|
148 |
-
input_image, text_prompt, use_enhancer, seed, randomize_seed,
|
149 |
-
width, height, guidance_scale, num_inference_steps
|
150 |
-
],
|
151 |
-
outputs=[output_image, final_prompt, used_seed]
|
152 |
-
)
|
153 |
-
|
154 |
-
demo.launch(debug=True)
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import time
|
4 |
+
import gradio as gr
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from diffusers import FluxPipeline
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
from sf3d.system import SF3D
|
10 |
+
import sf3d.utils as sf3d_utils
|
11 |
+
from gradio_litmodel3d import LitModel3D
|
12 |
+
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
dtype = torch.bfloat16
|
15 |
+
|
16 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
17 |
+
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
18 |
+
# Set up environment and cache
|
19 |
+
cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
|
20 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
21 |
+
os.environ["HF_HUB_CACHE"] = cache_path
|
22 |
+
os.environ["HF_HOME"] = cache_path
|
23 |
+
|
24 |
+
if not os.path.exists(cache_path):
|
25 |
+
os.makedirs(cache_path, exist_ok=True)
|
26 |
+
|
27 |
+
# Initialize Flux pipeline
|
28 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=huggingface_token)
|
29 |
+
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
|
30 |
+
pipe.fuse_lora(lora_scale=0.125)
|
31 |
+
pipe.to(device="cuda", dtype=torch.bfloat16)
|
32 |
+
|
33 |
+
# Initialize SF3D model
|
34 |
+
sf3d_model = SF3D.from_pretrained(
|
35 |
+
"stabilityai/stable-fast-3d",
|
36 |
+
config_name="config.yaml",
|
37 |
+
weight_name="model.safetensors",
|
38 |
+
token=huggingface_token
|
39 |
+
|
40 |
+
)
|
41 |
+
sf3d_model.eval().cuda()
|
42 |
+
|
43 |
+
# Constants for SF3D
|
44 |
+
COND_WIDTH, COND_HEIGHT = 512, 512
|
45 |
+
COND_DISTANCE, COND_FOVY_DEG = 1.6, 40
|
46 |
+
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
|
47 |
+
|
48 |
+
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
|
49 |
+
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
|
50 |
+
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
|
51 |
+
)
|
52 |
+
|
53 |
+
def generate_image(prompt, height, width, steps, scales, seed):
|
54 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
55 |
+
return pipe(
|
56 |
+
prompt=[prompt],
|
57 |
+
generator=torch.Generator().manual_seed(int(seed)),
|
58 |
+
num_inference_steps=int(steps),
|
59 |
+
guidance_scale=float(scales),
|
60 |
+
height=int(height),
|
61 |
+
width=int(width),
|
62 |
+
max_sequence_length=256
|
63 |
+
).images[0]
|
64 |
+
|
65 |
+
def create_batch(input_image: Image.Image) -> dict:
|
66 |
+
img_cond = torch.from_numpy(
|
67 |
+
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0
|
68 |
+
).float().clip(0, 1)
|
69 |
+
mask_cond = img_cond[:, :, -1:]
|
70 |
+
rgb_cond = torch.lerp(
|
71 |
+
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
|
72 |
+
)
|
73 |
+
|
74 |
+
batch_elem = {
|
75 |
+
"rgb_cond": rgb_cond,
|
76 |
+
"mask_cond": mask_cond,
|
77 |
+
"c2w_cond": c2w_cond.unsqueeze(0),
|
78 |
+
"intrinsic_cond": intrinsic.unsqueeze(0),
|
79 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
|
80 |
+
}
|
81 |
+
return {k: v.unsqueeze(0) for k, v in batch_elem.items()}
|
82 |
+
|
83 |
+
def generate_3d_model(input_image):
|
84 |
+
with torch.no_grad():
|
85 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
86 |
+
model_batch = create_batch(input_image)
|
87 |
+
model_batch = {k: v.cuda() for k, v in model_batch.items()}
|
88 |
+
trimesh_mesh, _ = sf3d_model.generate_mesh(model_batch, 1024)
|
89 |
+
trimesh_mesh = trimesh_mesh[0]
|
90 |
+
|
91 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
92 |
+
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
|
93 |
+
return tmp_file.name
|
94 |
+
|
95 |
+
def process_and_generate(prompt, height, width, steps, scales, seed):
|
96 |
+
# Generate image from prompt
|
97 |
+
generated_image = generate_image(prompt, height, width, steps, scales, seed)
|
98 |
+
|
99 |
+
# Generate 3D model from the image
|
100 |
+
glb_file = generate_3d_model(generated_image)
|
101 |
+
|
102 |
+
return generated_image, glb_file
|
103 |
+
|
104 |
+
# Gradio interface
|
105 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
106 |
+
gr.Markdown("# Text-to-3D Model Generator")
|
107 |
+
|
108 |
+
with gr.Row():
|
109 |
+
with gr.Column(scale=3):
|
110 |
+
prompt = gr.Textbox(label="Your Image Description", lines=3)
|
111 |
+
with gr.Accordion("Advanced Settings", open=False):
|
112 |
+
height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
|
113 |
+
width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
|
114 |
+
steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
|
115 |
+
scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
|
116 |
+
seed = gr.Number(label="Seed", value=3413, precision=0)
|
117 |
+
|
118 |
+
generate_btn = gr.Button("Generate 3D Model", variant="primary")
|
119 |
+
|
120 |
+
with gr.Column(scale=4):
|
121 |
+
output_image = gr.Image(label="Generated Image")
|
122 |
+
output_3d = LitModel3D(label="3D Model", clear_color=[0.0, 0.0, 0.0, 0.0])
|
123 |
+
|
124 |
+
generate_btn.click(
|
125 |
+
process_and_generate,
|
126 |
+
inputs=[prompt, height, width, steps, scales, seed],
|
127 |
+
outputs=[output_image, output_3d]
|
128 |
+
)
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo_files/comp.gif
ADDED
Git LFS Details
|
demo_files/examples/animal_character.png
ADDED
Git LFS Details
|
demo_files/examples/animal_character_2.png
ADDED
Git LFS Details
|
demo_files/examples/axe.png
ADDED
demo_files/examples/chair1.png
ADDED
demo_files/examples/character1.png
ADDED
demo_files/examples/otter_samurai.png
ADDED
demo_files/examples/raccoon_wizard.png
ADDED
demo_files/examples/stylized-rocks.png
ADDED
demo_files/examples/tree.png
ADDED
demo_files/hdri/abandoned_tiled_room_1k.hdr
ADDED
Binary file (478 kB). View file
|
|
demo_files/hdri/metro_noord_1k.hdr
ADDED
Binary file (467 kB). View file
|
|
demo_files/hdri/neon_photostudio_1k.hdr
ADDED
Binary file (438 kB). View file
|
|
demo_files/hdri/peppermint_powerplant_1k.hdr
ADDED
Binary file (473 kB). View file
|
|
demo_files/hdri/rainforest_trail_1k.hdr
ADDED
Binary file (512 kB). View file
|
|
demo_files/hdri/studio_small_08_1k.hdr
ADDED
Binary file (412 kB). View file
|
|
demo_files/hdri/urban_alley_01_1k.hdr
ADDED
Binary file (458 kB). View file
|
|
demo_files/scatterplot.jpg
ADDED
demo_files/teaser.gif
ADDED
Git LFS Details
|
flux_lora.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from os import path
|
6 |
+
from safetensors.torch import load_file
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
|
9 |
+
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
|
10 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
11 |
+
os.environ["HF_HUB_CACHE"] = cache_path
|
12 |
+
os.environ["HF_HOME"] = cache_path
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
import torch
|
16 |
+
from diffusers import FluxPipeline
|
17 |
+
|
18 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
19 |
+
|
20 |
+
class timer:
|
21 |
+
def __init__(self, method_name="timed process"):
|
22 |
+
self.method = method_name
|
23 |
+
def __enter__(self):
|
24 |
+
self.start = time.time()
|
25 |
+
print(f"{self.method} starts")
|
26 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
27 |
+
end = time.time()
|
28 |
+
print(f"{self.method} took {str(round(end - self.start, 2))}s")
|
29 |
+
|
30 |
+
if not path.exists(cache_path):
|
31 |
+
os.makedirs(cache_path, exist_ok=True)
|
32 |
+
|
33 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
34 |
+
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
|
35 |
+
pipe.fuse_lora(lora_scale=0.125)
|
36 |
+
pipe.to(device="cuda", dtype=torch.bfloat16)
|
37 |
+
|
38 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
39 |
+
gr.Markdown(
|
40 |
+
"""
|
41 |
+
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
42 |
+
<h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">Hyper-FLUX-8steps-LoRA</h1>
|
43 |
+
<p style="font-size: 1rem; margin-bottom: 1.5rem;">AutoML team from ByteDance</p>
|
44 |
+
</div>
|
45 |
+
"""
|
46 |
+
)
|
47 |
+
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column(scale=3):
|
50 |
+
with gr.Group():
|
51 |
+
prompt = gr.Textbox(
|
52 |
+
label="Your Image Description",
|
53 |
+
placeholder="E.g., A serene landscape with mountains and a lake at sunset",
|
54 |
+
lines=3
|
55 |
+
)
|
56 |
+
|
57 |
+
with gr.Accordion("Advanced Settings", open=False):
|
58 |
+
with gr.Group():
|
59 |
+
with gr.Row():
|
60 |
+
height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
|
61 |
+
width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
|
62 |
+
|
63 |
+
with gr.Row():
|
64 |
+
steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
|
65 |
+
scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
|
66 |
+
|
67 |
+
seed = gr.Number(label="Seed (for reproducibility)", value=3413, precision=0)
|
68 |
+
|
69 |
+
generate_btn = gr.Button("Generate Image", variant="primary", scale=1)
|
70 |
+
|
71 |
+
with gr.Column(scale=4):
|
72 |
+
output = gr.Image(label="Your Generated Image")
|
73 |
+
|
74 |
+
gr.Markdown(
|
75 |
+
"""
|
76 |
+
<div style="max-width: 650px; margin: 2rem auto; padding: 1rem; border-radius: 10px; background-color: #f0f0f0;">
|
77 |
+
<h2 style="font-size: 1.5rem; margin-bottom: 1rem;">How to Use</h2>
|
78 |
+
<ol style="padding-left: 1.5rem;">
|
79 |
+
<li>Enter a detailed description of the image you want to create.</li>
|
80 |
+
<li>Adjust advanced settings if desired (tap to expand).</li>
|
81 |
+
<li>Tap "Generate Image" and wait for your creation!</li>
|
82 |
+
</ol>
|
83 |
+
<p style="margin-top: 1rem; font-style: italic;">Tip: Be specific in your description for best results!</p>
|
84 |
+
</div>
|
85 |
+
"""
|
86 |
+
)
|
87 |
+
|
88 |
+
@spaces.GPU
|
89 |
+
def process_image(height, width, steps, scales, prompt, seed):
|
90 |
+
global pipe
|
91 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
|
92 |
+
return pipe(
|
93 |
+
prompt=[prompt],
|
94 |
+
generator=torch.Generator().manual_seed(int(seed)),
|
95 |
+
num_inference_steps=int(steps),
|
96 |
+
guidance_scale=float(scales),
|
97 |
+
height=int(height),
|
98 |
+
width=int(width),
|
99 |
+
max_sequence_length=256
|
100 |
+
).images[0]
|
101 |
+
|
102 |
+
generate_btn.click(
|
103 |
+
process_image,
|
104 |
+
inputs=[height, width, steps, scales, prompt, seed],
|
105 |
+
outputs=output
|
106 |
+
)
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
demo.launch()
|
load/tets/160_tets.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
|
3 |
+
size 15408790
|
requirements.txt
CHANGED
@@ -1,11 +1,19 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.1.2
|
2 |
+
torchvision>=0.16.2
|
3 |
+
einops>=0.7.0
|
4 |
+
jaxtyping>=0.2.31
|
5 |
+
omegaconf>=2.3.0
|
6 |
+
transformers>=4.43.3
|
7 |
+
slangtorch>=1.2.2
|
8 |
+
open_clip_torch>=2.24.0
|
9 |
+
trimesh>=4.4.1
|
10 |
+
numpy>=1.26.4
|
11 |
+
huggingface-hub>=0.23.4
|
12 |
+
rembg[gpu]>=2.0.57
|
13 |
+
gradio-litmodel3d>=0.0.1
|
14 |
+
accelerate
|
15 |
+
diffusers>=0.30.0
|
16 |
+
invisible_watermark
|
17 |
+
xformers
|
18 |
+
sentencepiece
|
19 |
+
peft
|
sf3d/box_uv_unwrap.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from jaxtyping import Float, Integer
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
from sf3d.models.utils import dot, triangle_intersection_2d
|
10 |
+
|
11 |
+
|
12 |
+
def _box_assign_vertex_to_cube_face(
|
13 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
14 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
15 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
16 |
+
bbox: Float[Tensor, "2 3"],
|
17 |
+
) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
|
18 |
+
# Test to not have a scaled model to fit the space better
|
19 |
+
# bbox_min = bbox[:1].mean(-1, keepdim=True)
|
20 |
+
# bbox_max = bbox[1:].mean(-1, keepdim=True)
|
21 |
+
# v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
|
22 |
+
|
23 |
+
# Create a [0, 1] normalized vertex position
|
24 |
+
v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
|
25 |
+
# And to [-1, 1]
|
26 |
+
v_pos_normalized = 2.0 * v_pos_normalized - 1.0
|
27 |
+
|
28 |
+
# Get all vertex positions for each triangle
|
29 |
+
# Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
|
30 |
+
v0 = v_pos_normalized[triangle_idxs[:, 0]]
|
31 |
+
v1 = v_pos_normalized[triangle_idxs[:, 1]]
|
32 |
+
v2 = v_pos_normalized[triangle_idxs[:, 2]]
|
33 |
+
tri_stack = torch.stack([v0, v1, v2], dim=1)
|
34 |
+
|
35 |
+
vn0 = vertex_normals[triangle_idxs[:, 0]]
|
36 |
+
vn1 = vertex_normals[triangle_idxs[:, 1]]
|
37 |
+
vn2 = vertex_normals[triangle_idxs[:, 2]]
|
38 |
+
tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
|
39 |
+
|
40 |
+
# Just average the normals per face
|
41 |
+
face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
|
42 |
+
|
43 |
+
# Now decide based on the face normal in which box map we project
|
44 |
+
# abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
|
45 |
+
abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
|
46 |
+
|
47 |
+
axis = torch.tensor(
|
48 |
+
[
|
49 |
+
[1, 0, 0], # 0
|
50 |
+
[-1, 0, 0], # 1
|
51 |
+
[0, 1, 0], # 2
|
52 |
+
[0, -1, 0], # 3
|
53 |
+
[0, 0, 1], # 4
|
54 |
+
[0, 0, -1], # 5
|
55 |
+
],
|
56 |
+
device=face_normal.device,
|
57 |
+
dtype=face_normal.dtype,
|
58 |
+
)
|
59 |
+
face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
|
60 |
+
index = face_normal_axis.argmax(-1)
|
61 |
+
|
62 |
+
max_axis, uc, vc = (
|
63 |
+
torch.ones_like(abs_x),
|
64 |
+
torch.zeros_like(tri_stack[..., :1]),
|
65 |
+
torch.zeros_like(tri_stack[..., :1]),
|
66 |
+
)
|
67 |
+
mask_pos_x = index == 0
|
68 |
+
max_axis[mask_pos_x] = abs_x[mask_pos_x]
|
69 |
+
uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
|
70 |
+
vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
|
71 |
+
|
72 |
+
mask_neg_x = index == 1
|
73 |
+
max_axis[mask_neg_x] = abs_x[mask_neg_x]
|
74 |
+
uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
|
75 |
+
vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
|
76 |
+
|
77 |
+
mask_pos_y = index == 2
|
78 |
+
max_axis[mask_pos_y] = abs_y[mask_pos_y]
|
79 |
+
uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
|
80 |
+
vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
|
81 |
+
|
82 |
+
mask_neg_y = index == 3
|
83 |
+
max_axis[mask_neg_y] = abs_y[mask_neg_y]
|
84 |
+
uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
|
85 |
+
vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
|
86 |
+
|
87 |
+
mask_pos_z = index == 4
|
88 |
+
max_axis[mask_pos_z] = abs_z[mask_pos_z]
|
89 |
+
uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
|
90 |
+
vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
|
91 |
+
|
92 |
+
mask_neg_z = index == 5
|
93 |
+
max_axis[mask_neg_z] = abs_z[mask_neg_z]
|
94 |
+
uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
|
95 |
+
vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
|
96 |
+
|
97 |
+
# UC from [-1, 1] to [0, 1]
|
98 |
+
max_dim_div = max_axis.max(dim=0, keepdims=True).values
|
99 |
+
uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
100 |
+
vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
101 |
+
|
102 |
+
uv = torch.stack([uc, vc], dim=-1)
|
103 |
+
|
104 |
+
return uv, index
|
105 |
+
|
106 |
+
|
107 |
+
def _assign_faces_uv_to_atlas_index(
|
108 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
109 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
110 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
111 |
+
face_index: Integer[Tensor, "Nf 3"],
|
112 |
+
) -> Integer[Tensor, "Nf"]: # noqa: F821
|
113 |
+
triangle_pos = vertex_positions[triangle_idxs]
|
114 |
+
# We need to do perform 3 overlap checks.
|
115 |
+
# The first set is placed in the upper two thirds of the UV atlas.
|
116 |
+
# Conceptually, this is the direct visible surfaces from the each cube side
|
117 |
+
# The second set is placed in the lower thirds and the left half of the UV atlas.
|
118 |
+
# This is the first set of occluded surfaces. They will also be saved in the projected fashion
|
119 |
+
# The third pass finds all non assigned faces. They will be placed in the bottom right half of
|
120 |
+
# the UV atlas in scattered fashion.
|
121 |
+
assign_idx = face_index.clone()
|
122 |
+
for overlap_step in range(3):
|
123 |
+
overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
|
124 |
+
for i in range(overlap_step * 6, (overlap_step + 1) * 6):
|
125 |
+
mask = assign_idx == i
|
126 |
+
if not mask.any():
|
127 |
+
continue
|
128 |
+
# Get all elements belonging to the projection face
|
129 |
+
uv_triangle = face_uv[mask]
|
130 |
+
cur_triangle_pos = triangle_pos[mask]
|
131 |
+
# Find the center of the uv coordinates
|
132 |
+
center_uv = uv_triangle.mean(dim=1, keepdim=True)
|
133 |
+
# And also the radius of the triangle
|
134 |
+
uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values
|
135 |
+
|
136 |
+
potentially_overlapping_mask = (
|
137 |
+
# Find all close triangles
|
138 |
+
(center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
|
139 |
+
# Do not select the same element by offseting with an large valued identity matrix
|
140 |
+
+ torch.eye(
|
141 |
+
uv_triangle.shape[0],
|
142 |
+
device=uv_triangle.device,
|
143 |
+
dtype=uv_triangle.dtype,
|
144 |
+
).unsqueeze(-1)
|
145 |
+
* 1000
|
146 |
+
)
|
147 |
+
# Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
|
148 |
+
potentially_overlapping_mask = (
|
149 |
+
potentially_overlapping_mask
|
150 |
+
<= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
|
151 |
+
).squeeze(-1)
|
152 |
+
overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)
|
153 |
+
|
154 |
+
# Only unique triangles (A|B and B|A should be the same)
|
155 |
+
f = torch.min(overlap_coords, dim=-1).values
|
156 |
+
s = torch.max(overlap_coords, dim=-1).values
|
157 |
+
overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
|
158 |
+
first, second = overlap_coords.unbind(-1)
|
159 |
+
|
160 |
+
# Get the triangles
|
161 |
+
tri_1 = uv_triangle[first]
|
162 |
+
tri_2 = uv_triangle[second]
|
163 |
+
|
164 |
+
# Perform the actual set with the reduced number of potentially overlapping triangles
|
165 |
+
its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)
|
166 |
+
|
167 |
+
# So we now need to detect which triangles are the occluded ones.
|
168 |
+
# We always assume the first to be the visible one (the others should move)
|
169 |
+
# In the previous step we use a lexigraphical sort to get the unique pairs
|
170 |
+
# In this we use a sort based on the orthographic projection
|
171 |
+
ax = 0 if i < 2 else 1 if i < 4 else 2
|
172 |
+
use_max = i % 2 == 1
|
173 |
+
|
174 |
+
tri1_c = cur_triangle_pos[first].mean(dim=1)
|
175 |
+
tri2_c = cur_triangle_pos[second].mean(dim=1)
|
176 |
+
|
177 |
+
mark_first = (
|
178 |
+
(tri1_c[..., ax] > tri2_c[..., ax])
|
179 |
+
if use_max
|
180 |
+
else (tri1_c[..., ax] < tri2_c[..., ax])
|
181 |
+
)
|
182 |
+
first[mark_first] = second[mark_first]
|
183 |
+
|
184 |
+
# Lastly the same index can be tested multiple times.
|
185 |
+
# If one marks it as overlapping we keep it marked as such.
|
186 |
+
# We do this by testing if it has been marked at least once.
|
187 |
+
unique_idx, rev_idx = torch.unique(first, return_inverse=True)
|
188 |
+
|
189 |
+
add = torch.zeros_like(unique_idx, dtype=torch.float32)
|
190 |
+
add.index_add_(0, rev_idx, its.float())
|
191 |
+
its_mask = add > 0
|
192 |
+
|
193 |
+
# And fill it in the overlapping indicator
|
194 |
+
idx = torch.where(mask)[0][unique_idx]
|
195 |
+
overlapping_indicator[idx] = its_mask
|
196 |
+
|
197 |
+
# Move the index to the overlap regions (shift by 6)
|
198 |
+
assign_idx[overlapping_indicator] += 6
|
199 |
+
|
200 |
+
# We do not care about the correct face placement after the first 2 slices
|
201 |
+
max_idx = 6 * 2
|
202 |
+
return assign_idx.clamp(0, max_idx)
|
203 |
+
|
204 |
+
|
205 |
+
def _find_slice_offset_and_scale(
|
206 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
207 |
+
) -> Tuple[
|
208 |
+
Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821
|
209 |
+
]: # noqa: F821
|
210 |
+
# 6 due to the 6 cube faces
|
211 |
+
off = 1 / 3
|
212 |
+
dupl_off = 1 / 6
|
213 |
+
|
214 |
+
# Here, we need to decide how to pack the textures in the case of overlap
|
215 |
+
def x_offset_calc(x, i):
|
216 |
+
offset_calc = i // 6
|
217 |
+
# Initial coordinates - just 3x2 grid
|
218 |
+
if offset_calc == 0:
|
219 |
+
return off * x
|
220 |
+
else:
|
221 |
+
# Smaller 3x2 grid plus eventual shift to right for
|
222 |
+
# second overlap
|
223 |
+
return dupl_off * x + min(offset_calc - 1, 1) * 0.5
|
224 |
+
|
225 |
+
def y_offset_calc(x, i):
|
226 |
+
offset_calc = i // 6
|
227 |
+
# Initial coordinates - just a 3x2 grid
|
228 |
+
if offset_calc == 0:
|
229 |
+
return off * x
|
230 |
+
else:
|
231 |
+
# Smaller coordinates in the lowest row
|
232 |
+
return dupl_off * x + off * 2
|
233 |
+
|
234 |
+
offset_x = torch.zeros_like(index, dtype=torch.float32)
|
235 |
+
offset_y = torch.zeros_like(index, dtype=torch.float32)
|
236 |
+
offset_x_vals = [0, 1, 2, 0, 1, 2]
|
237 |
+
offset_y_vals = [0, 0, 0, 1, 1, 1]
|
238 |
+
for i in range(index.max().item() + 1):
|
239 |
+
mask = index == i
|
240 |
+
if not mask.any():
|
241 |
+
continue
|
242 |
+
offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
|
243 |
+
offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
|
244 |
+
|
245 |
+
div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
|
246 |
+
# All overlap elements are saved in half scale
|
247 |
+
div_x[index >= 6] = 6
|
248 |
+
div_y = div_x.clone() # Same for y
|
249 |
+
# Except for the random overlaps
|
250 |
+
div_x[index >= 12] = 2
|
251 |
+
# But the random overlaps are saved in a large block in the lower thirds
|
252 |
+
div_y[index >= 12] = 3
|
253 |
+
|
254 |
+
return offset_x, offset_y, div_x, div_y
|
255 |
+
|
256 |
+
|
257 |
+
def rotation_flip_matrix_2d(
|
258 |
+
rad: float, flip_x: bool = False, flip_y: bool = False
|
259 |
+
) -> Float[Tensor, "2 2"]:
|
260 |
+
cos = math.cos(rad)
|
261 |
+
sin = math.sin(rad)
|
262 |
+
rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
|
263 |
+
flip_mat = torch.tensor(
|
264 |
+
[
|
265 |
+
[-1 if flip_x else 1, 0],
|
266 |
+
[0, -1 if flip_y else 1],
|
267 |
+
],
|
268 |
+
dtype=torch.float32,
|
269 |
+
)
|
270 |
+
|
271 |
+
return flip_mat @ rot_mat
|
272 |
+
|
273 |
+
|
274 |
+
def calculate_tangents(
|
275 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
276 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
277 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
278 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
279 |
+
) -> Float[Tensor, "Nf 3 4"]: # noqa: F821
|
280 |
+
vn_idx = [None] * 3
|
281 |
+
pos = [None] * 3
|
282 |
+
tex = face_uv.unbind(1)
|
283 |
+
for i in range(0, 3):
|
284 |
+
pos[i] = vertex_positions[triangle_idxs[:, i]]
|
285 |
+
# t_nrm_idx is always the same as t_pos_idx
|
286 |
+
vn_idx[i] = triangle_idxs[:, i]
|
287 |
+
|
288 |
+
tangents = torch.zeros_like(vertex_normals)
|
289 |
+
tansum = torch.zeros_like(vertex_normals)
|
290 |
+
|
291 |
+
# Compute tangent space for each triangle
|
292 |
+
duv1 = tex[1] - tex[0]
|
293 |
+
duv2 = tex[2] - tex[0]
|
294 |
+
dpos1 = pos[1] - pos[0]
|
295 |
+
dpos2 = pos[2] - pos[0]
|
296 |
+
|
297 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
298 |
+
|
299 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
300 |
+
|
301 |
+
# Avoid division by zero for degenerated texture coordinates
|
302 |
+
denom_safe = denom.clip(1e-6)
|
303 |
+
tang = tng_nom / denom_safe
|
304 |
+
|
305 |
+
# Update all 3 vertices
|
306 |
+
for i in range(0, 3):
|
307 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
308 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
309 |
+
tansum.scatter_add_(
|
310 |
+
0, idx, torch.ones_like(tang)
|
311 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
312 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
313 |
+
# triangles influence the tangent space more
|
314 |
+
tangents = tangents / tansum
|
315 |
+
|
316 |
+
# Normalize and make sure tangent is perpendicular to normal
|
317 |
+
tangents = F.normalize(tangents, dim=1)
|
318 |
+
tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)
|
319 |
+
|
320 |
+
return tangents
|
321 |
+
|
322 |
+
|
323 |
+
def _rotate_uv_slices_consistent_space(
|
324 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
325 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
326 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
327 |
+
uv: Float[Tensor, "Nf 3 2"],
|
328 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
329 |
+
):
|
330 |
+
tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
|
331 |
+
pos_stack = torch.stack(
|
332 |
+
[
|
333 |
+
-vertex_positions[..., 1],
|
334 |
+
vertex_positions[..., 0],
|
335 |
+
torch.zeros_like(vertex_positions[..., 0]),
|
336 |
+
],
|
337 |
+
dim=-1,
|
338 |
+
)
|
339 |
+
expected_tangents = F.normalize(
|
340 |
+
torch.linalg.cross(
|
341 |
+
vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
|
342 |
+
),
|
343 |
+
-1,
|
344 |
+
)
|
345 |
+
|
346 |
+
actual_tangents = tangents[triangle_idxs]
|
347 |
+
expected_tangents = expected_tangents[triangle_idxs]
|
348 |
+
|
349 |
+
def rotation_matrix_2d(theta):
|
350 |
+
c, s = torch.cos(theta), torch.sin(theta)
|
351 |
+
return torch.tensor([[c, -s], [s, c]])
|
352 |
+
|
353 |
+
# Now find the rotation
|
354 |
+
index_mod = index % 6 # Shouldn't happen. Just for safety
|
355 |
+
for i in range(6):
|
356 |
+
mask = index_mod == i
|
357 |
+
if not mask.any():
|
358 |
+
continue
|
359 |
+
|
360 |
+
actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
|
361 |
+
expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
|
362 |
+
|
363 |
+
dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
|
364 |
+
cross_product = (
|
365 |
+
actual_mean_tangent[0] * expected_mean_tangent[1]
|
366 |
+
- actual_mean_tangent[1] * expected_mean_tangent[0]
|
367 |
+
)
|
368 |
+
angle = torch.atan2(cross_product, dot_product)
|
369 |
+
|
370 |
+
rot_matrix = rotation_matrix_2d(angle).to(mask.device)
|
371 |
+
# Center the uv coordinate to be in the range of -1 to 1 and 0 centered
|
372 |
+
uv_cur = uv[mask] * 2 - 1 # Center it first
|
373 |
+
# Rotate it
|
374 |
+
uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
|
375 |
+
|
376 |
+
# Rescale uv[mask] to be within the 0-1 range
|
377 |
+
uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
|
378 |
+
|
379 |
+
return uv
|
380 |
+
|
381 |
+
|
382 |
+
def _handle_slice_uvs(
|
383 |
+
uv: Float[Tensor, "Nf 3 2"],
|
384 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
385 |
+
island_padding: float,
|
386 |
+
max_index: int = 6 * 2,
|
387 |
+
) -> Float[Tensor, "Nf 3 2"]: # noqa: F821
|
388 |
+
uc, vc = uv.unbind(-1)
|
389 |
+
|
390 |
+
# Get the second slice (The first overlap)
|
391 |
+
index_filter = [index == i for i in range(6, max_index)]
|
392 |
+
|
393 |
+
# Normalize them to always fully fill the atlas patch
|
394 |
+
for i, fi in enumerate(index_filter):
|
395 |
+
if fi.sum() > 0:
|
396 |
+
# Scale the slice but only up to a factor of 2
|
397 |
+
# This keeps the texture resolution with the first slice in line (Half space in UV)
|
398 |
+
uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
|
399 |
+
vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)
|
400 |
+
|
401 |
+
uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
402 |
+
vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
403 |
+
|
404 |
+
return torch.stack([uc_padded, vc_padded], dim=-1)
|
405 |
+
|
406 |
+
|
407 |
+
def _handle_remaining_uvs(
|
408 |
+
uv: Float[Tensor, "Nf 3 2"],
|
409 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
410 |
+
island_padding: float,
|
411 |
+
) -> Float[Tensor, "Nf 3 2"]:
|
412 |
+
uc, vc = uv.unbind(-1)
|
413 |
+
# Get all remaining elements
|
414 |
+
remaining_filter = index >= 6 * 2
|
415 |
+
squares_left = remaining_filter.sum()
|
416 |
+
|
417 |
+
if squares_left == 0:
|
418 |
+
return uv
|
419 |
+
|
420 |
+
uc = uc[remaining_filter]
|
421 |
+
vc = vc[remaining_filter]
|
422 |
+
|
423 |
+
# Or remaining triangles are distributed in a rectangle
|
424 |
+
# The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
|
425 |
+
ratio = 0.5 * (1 / 3) # 1.5
|
426 |
+
# sqrt(744/(0.5*(1/3)))
|
427 |
+
|
428 |
+
mult = math.sqrt(squares_left / ratio)
|
429 |
+
num_square_width = int(math.ceil(0.5 * mult))
|
430 |
+
num_square_height = int(math.ceil(squares_left / num_square_width))
|
431 |
+
|
432 |
+
width = 1 / num_square_width
|
433 |
+
height = 1 / num_square_height
|
434 |
+
|
435 |
+
# The idea is again to keep the texture resolution consistent with the first slice
|
436 |
+
# This only occupys half the region in the texture chart but the scaling on the squares
|
437 |
+
# assumes full coverage.
|
438 |
+
clip_val = min(width, height) * 1.5
|
439 |
+
# Now normalize the UVs with taking into account the maximum scaling
|
440 |
+
uc = (uc - uc.min(dim=1, keepdim=True).values) / (
|
441 |
+
uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
|
442 |
+
).clip(clip_val)
|
443 |
+
vc = (vc - vc.min(dim=1, keepdim=True).values) / (
|
444 |
+
vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
|
445 |
+
).clip(clip_val)
|
446 |
+
# Add a small padding
|
447 |
+
uc = (
|
448 |
+
uc * (1 - island_padding * num_square_width * 0.5)
|
449 |
+
+ island_padding * num_square_width * 0.25
|
450 |
+
).clip(0, 1)
|
451 |
+
vc = (
|
452 |
+
vc * (1 - island_padding * num_square_height * 0.5)
|
453 |
+
+ island_padding * num_square_height * 0.25
|
454 |
+
).clip(0, 1)
|
455 |
+
|
456 |
+
uc = uc * width
|
457 |
+
vc = vc * height
|
458 |
+
|
459 |
+
# And calculate offsets for each element
|
460 |
+
idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
|
461 |
+
x_idx = idx % num_square_width
|
462 |
+
y_idx = idx // num_square_width
|
463 |
+
# And move each triangle to its own spot
|
464 |
+
uc = uc + x_idx[:, None] * width
|
465 |
+
vc = vc + y_idx[:, None] * height
|
466 |
+
|
467 |
+
uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
468 |
+
vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
469 |
+
|
470 |
+
uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
|
471 |
+
|
472 |
+
return uv
|
473 |
+
|
474 |
+
|
475 |
+
def _distribute_individual_uvs_in_atlas(
|
476 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
477 |
+
assigned_faces: Integer[Tensor, "Nf"], # noqa: F821
|
478 |
+
offset_x: Float[Tensor, "Nf"], # noqa: F821
|
479 |
+
offset_y: Float[Tensor, "Nf"], # noqa: F821
|
480 |
+
div_x: Float[Tensor, "Nf"], # noqa: F821
|
481 |
+
div_y: Float[Tensor, "Nf"], # noqa: F821
|
482 |
+
island_padding: float,
|
483 |
+
):
|
484 |
+
# Place the slice first
|
485 |
+
placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
|
486 |
+
# Then handle the remaining overlap elements
|
487 |
+
placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)
|
488 |
+
|
489 |
+
uc, vc = placed_uv.unbind(-1)
|
490 |
+
uc = uc / div_x[:, None] + offset_x[:, None]
|
491 |
+
vc = vc / div_y[:, None] + offset_y[:, None]
|
492 |
+
|
493 |
+
uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
|
494 |
+
|
495 |
+
return uv
|
496 |
+
|
497 |
+
|
498 |
+
def _get_unique_face_uv(
|
499 |
+
uv: Float[Tensor, "Nf 3 2"],
|
500 |
+
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
|
501 |
+
unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
|
502 |
+
# And add the face to uv index mapping
|
503 |
+
vtex_idx = unique_idx.view(-1, 3)
|
504 |
+
|
505 |
+
return unique_uv, vtex_idx
|
506 |
+
|
507 |
+
|
508 |
+
def _align_mesh_with_main_axis(
|
509 |
+
vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
|
510 |
+
) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
|
511 |
+
# Use pca to find the 2 main axis (third is derived by cross product)
|
512 |
+
# Set the random seed so it's repeatable
|
513 |
+
torch.manual_seed(0)
|
514 |
+
_, _, v = torch.pca_lowrank(vertex_positions, q=2)
|
515 |
+
main_axis, seconday_axis = v[:, 0], v[:, 1]
|
516 |
+
|
517 |
+
main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
|
518 |
+
# Orthogonalize the second axis
|
519 |
+
seconday_axis: Float[Tensor, "3"] = F.normalize(
|
520 |
+
seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
|
521 |
+
)
|
522 |
+
# Create perpendicular third axis
|
523 |
+
third_axis: Float[Tensor, "3"] = F.normalize(
|
524 |
+
torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
|
525 |
+
)
|
526 |
+
|
527 |
+
# Check to which canonical axis each aligns
|
528 |
+
main_axis_max_idx = main_axis.abs().argmax().item()
|
529 |
+
seconday_axis_max_idx = seconday_axis.abs().argmax().item()
|
530 |
+
third_axis_max_idx = third_axis.abs().argmax().item()
|
531 |
+
|
532 |
+
# Now sort the axes based on the argmax so they align with thecanonoical axes
|
533 |
+
# If two axes have the same argmax move one of them
|
534 |
+
all_possible_axis = {0, 1, 2}
|
535 |
+
cur_index = 1
|
536 |
+
while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
|
537 |
+
# Find missing axis
|
538 |
+
missing_axis = all_possible_axis - set(
|
539 |
+
[main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
|
540 |
+
)
|
541 |
+
missing_axis = missing_axis.pop()
|
542 |
+
# Just assign it to third axis as it had the smallest contribution to the
|
543 |
+
# overall shape
|
544 |
+
if cur_index == 1:
|
545 |
+
third_axis_max_idx = missing_axis
|
546 |
+
elif cur_index == 2:
|
547 |
+
seconday_axis_max_idx = missing_axis
|
548 |
+
else:
|
549 |
+
raise ValueError("Could not find 3 unique axis")
|
550 |
+
cur_index += 1
|
551 |
+
|
552 |
+
if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
|
553 |
+
raise ValueError("Could not find 3 unique axis")
|
554 |
+
|
555 |
+
axes = [None] * 3
|
556 |
+
axes[main_axis_max_idx] = main_axis
|
557 |
+
axes[seconday_axis_max_idx] = seconday_axis
|
558 |
+
axes[third_axis_max_idx] = third_axis
|
559 |
+
# Create rotation matrix from the individual axes
|
560 |
+
rot_mat = torch.stack(axes, dim=1).T
|
561 |
+
|
562 |
+
# Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
|
563 |
+
vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
|
564 |
+
vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
|
565 |
+
|
566 |
+
return vertex_positions, vertex_normals
|
567 |
+
|
568 |
+
|
569 |
+
def box_projection_uv_unwrap(
|
570 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
571 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
572 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
573 |
+
island_padding: float,
|
574 |
+
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
|
575 |
+
# Align the mesh with main axis directions first
|
576 |
+
vertex_positions, vertex_normals = _align_mesh_with_main_axis(
|
577 |
+
vertex_positions, vertex_normals
|
578 |
+
)
|
579 |
+
|
580 |
+
bbox: Float[Tensor, "2 3"] = torch.stack(
|
581 |
+
[vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
|
582 |
+
)
|
583 |
+
# First decide in which cube face the triangle is placed
|
584 |
+
face_uv, face_index = _box_assign_vertex_to_cube_face(
|
585 |
+
vertex_positions, vertex_normals, triangle_idxs, bbox
|
586 |
+
)
|
587 |
+
|
588 |
+
# Rotate the UV islands in a way that they align with the radial z tangent space
|
589 |
+
face_uv = _rotate_uv_slices_consistent_space(
|
590 |
+
vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
|
591 |
+
)
|
592 |
+
|
593 |
+
# Then find where where the face is placed in the atlas.
|
594 |
+
# This has to detect potential overlaps
|
595 |
+
assigned_atlas_index = _assign_faces_uv_to_atlas_index(
|
596 |
+
vertex_positions, triangle_idxs, face_uv, face_index
|
597 |
+
)
|
598 |
+
|
599 |
+
# Then figure out the final place in the atlas based on the assignment
|
600 |
+
offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
|
601 |
+
assigned_atlas_index
|
602 |
+
)
|
603 |
+
|
604 |
+
# Next distribute the faces in the uv atlas
|
605 |
+
placed_uv = _distribute_individual_uvs_in_atlas(
|
606 |
+
face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
|
607 |
+
)
|
608 |
+
|
609 |
+
# And get the unique per-triangle UV coordinates
|
610 |
+
return _get_unique_face_uv(placed_uv)
|
sf3d/models/camera.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from sf3d.models.utils import BaseModule
|
8 |
+
|
9 |
+
|
10 |
+
class LinearCameraEmbedder(BaseModule):
|
11 |
+
@dataclass
|
12 |
+
class Config(BaseModule.Config):
|
13 |
+
in_channels: int = 25
|
14 |
+
out_channels: int = 768
|
15 |
+
conditions: List[str] = field(default_factory=list)
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
|
21 |
+
|
22 |
+
def forward(self, **kwargs):
|
23 |
+
cond_tensors = []
|
24 |
+
for cond_name in self.cfg.conditions:
|
25 |
+
assert cond_name in kwargs
|
26 |
+
cond = kwargs[cond_name]
|
27 |
+
# cond in shape (B, Nv, ...)
|
28 |
+
cond_tensors.append(cond.view(*cond.shape[:2], -1))
|
29 |
+
cond_tensor = torch.cat(cond_tensors, dim=-1)
|
30 |
+
assert cond_tensor.shape[-1] == self.cfg.in_channels
|
31 |
+
embedding = self.linear(cond_tensor)
|
32 |
+
return embedding
|
sf3d/models/global_estimator/multi_head_estimator.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Any, List, Optional
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from jaxtyping import Float
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from sf3d.models.network import get_activation
|
9 |
+
from sf3d.models.utils import BaseModule
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class HeadSpec:
|
14 |
+
name: str
|
15 |
+
out_channels: int
|
16 |
+
n_hidden_layers: int
|
17 |
+
output_activation: Optional[str] = None
|
18 |
+
output_bias: float = 0.0
|
19 |
+
add_to_decoder_features: bool = False
|
20 |
+
shape: Optional[list[int]] = None
|
21 |
+
|
22 |
+
|
23 |
+
class MultiHeadEstimator(BaseModule):
|
24 |
+
@dataclass
|
25 |
+
class Config(BaseModule.Config):
|
26 |
+
triplane_features: int = 1024
|
27 |
+
|
28 |
+
n_layers: int = 2
|
29 |
+
hidden_features: int = 512
|
30 |
+
activation: str = "relu"
|
31 |
+
|
32 |
+
pool: str = "max"
|
33 |
+
# Literal["mean", "max"] = "mean" # noqa: F821
|
34 |
+
|
35 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
36 |
+
|
37 |
+
cfg: Config
|
38 |
+
|
39 |
+
def configure(self):
|
40 |
+
layers = []
|
41 |
+
cur_features = self.cfg.triplane_features * 3
|
42 |
+
for _ in range(self.cfg.n_layers):
|
43 |
+
layers.append(
|
44 |
+
nn.Conv2d(
|
45 |
+
cur_features,
|
46 |
+
self.cfg.hidden_features,
|
47 |
+
kernel_size=3,
|
48 |
+
padding=0,
|
49 |
+
stride=2,
|
50 |
+
)
|
51 |
+
)
|
52 |
+
layers.append(self.make_activation(self.cfg.activation))
|
53 |
+
|
54 |
+
cur_features = self.cfg.hidden_features
|
55 |
+
|
56 |
+
self.layers = nn.Sequential(*layers)
|
57 |
+
|
58 |
+
assert len(self.cfg.heads) > 0
|
59 |
+
heads = {}
|
60 |
+
for head in self.cfg.heads:
|
61 |
+
head_layers = []
|
62 |
+
for i in range(head.n_hidden_layers):
|
63 |
+
head_layers += [
|
64 |
+
nn.Linear(
|
65 |
+
self.cfg.hidden_features,
|
66 |
+
self.cfg.hidden_features,
|
67 |
+
),
|
68 |
+
self.make_activation(self.cfg.activation),
|
69 |
+
]
|
70 |
+
head_layers += [
|
71 |
+
nn.Linear(
|
72 |
+
self.cfg.hidden_features,
|
73 |
+
head.out_channels,
|
74 |
+
),
|
75 |
+
]
|
76 |
+
heads[head.name] = nn.Sequential(*head_layers)
|
77 |
+
self.heads = nn.ModuleDict(heads)
|
78 |
+
|
79 |
+
def make_activation(self, activation):
|
80 |
+
if activation == "relu":
|
81 |
+
return nn.ReLU(inplace=True)
|
82 |
+
elif activation == "silu":
|
83 |
+
return nn.SiLU(inplace=True)
|
84 |
+
else:
|
85 |
+
raise NotImplementedError
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
triplane: Float[Tensor, "B 3 F Ht Wt"],
|
90 |
+
) -> dict[str, Any]:
|
91 |
+
x = self.layers(
|
92 |
+
triplane.reshape(
|
93 |
+
triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
|
94 |
+
)
|
95 |
+
)
|
96 |
+
|
97 |
+
if self.cfg.pool == "max":
|
98 |
+
x = x.amax(dim=[-2, -1])
|
99 |
+
elif self.cfg.pool == "mean":
|
100 |
+
x = x.mean(dim=[-2, -1])
|
101 |
+
else:
|
102 |
+
raise NotImplementedError
|
103 |
+
|
104 |
+
out = {
|
105 |
+
("decoder_" if head.add_to_decoder_features else "")
|
106 |
+
+ head.name: get_activation(head.output_activation)(
|
107 |
+
self.heads[head.name](x) + head.output_bias
|
108 |
+
)
|
109 |
+
for head in self.cfg.heads
|
110 |
+
}
|
111 |
+
for head in self.cfg.heads:
|
112 |
+
if head.shape:
|
113 |
+
head_name = (
|
114 |
+
"decoder_" if head.add_to_decoder_features else ""
|
115 |
+
) + head.name
|
116 |
+
out[head_name] = out[head_name].reshape(*head.shape)
|
117 |
+
|
118 |
+
return out
|
sf3d/models/image_estimator/clip_based_estimator.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Any, List, Optional
|
3 |
+
|
4 |
+
import open_clip
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
from torchvision.transforms import Normalize
|
10 |
+
|
11 |
+
from sf3d.models.network import get_activation
|
12 |
+
from sf3d.models.utils import BaseModule
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class HeadSpec:
|
17 |
+
name: str
|
18 |
+
out_channels: int
|
19 |
+
n_hidden_layers: int
|
20 |
+
output_activation: Optional[str] = None
|
21 |
+
output_bias: float = 0.0
|
22 |
+
add_to_decoder_features: bool = False
|
23 |
+
shape: Optional[list[int]] = None
|
24 |
+
|
25 |
+
|
26 |
+
class ClipBasedHeadEstimator(BaseModule):
|
27 |
+
@dataclass
|
28 |
+
class Config(BaseModule.Config):
|
29 |
+
model: str = "ViT-B-32"
|
30 |
+
pretrain: str = "laion2b_s34b_b79k"
|
31 |
+
|
32 |
+
distribution: str = "beta"
|
33 |
+
|
34 |
+
# ["mean", "mode", "sample", "sample_mean"]
|
35 |
+
distribution_eval: str = "mode"
|
36 |
+
|
37 |
+
activation: str = "relu"
|
38 |
+
hidden_features: int = 512
|
39 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
40 |
+
|
41 |
+
cfg: Config
|
42 |
+
|
43 |
+
def configure(self):
|
44 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
45 |
+
self.cfg.model, pretrained=self.cfg.pretrain
|
46 |
+
)
|
47 |
+
self.model.eval()
|
48 |
+
|
49 |
+
# Do not add the weights in self.model to the optimizer
|
50 |
+
for param in self.model.parameters():
|
51 |
+
param.requires_grad = False
|
52 |
+
|
53 |
+
assert len(self.cfg.heads) > 0
|
54 |
+
heads = {}
|
55 |
+
for head in self.cfg.heads:
|
56 |
+
head_layers = []
|
57 |
+
|
58 |
+
for i in range(head.n_hidden_layers):
|
59 |
+
head_layers += [
|
60 |
+
nn.Linear(
|
61 |
+
self.cfg.hidden_features,
|
62 |
+
self.cfg.hidden_features,
|
63 |
+
),
|
64 |
+
self.make_activation(self.cfg.activation),
|
65 |
+
]
|
66 |
+
|
67 |
+
head_layers = [nn.Sequential(*head_layers)]
|
68 |
+
head_layers += [
|
69 |
+
nn.Sequential(
|
70 |
+
nn.Linear(
|
71 |
+
self.cfg.hidden_features,
|
72 |
+
self.cfg.hidden_features,
|
73 |
+
),
|
74 |
+
self.make_activation(self.cfg.activation),
|
75 |
+
nn.Linear(self.cfg.hidden_features, 1),
|
76 |
+
)
|
77 |
+
for _ in range(2)
|
78 |
+
]
|
79 |
+
heads[head.name] = nn.ModuleList(head_layers)
|
80 |
+
self.heads = nn.ModuleDict(heads)
|
81 |
+
|
82 |
+
def make_activation(self, activation):
|
83 |
+
if activation == "relu":
|
84 |
+
return nn.ReLU(inplace=True)
|
85 |
+
elif activation == "silu":
|
86 |
+
return nn.SiLU(inplace=True)
|
87 |
+
else:
|
88 |
+
raise NotImplementedError
|
89 |
+
|
90 |
+
def forward(
|
91 |
+
self,
|
92 |
+
cond_image: Float[Tensor, "B 1 H W 3"],
|
93 |
+
sample: bool = True,
|
94 |
+
) -> dict[str, Any]:
|
95 |
+
# Run the model
|
96 |
+
# Resize cond_image to 224
|
97 |
+
cond_image = nn.functional.interpolate(
|
98 |
+
cond_image.flatten(0, 1).permute(0, 3, 1, 2),
|
99 |
+
size=(224, 224),
|
100 |
+
mode="bilinear",
|
101 |
+
align_corners=False,
|
102 |
+
)
|
103 |
+
cond_image = Normalize(
|
104 |
+
mean=open_clip.constants.OPENAI_DATASET_MEAN,
|
105 |
+
std=open_clip.constants.OPENAI_DATASET_STD,
|
106 |
+
)(cond_image)
|
107 |
+
image_features = self.model.encode_image(cond_image)
|
108 |
+
|
109 |
+
# Run the heads
|
110 |
+
outputs = {}
|
111 |
+
|
112 |
+
for head_dict in self.cfg.heads:
|
113 |
+
head_name = head_dict.name
|
114 |
+
shared_head, d1_h, d2_h = self.heads[head_name]
|
115 |
+
shared_features = shared_head(image_features)
|
116 |
+
d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
|
117 |
+
if self.cfg.distribution == "normal":
|
118 |
+
mean = d1
|
119 |
+
var = d2
|
120 |
+
if mean.shape[-1] == 1:
|
121 |
+
outputs[head_name] = torch.distributions.Normal(
|
122 |
+
mean + head_dict.output_bias,
|
123 |
+
torch.nn.functional.softplus(var),
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
outputs[head_name] = torch.distributions.MultivariateNormal(
|
127 |
+
mean + head_dict.output_bias,
|
128 |
+
torch.nn.functional.softplus(var).diag_embed(),
|
129 |
+
)
|
130 |
+
elif self.cfg.distribution == "beta":
|
131 |
+
outputs[head_name] = torch.distributions.Beta(
|
132 |
+
torch.nn.functional.softplus(d1 + head_dict.output_bias),
|
133 |
+
torch.nn.functional.softplus(d2 + head_dict.output_bias),
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
raise NotImplementedError
|
137 |
+
|
138 |
+
if sample:
|
139 |
+
for head_dict in self.cfg.heads:
|
140 |
+
head_name = head_dict.name
|
141 |
+
dist = outputs[head_name]
|
142 |
+
|
143 |
+
if self.cfg.distribution_eval == "mean":
|
144 |
+
out = dist.mean
|
145 |
+
elif self.cfg.distribution_eval == "mode":
|
146 |
+
out = dist.mode
|
147 |
+
elif self.cfg.distribution_eval == "sample_mean":
|
148 |
+
out = dist.sample([10]).mean(-1)
|
149 |
+
else:
|
150 |
+
# use rsample if gradient is needed
|
151 |
+
out = dist.rsample() if self.training else dist.sample()
|
152 |
+
|
153 |
+
outputs[head_name] = get_activation(head_dict.output_activation)(out)
|
154 |
+
outputs[f"{head_name}_dist"] = dist
|
155 |
+
|
156 |
+
for head in self.cfg.heads:
|
157 |
+
if head.shape:
|
158 |
+
if not sample:
|
159 |
+
raise ValueError(
|
160 |
+
"Cannot reshape non-sampled probabilisitic outputs"
|
161 |
+
)
|
162 |
+
outputs[head.name] = outputs[head.name].reshape(*head.shape)
|
163 |
+
|
164 |
+
if head.add_to_decoder_features:
|
165 |
+
outputs[f"decoder_{head.name}"] = outputs[head.name]
|
166 |
+
del outputs[head.name]
|
167 |
+
|
168 |
+
return outputs
|
sf3d/models/isosurface.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from jaxtyping import Float, Integer
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
from .mesh import Mesh
|
10 |
+
|
11 |
+
|
12 |
+
class IsosurfaceHelper(nn.Module):
|
13 |
+
points_range: Tuple[float, float] = (0, 1)
|
14 |
+
|
15 |
+
@property
|
16 |
+
def grid_vertices(self) -> Float[Tensor, "N 3"]:
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
@property
|
20 |
+
def requires_instance_per_batch(self) -> bool:
|
21 |
+
return False
|
22 |
+
|
23 |
+
|
24 |
+
class MarchingTetrahedraHelper(IsosurfaceHelper):
|
25 |
+
def __init__(self, resolution: int, tets_path: str):
|
26 |
+
super().__init__()
|
27 |
+
self.resolution = resolution
|
28 |
+
self.tets_path = tets_path
|
29 |
+
|
30 |
+
self.triangle_table: Float[Tensor, "..."]
|
31 |
+
self.register_buffer(
|
32 |
+
"triangle_table",
|
33 |
+
torch.as_tensor(
|
34 |
+
[
|
35 |
+
[-1, -1, -1, -1, -1, -1],
|
36 |
+
[1, 0, 2, -1, -1, -1],
|
37 |
+
[4, 0, 3, -1, -1, -1],
|
38 |
+
[1, 4, 2, 1, 3, 4],
|
39 |
+
[3, 1, 5, -1, -1, -1],
|
40 |
+
[2, 3, 0, 2, 5, 3],
|
41 |
+
[1, 4, 0, 1, 5, 4],
|
42 |
+
[4, 2, 5, -1, -1, -1],
|
43 |
+
[4, 5, 2, -1, -1, -1],
|
44 |
+
[4, 1, 0, 4, 5, 1],
|
45 |
+
[3, 2, 0, 3, 5, 2],
|
46 |
+
[1, 3, 5, -1, -1, -1],
|
47 |
+
[4, 1, 2, 4, 3, 1],
|
48 |
+
[3, 0, 4, -1, -1, -1],
|
49 |
+
[2, 0, 1, -1, -1, -1],
|
50 |
+
[-1, -1, -1, -1, -1, -1],
|
51 |
+
],
|
52 |
+
dtype=torch.long,
|
53 |
+
),
|
54 |
+
persistent=False,
|
55 |
+
)
|
56 |
+
self.num_triangles_table: Integer[Tensor, "..."]
|
57 |
+
self.register_buffer(
|
58 |
+
"num_triangles_table",
|
59 |
+
torch.as_tensor(
|
60 |
+
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
|
61 |
+
),
|
62 |
+
persistent=False,
|
63 |
+
)
|
64 |
+
self.base_tet_edges: Integer[Tensor, "..."]
|
65 |
+
self.register_buffer(
|
66 |
+
"base_tet_edges",
|
67 |
+
torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
|
68 |
+
persistent=False,
|
69 |
+
)
|
70 |
+
|
71 |
+
tets = np.load(self.tets_path)
|
72 |
+
self._grid_vertices: Float[Tensor, "..."]
|
73 |
+
self.register_buffer(
|
74 |
+
"_grid_vertices",
|
75 |
+
torch.from_numpy(tets["vertices"]).float(),
|
76 |
+
persistent=False,
|
77 |
+
)
|
78 |
+
self.indices: Integer[Tensor, "..."]
|
79 |
+
self.register_buffer(
|
80 |
+
"indices", torch.from_numpy(tets["indices"]).long(), persistent=False
|
81 |
+
)
|
82 |
+
|
83 |
+
self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
84 |
+
|
85 |
+
center_indices, boundary_indices = self.get_center_boundary_index(
|
86 |
+
self._grid_vertices
|
87 |
+
)
|
88 |
+
self.center_indices: Integer[Tensor, "..."]
|
89 |
+
self.register_buffer("center_indices", center_indices, persistent=False)
|
90 |
+
self.boundary_indices: Integer[Tensor, "..."]
|
91 |
+
self.register_buffer("boundary_indices", boundary_indices, persistent=False)
|
92 |
+
|
93 |
+
def get_center_boundary_index(self, verts):
|
94 |
+
magn = torch.sum(verts**2, dim=-1)
|
95 |
+
|
96 |
+
center_idx = torch.argmin(magn)
|
97 |
+
boundary_neg = verts == verts.max()
|
98 |
+
boundary_pos = verts == verts.min()
|
99 |
+
|
100 |
+
boundary = torch.bitwise_or(boundary_pos, boundary_neg)
|
101 |
+
boundary = torch.sum(boundary.float(), dim=-1)
|
102 |
+
|
103 |
+
boundary_idx = torch.nonzero(boundary)
|
104 |
+
return center_idx, boundary_idx.squeeze(dim=-1)
|
105 |
+
|
106 |
+
def normalize_grid_deformation(
|
107 |
+
self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
|
108 |
+
) -> Float[Tensor, "Nv 3"]:
|
109 |
+
return (
|
110 |
+
(self.points_range[1] - self.points_range[0])
|
111 |
+
/ self.resolution # half tet size is approximately 1 / self.resolution
|
112 |
+
* torch.tanh(grid_vertex_offsets)
|
113 |
+
) # FIXME: hard-coded activation
|
114 |
+
|
115 |
+
@property
|
116 |
+
def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
|
117 |
+
return self._grid_vertices
|
118 |
+
|
119 |
+
@property
|
120 |
+
def all_edges(self) -> Integer[Tensor, "Ne 2"]:
|
121 |
+
if self._all_edges is None:
|
122 |
+
# compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
|
123 |
+
edges = torch.tensor(
|
124 |
+
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
|
125 |
+
dtype=torch.long,
|
126 |
+
device=self.indices.device,
|
127 |
+
)
|
128 |
+
_all_edges = self.indices[:, edges].reshape(-1, 2)
|
129 |
+
_all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
|
130 |
+
_all_edges = torch.unique(_all_edges_sorted, dim=0)
|
131 |
+
self._all_edges = _all_edges
|
132 |
+
return self._all_edges
|
133 |
+
|
134 |
+
def sort_edges(self, edges_ex2):
|
135 |
+
with torch.no_grad():
|
136 |
+
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
|
137 |
+
order = order.unsqueeze(dim=1)
|
138 |
+
|
139 |
+
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
140 |
+
b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
|
141 |
+
|
142 |
+
return torch.stack([a, b], -1)
|
143 |
+
|
144 |
+
def _forward(self, pos_nx3, sdf_n, tet_fx4):
|
145 |
+
with torch.no_grad():
|
146 |
+
occ_n = sdf_n > 0
|
147 |
+
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
148 |
+
occ_sum = torch.sum(occ_fx4, -1)
|
149 |
+
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
150 |
+
occ_sum = occ_sum[valid_tets]
|
151 |
+
|
152 |
+
# find all vertices
|
153 |
+
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
|
154 |
+
all_edges = self.sort_edges(all_edges)
|
155 |
+
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
156 |
+
|
157 |
+
unique_edges = unique_edges.long()
|
158 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
159 |
+
mapping = (
|
160 |
+
torch.ones(
|
161 |
+
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
|
162 |
+
)
|
163 |
+
* -1
|
164 |
+
)
|
165 |
+
mapping[mask_edges] = torch.arange(
|
166 |
+
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
|
167 |
+
)
|
168 |
+
idx_map = mapping[idx_map] # map edges to verts
|
169 |
+
|
170 |
+
interp_v = unique_edges[mask_edges]
|
171 |
+
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
172 |
+
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
173 |
+
edges_to_interp_sdf[:, -1] *= -1
|
174 |
+
|
175 |
+
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
176 |
+
|
177 |
+
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
|
178 |
+
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
179 |
+
|
180 |
+
idx_map = idx_map.reshape(-1, 6)
|
181 |
+
|
182 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
|
183 |
+
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
184 |
+
num_triangles = self.num_triangles_table[tetindex]
|
185 |
+
|
186 |
+
# Generate triangle indices
|
187 |
+
faces = torch.cat(
|
188 |
+
(
|
189 |
+
torch.gather(
|
190 |
+
input=idx_map[num_triangles == 1],
|
191 |
+
dim=1,
|
192 |
+
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
|
193 |
+
).reshape(-1, 3),
|
194 |
+
torch.gather(
|
195 |
+
input=idx_map[num_triangles == 2],
|
196 |
+
dim=1,
|
197 |
+
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
|
198 |
+
).reshape(-1, 3),
|
199 |
+
),
|
200 |
+
dim=0,
|
201 |
+
)
|
202 |
+
|
203 |
+
return verts, faces
|
204 |
+
|
205 |
+
def forward(
|
206 |
+
self,
|
207 |
+
level: Float[Tensor, "N3 1"],
|
208 |
+
deformation: Optional[Float[Tensor, "N3 3"]] = None,
|
209 |
+
) -> Mesh:
|
210 |
+
if deformation is not None:
|
211 |
+
grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
|
212 |
+
deformation
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
grid_vertices = self.grid_vertices
|
216 |
+
|
217 |
+
v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
|
218 |
+
|
219 |
+
mesh = Mesh(
|
220 |
+
v_pos=v_pos,
|
221 |
+
t_pos_idx=t_pos_idx,
|
222 |
+
# extras
|
223 |
+
grid_vertices=grid_vertices,
|
224 |
+
tet_edges=self.all_edges,
|
225 |
+
grid_level=level,
|
226 |
+
grid_deformation=deformation,
|
227 |
+
)
|
228 |
+
|
229 |
+
return mesh
|
sf3d/models/mesh.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from jaxtyping import Float, Integer
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from sf3d.box_uv_unwrap import box_projection_uv_unwrap
|
11 |
+
from sf3d.models.utils import dot
|
12 |
+
|
13 |
+
|
14 |
+
class Mesh:
|
15 |
+
def __init__(
|
16 |
+
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
|
17 |
+
) -> None:
|
18 |
+
self.v_pos: Float[Tensor, "Nv 3"] = v_pos
|
19 |
+
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
|
20 |
+
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
|
21 |
+
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
|
22 |
+
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
|
23 |
+
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
24 |
+
self.extras: Dict[str, Any] = {}
|
25 |
+
for k, v in kwargs.items():
|
26 |
+
self.add_extra(k, v)
|
27 |
+
|
28 |
+
def add_extra(self, k, v) -> None:
|
29 |
+
self.extras[k] = v
|
30 |
+
|
31 |
+
@property
|
32 |
+
def requires_grad(self):
|
33 |
+
return self.v_pos.requires_grad
|
34 |
+
|
35 |
+
@property
|
36 |
+
def v_nrm(self):
|
37 |
+
if self._v_nrm is None:
|
38 |
+
self._v_nrm = self._compute_vertex_normal()
|
39 |
+
return self._v_nrm
|
40 |
+
|
41 |
+
@property
|
42 |
+
def v_tng(self):
|
43 |
+
if self._v_tng is None:
|
44 |
+
self._v_tng = self._compute_vertex_tangent()
|
45 |
+
return self._v_tng
|
46 |
+
|
47 |
+
@property
|
48 |
+
def v_tex(self):
|
49 |
+
if self._v_tex is None:
|
50 |
+
self.unwrap_uv()
|
51 |
+
return self._v_tex
|
52 |
+
|
53 |
+
@property
|
54 |
+
def edges(self):
|
55 |
+
if self._edges is None:
|
56 |
+
self._edges = self._compute_edges()
|
57 |
+
return self._edges
|
58 |
+
|
59 |
+
def _compute_vertex_normal(self):
|
60 |
+
i0 = self.t_pos_idx[:, 0]
|
61 |
+
i1 = self.t_pos_idx[:, 1]
|
62 |
+
i2 = self.t_pos_idx[:, 2]
|
63 |
+
|
64 |
+
v0 = self.v_pos[i0, :]
|
65 |
+
v1 = self.v_pos[i1, :]
|
66 |
+
v2 = self.v_pos[i2, :]
|
67 |
+
|
68 |
+
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
69 |
+
|
70 |
+
# Splat face normals to vertices
|
71 |
+
v_nrm = torch.zeros_like(self.v_pos)
|
72 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
73 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
74 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
75 |
+
|
76 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
77 |
+
v_nrm = torch.where(
|
78 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
79 |
+
)
|
80 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
81 |
+
|
82 |
+
if torch.is_anomaly_enabled():
|
83 |
+
assert torch.all(torch.isfinite(v_nrm))
|
84 |
+
|
85 |
+
return v_nrm
|
86 |
+
|
87 |
+
def _compute_vertex_tangent(self):
|
88 |
+
vn_idx = [None] * 3
|
89 |
+
pos = [None] * 3
|
90 |
+
tex = [None] * 3
|
91 |
+
for i in range(0, 3):
|
92 |
+
pos[i] = self.v_pos[self.t_pos_idx[:, i]]
|
93 |
+
tex[i] = self.v_tex[self.t_pos_idx[:, i]]
|
94 |
+
# t_nrm_idx is always the same as t_pos_idx
|
95 |
+
vn_idx[i] = self.t_pos_idx[:, i]
|
96 |
+
|
97 |
+
tangents = torch.zeros_like(self.v_nrm)
|
98 |
+
tansum = torch.zeros_like(self.v_nrm)
|
99 |
+
|
100 |
+
# Compute tangent space for each triangle
|
101 |
+
duv1 = tex[1] - tex[0]
|
102 |
+
duv2 = tex[2] - tex[0]
|
103 |
+
dpos1 = pos[1] - pos[0]
|
104 |
+
dpos2 = pos[2] - pos[0]
|
105 |
+
|
106 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
107 |
+
|
108 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
109 |
+
|
110 |
+
# Avoid division by zero for degenerated texture coordinates
|
111 |
+
denom_safe = denom.clip(1e-6)
|
112 |
+
tang = tng_nom / denom_safe
|
113 |
+
|
114 |
+
# Update all 3 vertices
|
115 |
+
for i in range(0, 3):
|
116 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
117 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
118 |
+
tansum.scatter_add_(
|
119 |
+
0, idx, torch.ones_like(tang)
|
120 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
121 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
122 |
+
# triangles influence the tangent space more
|
123 |
+
tangents = tangents / tansum
|
124 |
+
|
125 |
+
# Normalize and make sure tangent is perpendicular to normal
|
126 |
+
tangents = F.normalize(tangents, dim=1)
|
127 |
+
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
|
128 |
+
|
129 |
+
if torch.is_anomaly_enabled():
|
130 |
+
assert torch.all(torch.isfinite(tangents))
|
131 |
+
|
132 |
+
return tangents
|
133 |
+
|
134 |
+
@torch.no_grad()
|
135 |
+
def unwrap_uv(
|
136 |
+
self,
|
137 |
+
island_padding: float = 0.02,
|
138 |
+
) -> Mesh:
|
139 |
+
uv, indices = box_projection_uv_unwrap(
|
140 |
+
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
|
141 |
+
)
|
142 |
+
|
143 |
+
# Do store per vertex UVs.
|
144 |
+
# This means we need to duplicate some vertices at the seams
|
145 |
+
individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
|
146 |
+
individual_faces = torch.arange(
|
147 |
+
individual_vertices.shape[0],
|
148 |
+
device=individual_vertices.device,
|
149 |
+
dtype=self.t_pos_idx.dtype,
|
150 |
+
).reshape(-1, 3)
|
151 |
+
uv_flat = uv[indices].reshape((-1, 2))
|
152 |
+
# uv_flat[:, 1] = 1 - uv_flat[:, 1]
|
153 |
+
|
154 |
+
self.v_pos = individual_vertices
|
155 |
+
self.t_pos_idx = individual_faces
|
156 |
+
self._v_tex = uv_flat
|
157 |
+
self._v_nrm = self._compute_vertex_normal()
|
158 |
+
self._v_tng = self._compute_vertex_tangent()
|
159 |
+
|
160 |
+
def _compute_edges(self):
|
161 |
+
# Compute edges
|
162 |
+
edges = torch.cat(
|
163 |
+
[
|
164 |
+
self.t_pos_idx[:, [0, 1]],
|
165 |
+
self.t_pos_idx[:, [1, 2]],
|
166 |
+
self.t_pos_idx[:, [2, 0]],
|
167 |
+
],
|
168 |
+
dim=0,
|
169 |
+
)
|
170 |
+
edges = edges.sort()[0]
|
171 |
+
edges = torch.unique(edges, dim=0)
|
172 |
+
return edges
|
sf3d/models/network.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Callable, List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from jaxtyping import Float
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.autograd import Function
|
11 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
12 |
+
|
13 |
+
from sf3d.models.utils import BaseModule, normalize
|
14 |
+
|
15 |
+
|
16 |
+
class PixelShuffleUpsampleNetwork(BaseModule):
|
17 |
+
@dataclass
|
18 |
+
class Config(BaseModule.Config):
|
19 |
+
in_channels: int = 1024
|
20 |
+
out_channels: int = 40
|
21 |
+
scale_factor: int = 4
|
22 |
+
|
23 |
+
conv_layers: int = 4
|
24 |
+
conv_kernel_size: int = 3
|
25 |
+
|
26 |
+
cfg: Config
|
27 |
+
|
28 |
+
def configure(self) -> None:
|
29 |
+
layers = []
|
30 |
+
output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
|
31 |
+
|
32 |
+
in_channels = self.cfg.in_channels
|
33 |
+
for i in range(self.cfg.conv_layers):
|
34 |
+
cur_out_channels = (
|
35 |
+
in_channels if i != self.cfg.conv_layers - 1 else output_channels
|
36 |
+
)
|
37 |
+
layers.append(
|
38 |
+
nn.Conv2d(
|
39 |
+
in_channels,
|
40 |
+
cur_out_channels,
|
41 |
+
self.cfg.conv_kernel_size,
|
42 |
+
padding=(self.cfg.conv_kernel_size - 1) // 2,
|
43 |
+
)
|
44 |
+
)
|
45 |
+
if i != self.cfg.conv_layers - 1:
|
46 |
+
layers.append(nn.ReLU(inplace=True))
|
47 |
+
|
48 |
+
layers.append(nn.PixelShuffle(self.cfg.scale_factor))
|
49 |
+
|
50 |
+
self.upsample = nn.Sequential(*layers)
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
|
54 |
+
) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
|
55 |
+
return rearrange(
|
56 |
+
self.upsample(
|
57 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
58 |
+
),
|
59 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
60 |
+
Np=3,
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
class _TruncExp(Function): # pylint: disable=abstract-method
|
65 |
+
# Implementation from torch-ngp:
|
66 |
+
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
|
67 |
+
@staticmethod
|
68 |
+
@custom_fwd(cast_inputs=torch.float32)
|
69 |
+
def forward(ctx, x): # pylint: disable=arguments-differ
|
70 |
+
ctx.save_for_backward(x)
|
71 |
+
return torch.exp(x)
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
@custom_bwd
|
75 |
+
def backward(ctx, g): # pylint: disable=arguments-differ
|
76 |
+
x = ctx.saved_tensors[0]
|
77 |
+
return g * torch.exp(torch.clamp(x, max=15))
|
78 |
+
|
79 |
+
|
80 |
+
trunc_exp = _TruncExp.apply
|
81 |
+
|
82 |
+
|
83 |
+
def get_activation(name) -> Callable:
|
84 |
+
if name is None:
|
85 |
+
return lambda x: x
|
86 |
+
name = name.lower()
|
87 |
+
if name == "none" or name == "linear" or name == "identity":
|
88 |
+
return lambda x: x
|
89 |
+
elif name == "lin2srgb":
|
90 |
+
return lambda x: torch.where(
|
91 |
+
x > 0.0031308,
|
92 |
+
torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
|
93 |
+
12.92 * x,
|
94 |
+
).clamp(0.0, 1.0)
|
95 |
+
elif name == "exp":
|
96 |
+
return lambda x: torch.exp(x)
|
97 |
+
elif name == "shifted_exp":
|
98 |
+
return lambda x: torch.exp(x - 1.0)
|
99 |
+
elif name == "trunc_exp":
|
100 |
+
return trunc_exp
|
101 |
+
elif name == "shifted_trunc_exp":
|
102 |
+
return lambda x: trunc_exp(x - 1.0)
|
103 |
+
elif name == "sigmoid":
|
104 |
+
return lambda x: torch.sigmoid(x)
|
105 |
+
elif name == "tanh":
|
106 |
+
return lambda x: torch.tanh(x)
|
107 |
+
elif name == "shifted_softplus":
|
108 |
+
return lambda x: F.softplus(x - 1.0)
|
109 |
+
elif name == "scale_-11_01":
|
110 |
+
return lambda x: x * 0.5 + 0.5
|
111 |
+
elif name == "negative":
|
112 |
+
return lambda x: -x
|
113 |
+
elif name == "normalize_channel_last":
|
114 |
+
return lambda x: normalize(x)
|
115 |
+
elif name == "normalize_channel_first":
|
116 |
+
return lambda x: normalize(x, dim=1)
|
117 |
+
else:
|
118 |
+
try:
|
119 |
+
return getattr(F, name)
|
120 |
+
except AttributeError:
|
121 |
+
raise ValueError(f"Unknown activation function: {name}")
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass
|
125 |
+
class HeadSpec:
|
126 |
+
name: str
|
127 |
+
out_channels: int
|
128 |
+
n_hidden_layers: int
|
129 |
+
output_activation: Optional[str] = None
|
130 |
+
out_bias: float = 0.0
|
131 |
+
|
132 |
+
|
133 |
+
class MaterialMLP(BaseModule):
|
134 |
+
@dataclass
|
135 |
+
class Config(BaseModule.Config):
|
136 |
+
in_channels: int = 120
|
137 |
+
n_neurons: int = 64
|
138 |
+
activation: str = "silu"
|
139 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
140 |
+
|
141 |
+
cfg: Config
|
142 |
+
|
143 |
+
def configure(self) -> None:
|
144 |
+
assert len(self.cfg.heads) > 0
|
145 |
+
heads = {}
|
146 |
+
for head in self.cfg.heads:
|
147 |
+
head_layers = []
|
148 |
+
for i in range(head.n_hidden_layers):
|
149 |
+
head_layers += [
|
150 |
+
nn.Linear(
|
151 |
+
self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
|
152 |
+
self.cfg.n_neurons,
|
153 |
+
),
|
154 |
+
self.make_activation(self.cfg.activation),
|
155 |
+
]
|
156 |
+
head_layers += [
|
157 |
+
nn.Linear(
|
158 |
+
self.cfg.n_neurons,
|
159 |
+
head.out_channels,
|
160 |
+
),
|
161 |
+
]
|
162 |
+
heads[head.name] = nn.Sequential(*head_layers)
|
163 |
+
self.heads = nn.ModuleDict(heads)
|
164 |
+
|
165 |
+
def make_activation(self, activation):
|
166 |
+
if activation == "relu":
|
167 |
+
return nn.ReLU(inplace=True)
|
168 |
+
elif activation == "silu":
|
169 |
+
return nn.SiLU(inplace=True)
|
170 |
+
else:
|
171 |
+
raise NotImplementedError
|
172 |
+
|
173 |
+
def keys(self):
|
174 |
+
return self.heads.keys()
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self, x, include: Optional[List] = None, exclude: Optional[List] = None
|
178 |
+
):
|
179 |
+
if include is not None and exclude is not None:
|
180 |
+
raise ValueError("Cannot specify both include and exclude.")
|
181 |
+
if include is not None:
|
182 |
+
heads = [h for h in self.cfg.heads if h.name in include]
|
183 |
+
elif exclude is not None:
|
184 |
+
heads = [h for h in self.cfg.heads if h.name not in exclude]
|
185 |
+
else:
|
186 |
+
heads = self.cfg.heads
|
187 |
+
|
188 |
+
out = {
|
189 |
+
head.name: get_activation(head.output_activation)(
|
190 |
+
self.heads[head.name](x) + head.out_bias
|
191 |
+
)
|
192 |
+
for head in heads
|
193 |
+
}
|
194 |
+
|
195 |
+
return out
|
sf3d/models/tokenizers/dinov2.py
ADDED
@@ -0,0 +1,1196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch DINOv2 model."""
|
16 |
+
|
17 |
+
import collections.abc
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
from transformers.activations import ACT2FN
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
+
BackboneOutput,
|
30 |
+
BaseModelOutput,
|
31 |
+
BaseModelOutputWithPooling,
|
32 |
+
ImageClassifierOutput,
|
33 |
+
)
|
34 |
+
from transformers.modeling_utils import PreTrainedModel
|
35 |
+
from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
|
36 |
+
from transformers.pytorch_utils import (
|
37 |
+
find_pruneable_heads_and_indices,
|
38 |
+
prune_linear_layer,
|
39 |
+
)
|
40 |
+
from transformers.utils import (
|
41 |
+
add_code_sample_docstrings,
|
42 |
+
add_start_docstrings,
|
43 |
+
add_start_docstrings_to_model_forward,
|
44 |
+
logging,
|
45 |
+
replace_return_docstrings,
|
46 |
+
)
|
47 |
+
from transformers.utils.backbone_utils import BackboneMixin
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
# General docstring
|
52 |
+
_CONFIG_FOR_DOC = "Dinov2Config"
|
53 |
+
|
54 |
+
# Base docstring
|
55 |
+
_CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
|
56 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
|
57 |
+
|
58 |
+
# Image classification docstring
|
59 |
+
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
|
60 |
+
|
61 |
+
|
62 |
+
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
63 |
+
"facebook/dinov2-base",
|
64 |
+
# See all DINOv2 models at https://huggingface.co/models?filter=dinov2
|
65 |
+
]
|
66 |
+
|
67 |
+
|
68 |
+
class Dinov2Embeddings(nn.Module):
|
69 |
+
"""
|
70 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, config: Dinov2Config) -> None:
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
77 |
+
# register as mask token as it's not used in optimization
|
78 |
+
# to avoid the use of find_unused_parameters_true
|
79 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
|
80 |
+
self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
|
81 |
+
self.patch_embeddings = Dinov2PatchEmbeddings(config)
|
82 |
+
num_patches = self.patch_embeddings.num_patches
|
83 |
+
self.position_embeddings = nn.Parameter(
|
84 |
+
torch.randn(1, num_patches + 1, config.hidden_size)
|
85 |
+
)
|
86 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
87 |
+
self.config = config
|
88 |
+
|
89 |
+
def interpolate_pos_encoding(
|
90 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
91 |
+
) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
94 |
+
resolution images.
|
95 |
+
|
96 |
+
Source:
|
97 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
98 |
+
"""
|
99 |
+
|
100 |
+
num_patches = embeddings.shape[1] - 1
|
101 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
102 |
+
if num_patches == num_positions and height == width:
|
103 |
+
return self.position_embeddings
|
104 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
105 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
106 |
+
dim = embeddings.shape[-1]
|
107 |
+
height = height // self.config.patch_size
|
108 |
+
width = width // self.config.patch_size
|
109 |
+
# we add a small number to avoid floating point error in the interpolation
|
110 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
111 |
+
height, width = height + 0.1, width + 0.1
|
112 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
113 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
114 |
+
)
|
115 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
116 |
+
patch_pos_embed = nn.functional.interpolate(
|
117 |
+
patch_pos_embed,
|
118 |
+
scale_factor=(
|
119 |
+
height / math.sqrt(num_positions),
|
120 |
+
width / math.sqrt(num_positions),
|
121 |
+
),
|
122 |
+
mode="bicubic",
|
123 |
+
align_corners=False,
|
124 |
+
)
|
125 |
+
if (
|
126 |
+
int(height) != patch_pos_embed.shape[-2]
|
127 |
+
or int(width) != patch_pos_embed.shape[-1]
|
128 |
+
):
|
129 |
+
raise ValueError(
|
130 |
+
"Width or height does not match with the interpolated position embeddings"
|
131 |
+
)
|
132 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
133 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self,
|
137 |
+
pixel_values: torch.Tensor,
|
138 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
139 |
+
) -> torch.Tensor:
|
140 |
+
batch_size, _, height, width = pixel_values.shape
|
141 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
142 |
+
embeddings = patch_embeddings
|
143 |
+
|
144 |
+
if bool_masked_pos is not None:
|
145 |
+
embeddings = torch.where(
|
146 |
+
bool_masked_pos.unsqueeze(-1),
|
147 |
+
self.mask_token.to(embeddings.dtype).unsqueeze(0),
|
148 |
+
embeddings,
|
149 |
+
)
|
150 |
+
|
151 |
+
# add the [CLS] token to the embedded patch tokens
|
152 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
153 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
154 |
+
|
155 |
+
# add positional encoding to each token
|
156 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
157 |
+
embeddings, height, width
|
158 |
+
)
|
159 |
+
|
160 |
+
embeddings = self.dropout(embeddings)
|
161 |
+
|
162 |
+
return embeddings
|
163 |
+
|
164 |
+
|
165 |
+
class Dinov2PatchEmbeddings(nn.Module):
|
166 |
+
"""
|
167 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
168 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
169 |
+
Transformer.
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(self, config):
|
173 |
+
super().__init__()
|
174 |
+
image_size, patch_size = config.image_size, config.patch_size
|
175 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
176 |
+
|
177 |
+
image_size = (
|
178 |
+
image_size
|
179 |
+
if isinstance(image_size, collections.abc.Iterable)
|
180 |
+
else (image_size, image_size)
|
181 |
+
)
|
182 |
+
patch_size = (
|
183 |
+
patch_size
|
184 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
185 |
+
else (patch_size, patch_size)
|
186 |
+
)
|
187 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
188 |
+
image_size[0] // patch_size[0]
|
189 |
+
)
|
190 |
+
self.image_size = image_size
|
191 |
+
self.patch_size = patch_size
|
192 |
+
self.num_channels = num_channels
|
193 |
+
self.num_patches = num_patches
|
194 |
+
|
195 |
+
self.projection = nn.Conv2d(
|
196 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
197 |
+
)
|
198 |
+
|
199 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
200 |
+
"""
|
201 |
+
num_channels = pixel_values.shape[1]
|
202 |
+
if num_channels != self.num_channels:
|
203 |
+
raise ValueError(
|
204 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
205 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
206 |
+
)
|
207 |
+
"""
|
208 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
209 |
+
return embeddings
|
210 |
+
|
211 |
+
|
212 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
|
213 |
+
class Dinov2SelfAttention(nn.Module):
|
214 |
+
def __init__(self, config: Dinov2Config) -> None:
|
215 |
+
super().__init__()
|
216 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
217 |
+
config, "embedding_size"
|
218 |
+
):
|
219 |
+
raise ValueError(
|
220 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
221 |
+
f"heads {config.num_attention_heads}."
|
222 |
+
)
|
223 |
+
|
224 |
+
self.num_attention_heads = config.num_attention_heads
|
225 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
226 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
227 |
+
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
228 |
+
|
229 |
+
self.query = nn.Linear(
|
230 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
231 |
+
)
|
232 |
+
self.key = nn.Linear(
|
233 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
234 |
+
)
|
235 |
+
self.value = nn.Linear(
|
236 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
237 |
+
)
|
238 |
+
|
239 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
240 |
+
|
241 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
242 |
+
new_x_shape = x.size()[:-1] + (
|
243 |
+
self.num_attention_heads,
|
244 |
+
self.attention_head_size,
|
245 |
+
)
|
246 |
+
x = x.view(new_x_shape)
|
247 |
+
return x.permute(0, 2, 1, 3)
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
hidden_states,
|
252 |
+
head_mask: Optional[torch.Tensor] = None,
|
253 |
+
output_attentions: bool = False,
|
254 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
255 |
+
mixed_query_layer = self.query(hidden_states)
|
256 |
+
|
257 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
258 |
+
assert head_mask is None and not output_attentions
|
259 |
+
new_size = hidden_states.size()[:-1] + (
|
260 |
+
self.num_attention_heads,
|
261 |
+
self.attention_head_size,
|
262 |
+
)
|
263 |
+
key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
|
264 |
+
value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
|
265 |
+
query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
|
266 |
+
context_layer = F.scaled_dot_product_attention(
|
267 |
+
query_layer,
|
268 |
+
key_layer,
|
269 |
+
value_layer,
|
270 |
+
dropout_p=self.attention_probs_dropout_prob,
|
271 |
+
is_causal=False,
|
272 |
+
)
|
273 |
+
context_layer = context_layer.transpose(1, 2).reshape(
|
274 |
+
*hidden_states.size()[:-1], -1
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
278 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
279 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
280 |
+
|
281 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
282 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
283 |
+
|
284 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
285 |
+
|
286 |
+
# Normalize the attention scores to probabilities.
|
287 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
288 |
+
|
289 |
+
# This is actually dropping out entire tokens to attend to, which might
|
290 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
291 |
+
attention_probs = self.dropout(attention_probs)
|
292 |
+
|
293 |
+
# Mask heads if we want to
|
294 |
+
if head_mask is not None:
|
295 |
+
attention_probs = attention_probs * head_mask
|
296 |
+
|
297 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
298 |
+
|
299 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
300 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
301 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
302 |
+
|
303 |
+
outputs = (
|
304 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
305 |
+
)
|
306 |
+
|
307 |
+
return outputs
|
308 |
+
|
309 |
+
|
310 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
|
311 |
+
class Dinov2SelfOutput(nn.Module):
|
312 |
+
"""
|
313 |
+
The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
|
314 |
+
layernorm applied before each block.
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(self, config: Dinov2Config) -> None:
|
318 |
+
super().__init__()
|
319 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
320 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
321 |
+
|
322 |
+
def forward(
|
323 |
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
324 |
+
) -> torch.Tensor:
|
325 |
+
hidden_states = self.dense(hidden_states)
|
326 |
+
hidden_states = self.dropout(hidden_states)
|
327 |
+
|
328 |
+
return hidden_states
|
329 |
+
|
330 |
+
|
331 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
|
332 |
+
class Dinov2Attention(nn.Module):
|
333 |
+
def __init__(self, config: Dinov2Config) -> None:
|
334 |
+
super().__init__()
|
335 |
+
self.attention = Dinov2SelfAttention(config)
|
336 |
+
self.output = Dinov2SelfOutput(config)
|
337 |
+
self.pruned_heads = set()
|
338 |
+
|
339 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
340 |
+
if len(heads) == 0:
|
341 |
+
return
|
342 |
+
heads, index = find_pruneable_heads_and_indices(
|
343 |
+
heads,
|
344 |
+
self.attention.num_attention_heads,
|
345 |
+
self.attention.attention_head_size,
|
346 |
+
self.pruned_heads,
|
347 |
+
)
|
348 |
+
|
349 |
+
# Prune linear layers
|
350 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
351 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
352 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
353 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
354 |
+
|
355 |
+
# Update hyper params and store pruned heads
|
356 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(
|
357 |
+
heads
|
358 |
+
)
|
359 |
+
self.attention.all_head_size = (
|
360 |
+
self.attention.attention_head_size * self.attention.num_attention_heads
|
361 |
+
)
|
362 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
hidden_states: torch.Tensor,
|
367 |
+
head_mask: Optional[torch.Tensor] = None,
|
368 |
+
output_attentions: bool = False,
|
369 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
370 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
371 |
+
|
372 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
373 |
+
|
374 |
+
outputs = (attention_output,) + self_outputs[
|
375 |
+
1:
|
376 |
+
] # add attentions if we output them
|
377 |
+
return outputs
|
378 |
+
|
379 |
+
|
380 |
+
class Dinov2LayerScale(nn.Module):
|
381 |
+
def __init__(self, config) -> None:
|
382 |
+
super().__init__()
|
383 |
+
self.lambda1 = nn.Parameter(
|
384 |
+
config.layerscale_value * torch.ones(config.hidden_size)
|
385 |
+
)
|
386 |
+
|
387 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
388 |
+
return hidden_state * self.lambda1
|
389 |
+
|
390 |
+
|
391 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
392 |
+
def drop_path(
|
393 |
+
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
|
394 |
+
) -> torch.Tensor:
|
395 |
+
"""
|
396 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
397 |
+
|
398 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
399 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
400 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
401 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
402 |
+
argument.
|
403 |
+
"""
|
404 |
+
if drop_prob == 0.0 or not training:
|
405 |
+
return input
|
406 |
+
keep_prob = 1 - drop_prob
|
407 |
+
shape = (input.shape[0],) + (1,) * (
|
408 |
+
input.ndim - 1
|
409 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
410 |
+
random_tensor = keep_prob + torch.rand(
|
411 |
+
shape, dtype=input.dtype, device=input.device
|
412 |
+
)
|
413 |
+
random_tensor.floor_() # binarize
|
414 |
+
output = input.div(keep_prob) * random_tensor
|
415 |
+
return output
|
416 |
+
|
417 |
+
|
418 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
|
419 |
+
class Dinov2DropPath(nn.Module):
|
420 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
421 |
+
|
422 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
423 |
+
super().__init__()
|
424 |
+
self.drop_prob = drop_prob
|
425 |
+
|
426 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
427 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
428 |
+
|
429 |
+
def extra_repr(self) -> str:
|
430 |
+
return "p={}".format(self.drop_prob)
|
431 |
+
|
432 |
+
|
433 |
+
class Dinov2MLP(nn.Module):
|
434 |
+
def __init__(self, config) -> None:
|
435 |
+
super().__init__()
|
436 |
+
in_features = out_features = config.hidden_size
|
437 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
438 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
|
439 |
+
if isinstance(config.hidden_act, str):
|
440 |
+
self.activation = ACT2FN[config.hidden_act]
|
441 |
+
else:
|
442 |
+
self.activation = config.hidden_act
|
443 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
|
444 |
+
|
445 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
446 |
+
hidden_state = self.fc1(hidden_state)
|
447 |
+
hidden_state = self.activation(hidden_state)
|
448 |
+
hidden_state = self.fc2(hidden_state)
|
449 |
+
return hidden_state
|
450 |
+
|
451 |
+
|
452 |
+
class Dinov2SwiGLUFFN(nn.Module):
|
453 |
+
def __init__(self, config) -> None:
|
454 |
+
super().__init__()
|
455 |
+
in_features = out_features = config.hidden_size
|
456 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
457 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
458 |
+
|
459 |
+
self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
|
460 |
+
self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
|
461 |
+
|
462 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
463 |
+
hidden_state = self.weights_in(hidden_state)
|
464 |
+
x1, x2 = hidden_state.chunk(2, dim=-1)
|
465 |
+
hidden = nn.functional.silu(x1) * x2
|
466 |
+
return self.weights_out(hidden)
|
467 |
+
|
468 |
+
|
469 |
+
class Dinov2Layer(nn.Module):
|
470 |
+
"""This corresponds to the Block class in the original implementation."""
|
471 |
+
|
472 |
+
def __init__(self, config: Dinov2Config) -> None:
|
473 |
+
super().__init__()
|
474 |
+
|
475 |
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
476 |
+
self.norm1_modulation = None
|
477 |
+
self.attention = Dinov2Attention(config)
|
478 |
+
self.layer_scale1 = Dinov2LayerScale(config)
|
479 |
+
self.drop_path1 = (
|
480 |
+
Dinov2DropPath(config.drop_path_rate)
|
481 |
+
if config.drop_path_rate > 0.0
|
482 |
+
else nn.Identity()
|
483 |
+
)
|
484 |
+
|
485 |
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
486 |
+
self.norm2_modulation = None
|
487 |
+
|
488 |
+
if config.use_swiglu_ffn:
|
489 |
+
self.mlp = Dinov2SwiGLUFFN(config)
|
490 |
+
else:
|
491 |
+
self.mlp = Dinov2MLP(config)
|
492 |
+
self.layer_scale2 = Dinov2LayerScale(config)
|
493 |
+
self.drop_path2 = (
|
494 |
+
Dinov2DropPath(config.drop_path_rate)
|
495 |
+
if config.drop_path_rate > 0.0
|
496 |
+
else nn.Identity()
|
497 |
+
)
|
498 |
+
|
499 |
+
def forward(
|
500 |
+
self,
|
501 |
+
hidden_states: torch.Tensor,
|
502 |
+
head_mask: Optional[torch.Tensor] = None,
|
503 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
504 |
+
output_attentions: bool = False,
|
505 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
506 |
+
hidden_states_norm = self.norm1(hidden_states)
|
507 |
+
if self.norm1_modulation is not None:
|
508 |
+
assert modulation_cond is not None
|
509 |
+
hidden_states_norm = self.norm1_modulation(
|
510 |
+
hidden_states_norm, modulation_cond
|
511 |
+
)
|
512 |
+
self_attention_outputs = self.attention(
|
513 |
+
hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
|
514 |
+
head_mask,
|
515 |
+
output_attentions=output_attentions,
|
516 |
+
)
|
517 |
+
attention_output = self_attention_outputs[0]
|
518 |
+
|
519 |
+
attention_output = self.layer_scale1(attention_output)
|
520 |
+
outputs = self_attention_outputs[
|
521 |
+
1:
|
522 |
+
] # add self attentions if we output attention weights
|
523 |
+
|
524 |
+
# first residual connection
|
525 |
+
hidden_states = attention_output + hidden_states
|
526 |
+
|
527 |
+
# in Dinov2, layernorm is also applied after self-attention
|
528 |
+
layer_output = self.norm2(hidden_states)
|
529 |
+
if self.norm2_modulation is not None:
|
530 |
+
assert modulation_cond is not None
|
531 |
+
layer_output = self.norm2_modulation(layer_output, modulation_cond)
|
532 |
+
layer_output = self.mlp(layer_output)
|
533 |
+
layer_output = self.layer_scale2(layer_output)
|
534 |
+
|
535 |
+
# second residual connection
|
536 |
+
layer_output = layer_output + hidden_states
|
537 |
+
|
538 |
+
outputs = (layer_output,) + outputs
|
539 |
+
|
540 |
+
return outputs
|
541 |
+
|
542 |
+
def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
|
543 |
+
self.norm1_modulation = norm1_mod
|
544 |
+
self.norm2_modulation = norm2_mod
|
545 |
+
|
546 |
+
|
547 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
|
548 |
+
class Dinov2Encoder(nn.Module):
|
549 |
+
def __init__(self, config: Dinov2Config) -> None:
|
550 |
+
super().__init__()
|
551 |
+
self.config = config
|
552 |
+
self.layer = nn.ModuleList(
|
553 |
+
[Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
|
554 |
+
)
|
555 |
+
self.gradient_checkpointing = False
|
556 |
+
|
557 |
+
def forward(
|
558 |
+
self,
|
559 |
+
hidden_states: torch.Tensor,
|
560 |
+
head_mask: Optional[torch.Tensor] = None,
|
561 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
562 |
+
output_attentions: bool = False,
|
563 |
+
output_hidden_states: bool = False,
|
564 |
+
return_dict: bool = True,
|
565 |
+
) -> Union[tuple, BaseModelOutput]:
|
566 |
+
all_hidden_states = () if output_hidden_states else None
|
567 |
+
all_self_attentions = () if output_attentions else None
|
568 |
+
|
569 |
+
for i, layer_module in enumerate(self.layer):
|
570 |
+
if output_hidden_states:
|
571 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
572 |
+
|
573 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
574 |
+
|
575 |
+
if self.gradient_checkpointing and self.training:
|
576 |
+
|
577 |
+
def create_custom_forward(module):
|
578 |
+
def custom_forward(*inputs):
|
579 |
+
return module(*inputs, output_attentions)
|
580 |
+
|
581 |
+
return custom_forward
|
582 |
+
|
583 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
584 |
+
create_custom_forward(layer_module),
|
585 |
+
hidden_states,
|
586 |
+
layer_head_mask,
|
587 |
+
modulation_cond,
|
588 |
+
use_reentrant=False,
|
589 |
+
)
|
590 |
+
else:
|
591 |
+
layer_outputs = layer_module(
|
592 |
+
hidden_states, layer_head_mask, modulation_cond, output_attentions
|
593 |
+
)
|
594 |
+
|
595 |
+
hidden_states = layer_outputs[0]
|
596 |
+
|
597 |
+
if output_attentions:
|
598 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
599 |
+
|
600 |
+
if output_hidden_states:
|
601 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
602 |
+
|
603 |
+
if not return_dict:
|
604 |
+
return tuple(
|
605 |
+
v
|
606 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
607 |
+
if v is not None
|
608 |
+
)
|
609 |
+
return BaseModelOutput(
|
610 |
+
last_hidden_state=hidden_states,
|
611 |
+
hidden_states=all_hidden_states,
|
612 |
+
attentions=all_self_attentions,
|
613 |
+
)
|
614 |
+
|
615 |
+
|
616 |
+
class Dinov2PreTrainedModel(PreTrainedModel):
|
617 |
+
"""
|
618 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
619 |
+
models.
|
620 |
+
"""
|
621 |
+
|
622 |
+
config_class = Dinov2Config
|
623 |
+
base_model_prefix = "dinov2"
|
624 |
+
main_input_name = "pixel_values"
|
625 |
+
supports_gradient_checkpointing = True
|
626 |
+
|
627 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
628 |
+
"""Initialize the weights"""
|
629 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
630 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
631 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
632 |
+
module.weight.data = nn.init.trunc_normal_(
|
633 |
+
module.weight.data.to(torch.float32),
|
634 |
+
mean=0.0,
|
635 |
+
std=self.config.initializer_range,
|
636 |
+
).to(module.weight.dtype)
|
637 |
+
if module.bias is not None:
|
638 |
+
module.bias.data.zero_()
|
639 |
+
elif isinstance(module, nn.LayerNorm):
|
640 |
+
module.bias.data.zero_()
|
641 |
+
module.weight.data.fill_(1.0)
|
642 |
+
elif isinstance(module, Dinov2Embeddings):
|
643 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
644 |
+
module.position_embeddings.data.to(torch.float32),
|
645 |
+
mean=0.0,
|
646 |
+
std=self.config.initializer_range,
|
647 |
+
).to(module.position_embeddings.dtype)
|
648 |
+
|
649 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
650 |
+
module.cls_token.data.to(torch.float32),
|
651 |
+
mean=0.0,
|
652 |
+
std=self.config.initializer_range,
|
653 |
+
).to(module.cls_token.dtype)
|
654 |
+
|
655 |
+
def _set_gradient_checkpointing(
|
656 |
+
self, module: Dinov2Encoder, value: bool = False
|
657 |
+
) -> None:
|
658 |
+
if isinstance(module, Dinov2Encoder):
|
659 |
+
module.gradient_checkpointing = value
|
660 |
+
|
661 |
+
|
662 |
+
DINOV2_START_DOCSTRING = r"""
|
663 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
664 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
665 |
+
behavior.
|
666 |
+
|
667 |
+
Parameters:
|
668 |
+
config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
|
669 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
670 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
671 |
+
"""
|
672 |
+
|
673 |
+
DINOV2_BASE_INPUTS_DOCSTRING = r"""
|
674 |
+
Args:
|
675 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
676 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
677 |
+
[`BitImageProcessor.preprocess`] for details.
|
678 |
+
|
679 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
|
680 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
|
681 |
+
pre-training.
|
682 |
+
|
683 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
684 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
685 |
+
|
686 |
+
- 1 indicates the head is **not masked**,
|
687 |
+
- 0 indicates the head is **masked**.
|
688 |
+
|
689 |
+
output_attentions (`bool`, *optional*):
|
690 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
691 |
+
tensors for more detail.
|
692 |
+
output_hidden_states (`bool`, *optional*):
|
693 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
694 |
+
more detail.
|
695 |
+
return_dict (`bool`, *optional*):
|
696 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
697 |
+
"""
|
698 |
+
|
699 |
+
DINOV2_INPUTS_DOCSTRING = r"""
|
700 |
+
Args:
|
701 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
702 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
703 |
+
[`BitImageProcessor.preprocess`] for details.
|
704 |
+
|
705 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
706 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
707 |
+
|
708 |
+
- 1 indicates the head is **not masked**,
|
709 |
+
- 0 indicates the head is **masked**.
|
710 |
+
|
711 |
+
output_attentions (`bool`, *optional*):
|
712 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
713 |
+
tensors for more detail.
|
714 |
+
output_hidden_states (`bool`, *optional*):
|
715 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
716 |
+
more detail.
|
717 |
+
return_dict (`bool`, *optional*):
|
718 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
719 |
+
"""
|
720 |
+
|
721 |
+
|
722 |
+
@dataclass
|
723 |
+
class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
|
724 |
+
patch_embeddings: Optional[torch.FloatTensor] = None
|
725 |
+
|
726 |
+
|
727 |
+
@add_start_docstrings(
|
728 |
+
"The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
|
729 |
+
DINOV2_START_DOCSTRING,
|
730 |
+
)
|
731 |
+
class Dinov2Model(Dinov2PreTrainedModel):
|
732 |
+
def __init__(self, config: Dinov2Config):
|
733 |
+
super().__init__(config)
|
734 |
+
self.config = config
|
735 |
+
|
736 |
+
self.embeddings = Dinov2Embeddings(config)
|
737 |
+
self.encoder = Dinov2Encoder(config)
|
738 |
+
|
739 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
740 |
+
|
741 |
+
# Initialize weights and apply final processing
|
742 |
+
self.post_init()
|
743 |
+
|
744 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
745 |
+
return self.embeddings.patch_embeddings
|
746 |
+
|
747 |
+
def expand_input_channels(self, extra_input_channels: int) -> None:
|
748 |
+
if extra_input_channels == 0:
|
749 |
+
return
|
750 |
+
conv_old = self.embeddings.patch_embeddings.projection
|
751 |
+
conv_new = nn.Conv2d(
|
752 |
+
self.config.num_channels + extra_input_channels,
|
753 |
+
self.config.hidden_size,
|
754 |
+
kernel_size=self.config.patch_size,
|
755 |
+
stride=self.config.patch_size,
|
756 |
+
).to(self.device)
|
757 |
+
with torch.no_grad():
|
758 |
+
conv_new.weight[:, :3] = conv_old.weight
|
759 |
+
conv_new.bias = conv_old.bias
|
760 |
+
self.embeddings.patch_embeddings.projection = conv_new
|
761 |
+
del conv_old
|
762 |
+
|
763 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
764 |
+
"""
|
765 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
766 |
+
class PreTrainedModel
|
767 |
+
"""
|
768 |
+
for layer, heads in heads_to_prune.items():
|
769 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
770 |
+
|
771 |
+
@add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
|
772 |
+
@add_code_sample_docstrings(
|
773 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
774 |
+
output_type=BaseModelOutputWithPooling,
|
775 |
+
config_class=_CONFIG_FOR_DOC,
|
776 |
+
modality="vision",
|
777 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
778 |
+
)
|
779 |
+
def forward(
|
780 |
+
self,
|
781 |
+
pixel_values: Optional[torch.Tensor] = None,
|
782 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
783 |
+
head_mask: Optional[torch.Tensor] = None,
|
784 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
785 |
+
output_attentions: Optional[bool] = None,
|
786 |
+
output_hidden_states: Optional[bool] = None,
|
787 |
+
return_dict: Optional[bool] = None,
|
788 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
789 |
+
output_attentions = (
|
790 |
+
output_attentions
|
791 |
+
if output_attentions is not None
|
792 |
+
else self.config.output_attentions
|
793 |
+
)
|
794 |
+
output_hidden_states = (
|
795 |
+
output_hidden_states
|
796 |
+
if output_hidden_states is not None
|
797 |
+
else self.config.output_hidden_states
|
798 |
+
)
|
799 |
+
return_dict = (
|
800 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
801 |
+
)
|
802 |
+
|
803 |
+
if pixel_values is None:
|
804 |
+
raise ValueError("You have to specify pixel_values")
|
805 |
+
|
806 |
+
# Prepare head mask if needed
|
807 |
+
# 1.0 in head_mask indicate we keep the head
|
808 |
+
# attention_probs has shape bsz x n_heads x N x N
|
809 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
810 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
811 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
812 |
+
|
813 |
+
embedding_output = self.embeddings(
|
814 |
+
pixel_values, bool_masked_pos=bool_masked_pos
|
815 |
+
)
|
816 |
+
|
817 |
+
encoder_outputs = self.encoder(
|
818 |
+
embedding_output,
|
819 |
+
head_mask=head_mask,
|
820 |
+
modulation_cond=modulation_cond,
|
821 |
+
output_attentions=output_attentions,
|
822 |
+
output_hidden_states=output_hidden_states,
|
823 |
+
return_dict=return_dict,
|
824 |
+
)
|
825 |
+
sequence_output = encoder_outputs[0]
|
826 |
+
sequence_output = self.layernorm(sequence_output)
|
827 |
+
pooled_output = sequence_output[:, 0, :]
|
828 |
+
|
829 |
+
if not return_dict:
|
830 |
+
head_outputs = (sequence_output, pooled_output)
|
831 |
+
return head_outputs + encoder_outputs[1:]
|
832 |
+
|
833 |
+
return CustomBaseModelOutputWithPooling(
|
834 |
+
last_hidden_state=sequence_output,
|
835 |
+
pooler_output=pooled_output,
|
836 |
+
hidden_states=encoder_outputs.hidden_states,
|
837 |
+
attentions=encoder_outputs.attentions,
|
838 |
+
patch_embeddings=embedding_output,
|
839 |
+
)
|
840 |
+
|
841 |
+
def set_gradient_checkpointing(self, value: bool = False) -> None:
|
842 |
+
self._set_gradient_checkpointing(self.encoder, value)
|
843 |
+
|
844 |
+
|
845 |
+
@add_start_docstrings(
|
846 |
+
"""
|
847 |
+
Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
|
848 |
+
of the [CLS] token) e.g. for ImageNet.
|
849 |
+
""",
|
850 |
+
DINOV2_START_DOCSTRING,
|
851 |
+
)
|
852 |
+
class Dinov2ForImageClassification(Dinov2PreTrainedModel):
|
853 |
+
def __init__(self, config: Dinov2Config) -> None:
|
854 |
+
super().__init__(config)
|
855 |
+
|
856 |
+
self.num_labels = config.num_labels
|
857 |
+
self.dinov2 = Dinov2Model(config)
|
858 |
+
|
859 |
+
# Classifier head
|
860 |
+
self.classifier = (
|
861 |
+
nn.Linear(config.hidden_size * 2, config.num_labels)
|
862 |
+
if config.num_labels > 0
|
863 |
+
else nn.Identity()
|
864 |
+
)
|
865 |
+
|
866 |
+
# Initialize weights and apply final processing
|
867 |
+
self.post_init()
|
868 |
+
|
869 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
870 |
+
@add_code_sample_docstrings(
|
871 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
872 |
+
output_type=ImageClassifierOutput,
|
873 |
+
config_class=_CONFIG_FOR_DOC,
|
874 |
+
)
|
875 |
+
def forward(
|
876 |
+
self,
|
877 |
+
pixel_values: Optional[torch.Tensor] = None,
|
878 |
+
head_mask: Optional[torch.Tensor] = None,
|
879 |
+
labels: Optional[torch.Tensor] = None,
|
880 |
+
output_attentions: Optional[bool] = None,
|
881 |
+
output_hidden_states: Optional[bool] = None,
|
882 |
+
return_dict: Optional[bool] = None,
|
883 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
884 |
+
r"""
|
885 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
886 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
887 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
888 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
889 |
+
"""
|
890 |
+
return_dict = (
|
891 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
892 |
+
)
|
893 |
+
|
894 |
+
outputs = self.dinov2(
|
895 |
+
pixel_values,
|
896 |
+
head_mask=head_mask,
|
897 |
+
output_attentions=output_attentions,
|
898 |
+
output_hidden_states=output_hidden_states,
|
899 |
+
return_dict=return_dict,
|
900 |
+
)
|
901 |
+
|
902 |
+
sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
|
903 |
+
|
904 |
+
cls_token = sequence_output[:, 0]
|
905 |
+
patch_tokens = sequence_output[:, 1:]
|
906 |
+
|
907 |
+
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
|
908 |
+
|
909 |
+
logits = self.classifier(linear_input)
|
910 |
+
|
911 |
+
loss = None
|
912 |
+
if labels is not None:
|
913 |
+
# move labels to correct device to enable model parallelism
|
914 |
+
labels = labels.to(logits.device)
|
915 |
+
if self.config.problem_type is None:
|
916 |
+
if self.num_labels == 1:
|
917 |
+
self.config.problem_type = "regression"
|
918 |
+
elif self.num_labels > 1 and (
|
919 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
920 |
+
):
|
921 |
+
self.config.problem_type = "single_label_classification"
|
922 |
+
else:
|
923 |
+
self.config.problem_type = "multi_label_classification"
|
924 |
+
|
925 |
+
if self.config.problem_type == "regression":
|
926 |
+
loss_fct = MSELoss()
|
927 |
+
if self.num_labels == 1:
|
928 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
929 |
+
else:
|
930 |
+
loss = loss_fct(logits, labels)
|
931 |
+
elif self.config.problem_type == "single_label_classification":
|
932 |
+
loss_fct = CrossEntropyLoss()
|
933 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
934 |
+
elif self.config.problem_type == "multi_label_classification":
|
935 |
+
loss_fct = BCEWithLogitsLoss()
|
936 |
+
loss = loss_fct(logits, labels)
|
937 |
+
|
938 |
+
if not return_dict:
|
939 |
+
output = (logits,) + outputs[2:]
|
940 |
+
return ((loss,) + output) if loss is not None else output
|
941 |
+
|
942 |
+
return ImageClassifierOutput(
|
943 |
+
loss=loss,
|
944 |
+
logits=logits,
|
945 |
+
hidden_states=outputs.hidden_states,
|
946 |
+
attentions=outputs.attentions,
|
947 |
+
)
|
948 |
+
|
949 |
+
|
950 |
+
@add_start_docstrings(
|
951 |
+
"""
|
952 |
+
Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
|
953 |
+
""",
|
954 |
+
DINOV2_START_DOCSTRING,
|
955 |
+
)
|
956 |
+
class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
|
957 |
+
def __init__(self, config):
|
958 |
+
super().__init__(config)
|
959 |
+
super()._init_backbone(config)
|
960 |
+
|
961 |
+
self.num_features = [
|
962 |
+
config.hidden_size for _ in range(config.num_hidden_layers + 1)
|
963 |
+
]
|
964 |
+
self.embeddings = Dinov2Embeddings(config)
|
965 |
+
self.encoder = Dinov2Encoder(config)
|
966 |
+
|
967 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
968 |
+
|
969 |
+
# Initialize weights and apply final processing
|
970 |
+
self.post_init()
|
971 |
+
|
972 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
973 |
+
return self.embeddings.patch_embeddings
|
974 |
+
|
975 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
976 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
977 |
+
def forward(
|
978 |
+
self,
|
979 |
+
pixel_values: torch.Tensor,
|
980 |
+
output_hidden_states: Optional[bool] = None,
|
981 |
+
output_attentions: Optional[bool] = None,
|
982 |
+
return_dict: Optional[bool] = None,
|
983 |
+
) -> BackboneOutput:
|
984 |
+
"""
|
985 |
+
Returns:
|
986 |
+
|
987 |
+
Examples:
|
988 |
+
|
989 |
+
```python
|
990 |
+
>>> from transformers import AutoImageProcessor, AutoBackbone
|
991 |
+
>>> import torch
|
992 |
+
>>> from PIL import Image
|
993 |
+
>>> import requests
|
994 |
+
|
995 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
996 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
997 |
+
|
998 |
+
>>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
999 |
+
>>> model = AutoBackbone.from_pretrained(
|
1000 |
+
... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
|
1001 |
+
... )
|
1002 |
+
|
1003 |
+
>>> inputs = processor(image, return_tensors="pt")
|
1004 |
+
|
1005 |
+
>>> outputs = model(**inputs)
|
1006 |
+
>>> feature_maps = outputs.feature_maps
|
1007 |
+
>>> list(feature_maps[-1].shape)
|
1008 |
+
[1, 768, 16, 16]
|
1009 |
+
```"""
|
1010 |
+
return_dict = (
|
1011 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1012 |
+
)
|
1013 |
+
output_hidden_states = (
|
1014 |
+
output_hidden_states
|
1015 |
+
if output_hidden_states is not None
|
1016 |
+
else self.config.output_hidden_states
|
1017 |
+
)
|
1018 |
+
output_attentions = (
|
1019 |
+
output_attentions
|
1020 |
+
if output_attentions is not None
|
1021 |
+
else self.config.output_attentions
|
1022 |
+
)
|
1023 |
+
|
1024 |
+
embedding_output = self.embeddings(pixel_values)
|
1025 |
+
|
1026 |
+
outputs = self.encoder(
|
1027 |
+
embedding_output,
|
1028 |
+
output_hidden_states=True,
|
1029 |
+
output_attentions=output_attentions,
|
1030 |
+
return_dict=return_dict,
|
1031 |
+
)
|
1032 |
+
|
1033 |
+
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
1034 |
+
|
1035 |
+
feature_maps = ()
|
1036 |
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
1037 |
+
if stage in self.out_features:
|
1038 |
+
if self.config.apply_layernorm:
|
1039 |
+
hidden_state = self.layernorm(hidden_state)
|
1040 |
+
if self.config.reshape_hidden_states:
|
1041 |
+
batch_size, _, height, width = pixel_values.shape
|
1042 |
+
patch_size = self.config.patch_size
|
1043 |
+
hidden_state = hidden_state[:, 1:, :].reshape(
|
1044 |
+
batch_size, width // patch_size, height // patch_size, -1
|
1045 |
+
)
|
1046 |
+
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
1047 |
+
feature_maps += (hidden_state,)
|
1048 |
+
|
1049 |
+
if not return_dict:
|
1050 |
+
if output_hidden_states:
|
1051 |
+
output = (feature_maps,) + outputs[1:]
|
1052 |
+
else:
|
1053 |
+
output = (feature_maps,) + outputs[2:]
|
1054 |
+
return output
|
1055 |
+
|
1056 |
+
return BackboneOutput(
|
1057 |
+
feature_maps=feature_maps,
|
1058 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
1059 |
+
attentions=outputs.attentions if output_attentions else None,
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
|
1063 |
+
class CustomPatchEmbeddings(nn.Module):
|
1064 |
+
"""
|
1065 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
1066 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
1067 |
+
Transformer.
|
1068 |
+
"""
|
1069 |
+
|
1070 |
+
def __init__(
|
1071 |
+
self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
|
1072 |
+
):
|
1073 |
+
super().__init__()
|
1074 |
+
|
1075 |
+
image_size = (
|
1076 |
+
image_size
|
1077 |
+
if isinstance(image_size, collections.abc.Iterable)
|
1078 |
+
else (image_size, image_size)
|
1079 |
+
)
|
1080 |
+
patch_size = (
|
1081 |
+
patch_size
|
1082 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
1083 |
+
else (patch_size, patch_size)
|
1084 |
+
)
|
1085 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
1086 |
+
image_size[0] // patch_size[0]
|
1087 |
+
)
|
1088 |
+
self.image_size = image_size
|
1089 |
+
self.patch_size = patch_size
|
1090 |
+
self.num_channels = num_channels
|
1091 |
+
self.num_patches = num_patches
|
1092 |
+
|
1093 |
+
self.projection = nn.Conv2d(
|
1094 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
1098 |
+
num_channels = pixel_values.shape[1]
|
1099 |
+
if num_channels != self.num_channels:
|
1100 |
+
raise ValueError(
|
1101 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
1102 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
1103 |
+
)
|
1104 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
1105 |
+
return embeddings
|
1106 |
+
|
1107 |
+
|
1108 |
+
class CustomEmbeddings(nn.Module):
|
1109 |
+
"""
|
1110 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
1111 |
+
"""
|
1112 |
+
|
1113 |
+
def __init__(
|
1114 |
+
self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
|
1115 |
+
) -> None:
|
1116 |
+
super().__init__()
|
1117 |
+
|
1118 |
+
self.image_size = image_size
|
1119 |
+
self.patch_size = patch_size
|
1120 |
+
self.num_channels = num_channels
|
1121 |
+
self.hidden_size = hidden_size
|
1122 |
+
|
1123 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
|
1124 |
+
|
1125 |
+
self.patch_embeddings = CustomPatchEmbeddings(
|
1126 |
+
image_size, patch_size, num_channels, hidden_size
|
1127 |
+
)
|
1128 |
+
num_patches = self.patch_embeddings.num_patches
|
1129 |
+
self.position_embeddings = nn.Parameter(
|
1130 |
+
torch.randn(1, num_patches + 1, self.hidden_size)
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
def interpolate_pos_encoding(
|
1134 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
1135 |
+
) -> torch.Tensor:
|
1136 |
+
"""
|
1137 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
1138 |
+
resolution images.
|
1139 |
+
|
1140 |
+
Source:
|
1141 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
1142 |
+
"""
|
1143 |
+
|
1144 |
+
num_patches = embeddings.shape[1] - 1
|
1145 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
1146 |
+
if num_patches == num_positions and height == width:
|
1147 |
+
return self.position_embeddings
|
1148 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
1149 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
1150 |
+
dim = embeddings.shape[-1]
|
1151 |
+
height = height // self.patch_size
|
1152 |
+
width = width // self.patch_size
|
1153 |
+
# we add a small number to avoid floating point error in the interpolation
|
1154 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
1155 |
+
height, width = height + 0.1, width + 0.1
|
1156 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
1157 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
1158 |
+
)
|
1159 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
1160 |
+
patch_pos_embed = nn.functional.interpolate(
|
1161 |
+
patch_pos_embed,
|
1162 |
+
scale_factor=(
|
1163 |
+
height / math.sqrt(num_positions),
|
1164 |
+
width / math.sqrt(num_positions),
|
1165 |
+
),
|
1166 |
+
mode="bicubic",
|
1167 |
+
align_corners=False,
|
1168 |
+
)
|
1169 |
+
if (
|
1170 |
+
int(height) != patch_pos_embed.shape[-2]
|
1171 |
+
or int(width) != patch_pos_embed.shape[-1]
|
1172 |
+
):
|
1173 |
+
raise ValueError(
|
1174 |
+
"Width or height does not match with the interpolated position embeddings"
|
1175 |
+
)
|
1176 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
1177 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
1178 |
+
|
1179 |
+
def forward(
|
1180 |
+
self,
|
1181 |
+
pixel_values: torch.Tensor,
|
1182 |
+
) -> torch.Tensor:
|
1183 |
+
batch_size, _, height, width = pixel_values.shape
|
1184 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
1185 |
+
embeddings = patch_embeddings
|
1186 |
+
|
1187 |
+
# add the [CLS] token to the embedded patch tokens
|
1188 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
1189 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
1190 |
+
|
1191 |
+
# add positional encoding to each token
|
1192 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
1193 |
+
embeddings, height, width
|
1194 |
+
)
|
1195 |
+
|
1196 |
+
return embeddings
|
sf3d/models/tokenizers/image.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from sf3d.models.tokenizers.dinov2 import Dinov2Model
|
11 |
+
from sf3d.models.transformers.attention import Modulation
|
12 |
+
from sf3d.models.utils import BaseModule
|
13 |
+
|
14 |
+
|
15 |
+
class DINOV2SingleImageTokenizer(BaseModule):
|
16 |
+
@dataclass
|
17 |
+
class Config(BaseModule.Config):
|
18 |
+
pretrained_model_name_or_path: str = "facebook/dinov2-large"
|
19 |
+
width: int = 512
|
20 |
+
height: int = 512
|
21 |
+
modulation_cond_dim: int = 768
|
22 |
+
|
23 |
+
cfg: Config
|
24 |
+
|
25 |
+
def configure(self) -> None:
|
26 |
+
self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
|
27 |
+
|
28 |
+
for p in self.model.parameters():
|
29 |
+
p.requires_grad_(False)
|
30 |
+
self.model.eval()
|
31 |
+
|
32 |
+
self.model.set_gradient_checkpointing(False)
|
33 |
+
|
34 |
+
# add modulation
|
35 |
+
modulations = []
|
36 |
+
for layer in self.model.encoder.layer:
|
37 |
+
norm1_modulation = Modulation(
|
38 |
+
self.model.config.hidden_size,
|
39 |
+
self.cfg.modulation_cond_dim,
|
40 |
+
zero_init=True,
|
41 |
+
single_layer=True,
|
42 |
+
)
|
43 |
+
norm2_modulation = Modulation(
|
44 |
+
self.model.config.hidden_size,
|
45 |
+
self.cfg.modulation_cond_dim,
|
46 |
+
zero_init=True,
|
47 |
+
single_layer=True,
|
48 |
+
)
|
49 |
+
layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
|
50 |
+
modulations += [norm1_modulation, norm2_modulation]
|
51 |
+
self.modulations = nn.ModuleList(modulations)
|
52 |
+
|
53 |
+
self.register_buffer(
|
54 |
+
"image_mean",
|
55 |
+
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
|
56 |
+
persistent=False,
|
57 |
+
)
|
58 |
+
self.register_buffer(
|
59 |
+
"image_std",
|
60 |
+
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
|
61 |
+
persistent=False,
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
images: Float[Tensor, "B *N C H W"],
|
67 |
+
modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
|
68 |
+
**kwargs,
|
69 |
+
) -> Float[Tensor, "B *N Ct Nt"]:
|
70 |
+
model = self.model
|
71 |
+
|
72 |
+
packed = False
|
73 |
+
if images.ndim == 4:
|
74 |
+
packed = True
|
75 |
+
images = images.unsqueeze(1)
|
76 |
+
if modulation_cond is not None:
|
77 |
+
assert modulation_cond.ndim == 2
|
78 |
+
modulation_cond = modulation_cond.unsqueeze(1)
|
79 |
+
|
80 |
+
batch_size, n_input_views = images.shape[:2]
|
81 |
+
images = (images - self.image_mean) / self.image_std
|
82 |
+
out = model(
|
83 |
+
rearrange(images, "B N C H W -> (B N) C H W"),
|
84 |
+
modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
|
85 |
+
if modulation_cond is not None
|
86 |
+
else None,
|
87 |
+
)
|
88 |
+
local_features = out.last_hidden_state
|
89 |
+
local_features = local_features.permute(0, 2, 1)
|
90 |
+
local_features = rearrange(
|
91 |
+
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
|
92 |
+
)
|
93 |
+
if packed:
|
94 |
+
local_features = local_features.squeeze(1)
|
95 |
+
|
96 |
+
return local_features
|
97 |
+
|
98 |
+
def detokenize(self, *args, **kwargs):
|
99 |
+
raise NotImplementedError
|
sf3d/models/tokenizers/triplane.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from sf3d.models.utils import BaseModule
|
11 |
+
|
12 |
+
|
13 |
+
class TriplaneLearnablePositionalEmbedding(BaseModule):
|
14 |
+
@dataclass
|
15 |
+
class Config(BaseModule.Config):
|
16 |
+
plane_size: int = 96
|
17 |
+
num_channels: int = 1024
|
18 |
+
|
19 |
+
cfg: Config
|
20 |
+
|
21 |
+
def configure(self) -> None:
|
22 |
+
self.embeddings = nn.Parameter(
|
23 |
+
torch.randn(
|
24 |
+
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
|
25 |
+
dtype=torch.float32,
|
26 |
+
)
|
27 |
+
* 1
|
28 |
+
/ math.sqrt(self.cfg.num_channels)
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
|
32 |
+
return rearrange(
|
33 |
+
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
|
34 |
+
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
|
35 |
+
)
|
36 |
+
|
37 |
+
def detokenize(
|
38 |
+
self, tokens: Float[Tensor, "B Ct Nt"]
|
39 |
+
) -> Float[Tensor, "B 3 Ct Hp Wp"]:
|
40 |
+
batch_size, Ct, Nt = tokens.shape
|
41 |
+
assert Nt == self.cfg.plane_size**2 * 3
|
42 |
+
assert Ct == self.cfg.num_channels
|
43 |
+
return rearrange(
|
44 |
+
tokens,
|
45 |
+
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
|
46 |
+
Np=3,
|
47 |
+
Hp=self.cfg.plane_size,
|
48 |
+
Wp=self.cfg.plane_size,
|
49 |
+
)
|
sf3d/models/transformers/attention.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class Modulation(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
embedding_dim: int,
|
9 |
+
condition_dim: int,
|
10 |
+
zero_init: bool = False,
|
11 |
+
single_layer: bool = False,
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.silu = nn.SiLU()
|
15 |
+
if single_layer:
|
16 |
+
self.linear1 = nn.Identity()
|
17 |
+
else:
|
18 |
+
self.linear1 = nn.Linear(condition_dim, condition_dim)
|
19 |
+
|
20 |
+
self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
|
21 |
+
|
22 |
+
# Only zero init the last linear layer
|
23 |
+
if zero_init:
|
24 |
+
nn.init.zeros_(self.linear2.weight)
|
25 |
+
nn.init.zeros_(self.linear2.bias)
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
|
28 |
+
emb = self.linear2(self.silu(self.linear1(condition)))
|
29 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
30 |
+
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
31 |
+
return x
|
sf3d/models/transformers/backbone.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from sf3d.models.utils import BaseModule
|
9 |
+
|
10 |
+
|
11 |
+
class GEGLU(nn.Module):
|
12 |
+
r"""
|
13 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
dim_in (`int`): The number of channels in the input.
|
17 |
+
dim_out (`int`): The number of channels in the output.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, dim_in: int, dim_out: int):
|
21 |
+
super().__init__()
|
22 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
23 |
+
|
24 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
25 |
+
if gate.device.type != "mps":
|
26 |
+
return F.gelu(gate)
|
27 |
+
# mps: gelu is not implemented for float16
|
28 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
29 |
+
|
30 |
+
def forward(self, hidden_states, scale: float = 1.0):
|
31 |
+
args = ()
|
32 |
+
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
33 |
+
return hidden_states * self.gelu(gate)
|
34 |
+
|
35 |
+
|
36 |
+
class CrossAttention(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim,
|
40 |
+
kv_dim=None,
|
41 |
+
num_heads=16,
|
42 |
+
qkv_bias=False,
|
43 |
+
attn_drop=0.0,
|
44 |
+
proj_drop=0.0,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
self.scale = head_dim**-0.5
|
50 |
+
kv_dim = dim if not kv_dim else kv_dim
|
51 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
52 |
+
self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
|
53 |
+
self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
|
54 |
+
self.attn_drop = attn_drop
|
55 |
+
self.proj = nn.Linear(dim, dim)
|
56 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
57 |
+
|
58 |
+
def forward(self, x_q, x_kv):
|
59 |
+
B, N_q, C = x_q.shape
|
60 |
+
B, N_kv, _ = x_kv.shape
|
61 |
+
# [B, N_q, C] -> [B, N_q, H, C/H]
|
62 |
+
q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads)
|
63 |
+
# [B, N_kv, C] -> [B, N_kv, H, C/H]
|
64 |
+
k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
|
65 |
+
v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
|
66 |
+
|
67 |
+
# attention
|
68 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
69 |
+
q.permute(0, 2, 1, 3),
|
70 |
+
k.permute(0, 2, 1, 3),
|
71 |
+
v.permute(0, 2, 1, 3),
|
72 |
+
attn_mask=None,
|
73 |
+
dropout_p=self.attn_drop,
|
74 |
+
scale=self.scale,
|
75 |
+
).permute(0, 2, 1, 3)
|
76 |
+
|
77 |
+
# [B, N_q, H, C/H] -> [B, N_q, C]
|
78 |
+
x = x.reshape(B, N_q, C)
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
84 |
+
class FeedForward(nn.Module):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
dim: int,
|
88 |
+
dim_out: Optional[int] = None,
|
89 |
+
mult: int = 4,
|
90 |
+
dropout: float = 0.0,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
inner_dim = int(dim * mult)
|
94 |
+
dim_out = dim_out if dim_out is not None else dim
|
95 |
+
act_fn = GEGLU(dim, inner_dim)
|
96 |
+
self.net = nn.ModuleList([])
|
97 |
+
self.net.append(act_fn)
|
98 |
+
self.net.append(nn.Dropout(dropout))
|
99 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
100 |
+
|
101 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
102 |
+
for module in self.net:
|
103 |
+
x = module(x)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class BasicBlock(nn.Module):
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
dim: int,
|
111 |
+
kv_dim: Optional[int] = None,
|
112 |
+
num_heads: int = 16,
|
113 |
+
qkv_bias: bool = False,
|
114 |
+
attn_drop: float = 0.0,
|
115 |
+
proj_drop: float = 0.0,
|
116 |
+
ff_drop: float = 0.0,
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.norm1 = nn.LayerNorm(dim)
|
120 |
+
self.attn1 = CrossAttention(
|
121 |
+
dim,
|
122 |
+
kv_dim=dim,
|
123 |
+
num_heads=num_heads,
|
124 |
+
qkv_bias=qkv_bias,
|
125 |
+
attn_drop=attn_drop,
|
126 |
+
proj_drop=proj_drop,
|
127 |
+
)
|
128 |
+
self.norm2 = nn.LayerNorm(dim)
|
129 |
+
self.attn2 = CrossAttention(
|
130 |
+
dim,
|
131 |
+
kv_dim=kv_dim,
|
132 |
+
num_heads=num_heads,
|
133 |
+
qkv_bias=qkv_bias,
|
134 |
+
attn_drop=attn_drop,
|
135 |
+
proj_drop=proj_drop,
|
136 |
+
)
|
137 |
+
self.norm3 = nn.LayerNorm(dim)
|
138 |
+
self.ff = FeedForward(dim, dropout=ff_drop)
|
139 |
+
|
140 |
+
def forward(self, z, x):
|
141 |
+
z_norm = self.norm1(z)
|
142 |
+
z = z + self.attn1(z_norm, z_norm)
|
143 |
+
# TODO: do we need to have the second attention when x is None?
|
144 |
+
z_norm = self.norm2(z)
|
145 |
+
z = z + self.attn2(z_norm, x if x is not None else z_norm)
|
146 |
+
z_norm = self.norm3(z)
|
147 |
+
z = z + self.ff(z_norm)
|
148 |
+
return z
|
149 |
+
|
150 |
+
|
151 |
+
class SingleStreamTransformer(BaseModule):
|
152 |
+
@dataclass
|
153 |
+
class Config(BaseModule.Config):
|
154 |
+
num_attention_heads: int = 16
|
155 |
+
attention_head_dim: int = 88
|
156 |
+
in_channels: Optional[int] = None
|
157 |
+
out_channels: Optional[int] = None
|
158 |
+
num_layers: int = 16
|
159 |
+
dropout: float = 0.0
|
160 |
+
norm_num_groups: int = 32
|
161 |
+
cross_attention_dim: Optional[int] = None
|
162 |
+
attention_bias: bool = False
|
163 |
+
|
164 |
+
cfg: Config
|
165 |
+
|
166 |
+
def configure(self) -> None:
|
167 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
168 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
169 |
+
inner_dim = self.num_attention_heads * self.attention_head_dim
|
170 |
+
|
171 |
+
# Define input layers
|
172 |
+
self.norm = torch.nn.GroupNorm(
|
173 |
+
num_groups=self.cfg.norm_num_groups,
|
174 |
+
num_channels=self.cfg.in_channels,
|
175 |
+
eps=1e-6,
|
176 |
+
affine=True,
|
177 |
+
)
|
178 |
+
self.proj_in = nn.Linear(self.cfg.in_channels, inner_dim)
|
179 |
+
|
180 |
+
# Define transformers blocks
|
181 |
+
self.transformer_blocks = nn.ModuleList(
|
182 |
+
[
|
183 |
+
BasicBlock(
|
184 |
+
inner_dim,
|
185 |
+
kv_dim=self.cfg.cross_attention_dim,
|
186 |
+
num_heads=self.num_attention_heads,
|
187 |
+
qkv_bias=self.cfg.attention_bias,
|
188 |
+
proj_drop=self.cfg.dropout,
|
189 |
+
ff_drop=self.cfg.dropout,
|
190 |
+
)
|
191 |
+
for d in range(self.cfg.num_layers)
|
192 |
+
]
|
193 |
+
)
|
194 |
+
|
195 |
+
# 4. Define output layers
|
196 |
+
self.proj_out = nn.Linear(inner_dim, self.cfg.in_channels)
|
197 |
+
|
198 |
+
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
199 |
+
residual = hidden_states
|
200 |
+
hidden_states = self.norm(hidden_states)
|
201 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
202 |
+
hidden_states = self.proj_in(hidden_states)
|
203 |
+
for block in self.transformer_blocks:
|
204 |
+
hidden_states = block(hidden_states, encoder_hidden_states)
|
205 |
+
hidden_states = self.proj_out(hidden_states).permute(0, 2, 1).contiguous()
|
206 |
+
# TODO: do we really need to add the residual?
|
207 |
+
hidden_states = hidden_states + residual
|
208 |
+
return hidden_states
|
209 |
+
|
210 |
+
|
211 |
+
class FuseBlock(nn.Module):
|
212 |
+
"""
|
213 |
+
Fuse X in to Z with cross attention
|
214 |
+
"""
|
215 |
+
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
dim_z: int,
|
219 |
+
dim_x: int,
|
220 |
+
num_heads: int = 16,
|
221 |
+
qkv_bias: bool = False,
|
222 |
+
attn_drop: float = 0.0,
|
223 |
+
proj_drop: float = 0.0,
|
224 |
+
ff_drop: float = 0.0,
|
225 |
+
norm_x_input: bool = True,
|
226 |
+
):
|
227 |
+
super().__init__()
|
228 |
+
self.norm_x_input = norm_x_input
|
229 |
+
if self.norm_x_input:
|
230 |
+
self.norm_x = nn.LayerNorm(dim_x)
|
231 |
+
self.attn = CrossAttention(
|
232 |
+
dim_z,
|
233 |
+
kv_dim=dim_x,
|
234 |
+
num_heads=num_heads,
|
235 |
+
qkv_bias=qkv_bias,
|
236 |
+
attn_drop=attn_drop,
|
237 |
+
proj_drop=proj_drop,
|
238 |
+
)
|
239 |
+
self.norm_z1 = nn.LayerNorm(dim_z)
|
240 |
+
self.norm_z2 = nn.LayerNorm(dim_z)
|
241 |
+
self.ff = FeedForward(dim_z, dropout=ff_drop)
|
242 |
+
|
243 |
+
def forward(self, z, x):
|
244 |
+
# TODO: do we need to normalize x?
|
245 |
+
z = z + self.attn(self.norm_z1(z), self.norm_x(x) if self.norm_x_input else x)
|
246 |
+
z = z + self.ff(self.norm_z2(z))
|
247 |
+
return z
|
248 |
+
|
249 |
+
|
250 |
+
@torch.no_grad()
|
251 |
+
def get_triplane_attention_mask(res):
|
252 |
+
N = 3 * res * res
|
253 |
+
attn_mask = torch.zeros(3, res, res, 3, res, res)
|
254 |
+
|
255 |
+
i, j = torch.meshgrid(torch.arange(res), torch.arange(res))
|
256 |
+
|
257 |
+
attn_mask[0, i, j, 1, i, :] = 1.0
|
258 |
+
attn_mask[0, i, j, 2, j, :] = 1.0
|
259 |
+
attn_mask[1, i, j, 0, i, :] = 1.0
|
260 |
+
attn_mask[1, i, j, 2, :, j] = 1.0
|
261 |
+
attn_mask[2, i, j, 0, :, i] = 1.0
|
262 |
+
attn_mask[2, i, j, 1, :, j] = 1.0
|
263 |
+
attn_mask = attn_mask.bool()
|
264 |
+
|
265 |
+
attn_bias = torch.empty_like(attn_mask, dtype=torch.float)
|
266 |
+
attn_bias.masked_fill_(attn_mask, 0.0)
|
267 |
+
attn_bias.masked_fill_(~attn_mask, float("-inf"))
|
268 |
+
|
269 |
+
return attn_bias.reshape(N, N)
|
270 |
+
|
271 |
+
|
272 |
+
class TriplaneAttention(nn.Module):
|
273 |
+
def __init__(
|
274 |
+
self,
|
275 |
+
dim: int,
|
276 |
+
resolution: int,
|
277 |
+
num_heads: int = 16,
|
278 |
+
qkv_bias: bool = False,
|
279 |
+
attn_drop: float = 0.0,
|
280 |
+
proj_drop: float = 0.0,
|
281 |
+
full_attention: bool = False,
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
self.num_heads = num_heads
|
285 |
+
head_dim = dim // num_heads
|
286 |
+
self.scale = head_dim**-0.5
|
287 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
288 |
+
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
|
289 |
+
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
|
290 |
+
self.attn_drop = attn_drop
|
291 |
+
self.proj = nn.Linear(dim, dim)
|
292 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
293 |
+
|
294 |
+
self.resolution = resolution
|
295 |
+
self.full_attention = full_attention
|
296 |
+
self.attn_mask = (
|
297 |
+
get_triplane_attention_mask(resolution) if not full_attention else None
|
298 |
+
)
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
B, N, C = x.shape
|
302 |
+
# [B, N, C] -> [B, N, H, C/H]
|
303 |
+
q = self.wq(x).reshape(B, N, self.num_heads, C // self.num_heads)
|
304 |
+
k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads)
|
305 |
+
v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads)
|
306 |
+
|
307 |
+
# detokenize the planes
|
308 |
+
assert N == self.resolution**2 * 3
|
309 |
+
attn_bias = (
|
310 |
+
self.attn_mask.to(q)
|
311 |
+
.unsqueeze(0)
|
312 |
+
.unsqueeze(0)
|
313 |
+
.expand(B, self.num_heads, -1, -1)
|
314 |
+
if not self.full_attention
|
315 |
+
else None
|
316 |
+
)
|
317 |
+
|
318 |
+
# full attention
|
319 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
320 |
+
q.permute(0, 2, 1, 3),
|
321 |
+
k.permute(0, 2, 1, 3),
|
322 |
+
v.permute(0, 2, 1, 3),
|
323 |
+
attn_mask=attn_bias,
|
324 |
+
dropout_p=self.attn_drop,
|
325 |
+
scale=self.scale,
|
326 |
+
).permute(0, 2, 1, 3)
|
327 |
+
|
328 |
+
# [B, N_q, H, C/H] -> [B, N_q, C]
|
329 |
+
x = x.reshape(B, N, C)
|
330 |
+
x = self.proj(x)
|
331 |
+
x = self.proj_drop(x)
|
332 |
+
return x
|
333 |
+
|
334 |
+
|
335 |
+
class TwoStreamBlock(nn.Module):
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
dim_latent: int,
|
339 |
+
dim_input: int,
|
340 |
+
num_basic_blocks: int = 4,
|
341 |
+
num_heads: int = 16,
|
342 |
+
qkv_bias: bool = False,
|
343 |
+
attn_drop: float = 0.0,
|
344 |
+
proj_drop: float = 0.0,
|
345 |
+
ff_drop: float = 0.0,
|
346 |
+
norm_x_input: bool = True,
|
347 |
+
dim_cross: Optional[int] = None,
|
348 |
+
):
|
349 |
+
super().__init__()
|
350 |
+
|
351 |
+
# Define the fuse block that fuse the input into the latent
|
352 |
+
self.fuse_block_in = FuseBlock(
|
353 |
+
dim_latent,
|
354 |
+
dim_input,
|
355 |
+
num_heads=num_heads,
|
356 |
+
qkv_bias=qkv_bias,
|
357 |
+
attn_drop=attn_drop,
|
358 |
+
proj_drop=proj_drop,
|
359 |
+
ff_drop=ff_drop,
|
360 |
+
norm_x_input=norm_x_input,
|
361 |
+
)
|
362 |
+
|
363 |
+
# Define the transformer block that process the latent
|
364 |
+
self.transformer_block = nn.ModuleList(
|
365 |
+
[
|
366 |
+
BasicBlock(
|
367 |
+
dim_latent,
|
368 |
+
kv_dim=dim_cross,
|
369 |
+
num_heads=num_heads,
|
370 |
+
qkv_bias=qkv_bias,
|
371 |
+
proj_drop=proj_drop,
|
372 |
+
ff_drop=ff_drop,
|
373 |
+
)
|
374 |
+
for _ in range(num_basic_blocks)
|
375 |
+
]
|
376 |
+
)
|
377 |
+
|
378 |
+
# Define the fuse block that fuse the latent into the input
|
379 |
+
self.fuse_block_out = FuseBlock(
|
380 |
+
dim_input,
|
381 |
+
dim_latent,
|
382 |
+
num_heads=num_heads,
|
383 |
+
qkv_bias=qkv_bias,
|
384 |
+
attn_drop=attn_drop,
|
385 |
+
proj_drop=proj_drop,
|
386 |
+
ff_drop=ff_drop,
|
387 |
+
norm_x_input=norm_x_input,
|
388 |
+
)
|
389 |
+
|
390 |
+
def forward(self, latent, input, cross_input):
|
391 |
+
latent = self.fuse_block_in(latent, input)
|
392 |
+
for block in self.transformer_block:
|
393 |
+
latent = block(latent, cross_input)
|
394 |
+
input = self.fuse_block_out(input, latent)
|
395 |
+
return latent, input
|
396 |
+
|
397 |
+
|
398 |
+
class TwoStreamInterleaveTransformer(BaseModule):
|
399 |
+
@dataclass
|
400 |
+
class Config(BaseModule.Config):
|
401 |
+
num_attention_heads: int = 16
|
402 |
+
attention_head_dim: int = 64
|
403 |
+
raw_triplane_channels: int = 1024
|
404 |
+
triplane_channels: int = 1024
|
405 |
+
raw_image_channels: int = 1024
|
406 |
+
num_latents: int = 1792
|
407 |
+
num_blocks: int = 4
|
408 |
+
num_basic_blocks: int = 3
|
409 |
+
dropout: float = 0.0
|
410 |
+
latent_init_std: float = 0.02
|
411 |
+
norm_num_groups: int = 32
|
412 |
+
attention_bias: bool = False
|
413 |
+
norm_x_input: bool = False
|
414 |
+
cross_attention_dim: int = 1024
|
415 |
+
mix_latent: bool = True
|
416 |
+
|
417 |
+
cfg: Config
|
418 |
+
|
419 |
+
def configure(self) -> None:
|
420 |
+
self.mix_latent = self.cfg.mix_latent
|
421 |
+
|
422 |
+
# Define the dimensions
|
423 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
424 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
425 |
+
self.num_latents = self.cfg.num_latents
|
426 |
+
self.latent_dim = self.num_attention_heads * self.attention_head_dim
|
427 |
+
|
428 |
+
# Define input layers
|
429 |
+
if self.cfg.norm_num_groups > 0:
|
430 |
+
self.norm_triplane = torch.nn.GroupNorm(
|
431 |
+
num_groups=self.cfg.norm_num_groups,
|
432 |
+
num_channels=self.cfg.raw_triplane_channels,
|
433 |
+
eps=1e-6,
|
434 |
+
affine=True,
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
self.norm_triplane = nn.LayerNorm(self.cfg.raw_triplane_channels)
|
438 |
+
self.proj_triplane = nn.Linear(
|
439 |
+
self.cfg.raw_triplane_channels, self.cfg.triplane_channels
|
440 |
+
)
|
441 |
+
if self.mix_latent:
|
442 |
+
self.norm_image = nn.LayerNorm(self.cfg.raw_image_channels)
|
443 |
+
self.proj_image = nn.Linear(self.cfg.raw_image_channels, self.latent_dim)
|
444 |
+
self.norm_latent = nn.LayerNorm(self.latent_dim)
|
445 |
+
self.proj_latent = nn.Linear(self.latent_dim, self.latent_dim)
|
446 |
+
|
447 |
+
# Define the latents
|
448 |
+
self.latent_init = nn.Parameter(
|
449 |
+
torch.zeros(1, self.num_latents, self.latent_dim)
|
450 |
+
)
|
451 |
+
nn.init.normal_(self.latent_init, std=self.cfg.latent_init_std)
|
452 |
+
|
453 |
+
# Define the transformer blocks
|
454 |
+
self.main_blocks = nn.ModuleList(
|
455 |
+
[
|
456 |
+
TwoStreamBlock(
|
457 |
+
self.latent_dim,
|
458 |
+
self.cfg.triplane_channels,
|
459 |
+
num_basic_blocks=self.cfg.num_basic_blocks,
|
460 |
+
num_heads=self.num_attention_heads,
|
461 |
+
qkv_bias=self.cfg.attention_bias,
|
462 |
+
proj_drop=self.cfg.dropout,
|
463 |
+
ff_drop=self.cfg.dropout,
|
464 |
+
norm_x_input=self.cfg.norm_x_input,
|
465 |
+
dim_cross=self.cfg.cross_attention_dim,
|
466 |
+
)
|
467 |
+
for _ in range(self.cfg.num_blocks)
|
468 |
+
]
|
469 |
+
)
|
470 |
+
|
471 |
+
# 4. Define output layers
|
472 |
+
self.proj_out = nn.Linear(
|
473 |
+
self.cfg.triplane_channels, self.cfg.raw_triplane_channels
|
474 |
+
)
|
475 |
+
|
476 |
+
def forward(self, hidden_states, encoder_hidden_states, **kwargs):
|
477 |
+
# hidden_states: [B, triplane_dim, N_triplane] is triplane tokens
|
478 |
+
# encoder_hidden_states: [B, N_image, image_dim] is the image tokens
|
479 |
+
if isinstance(self.norm_triplane, nn.GroupNorm):
|
480 |
+
triplane_tokens = self.norm_triplane(hidden_states)
|
481 |
+
triplane_tokens = triplane_tokens.permute(
|
482 |
+
0, 2, 1
|
483 |
+
) # [B, N_triplane, triplane_dim]
|
484 |
+
elif isinstance(self.norm_triplane, nn.LayerNorm):
|
485 |
+
triplane_tokens = self.norm_triplane(hidden_states.permute(0, 2, 1))
|
486 |
+
else:
|
487 |
+
raise ValueError("Unknown normalization layer")
|
488 |
+
triplane_tokens = self.proj_triplane(triplane_tokens)
|
489 |
+
if self.mix_latent:
|
490 |
+
image_tokens = self.norm_image(
|
491 |
+
encoder_hidden_states
|
492 |
+
) # [B, N_image, image_dim]
|
493 |
+
image_tokens = self.proj_image(image_tokens)
|
494 |
+
init_latents = self.latent_init.expand(
|
495 |
+
hidden_states.shape[0], -1, -1
|
496 |
+
) # [B, N_latent_init, latent_dim]
|
497 |
+
init_latents = self.norm_latent(init_latents)
|
498 |
+
init_latents = self.proj_latent(init_latents)
|
499 |
+
if self.mix_latent:
|
500 |
+
latent_tokens = torch.cat(
|
501 |
+
[image_tokens, init_latents], dim=1
|
502 |
+
) # [B, N_latent, latent_dim]
|
503 |
+
else:
|
504 |
+
latent_tokens = init_latents
|
505 |
+
|
506 |
+
# forward the main blocks
|
507 |
+
for block in self.main_blocks:
|
508 |
+
latent_tokens, triplane_tokens = block(
|
509 |
+
latent_tokens, triplane_tokens, encoder_hidden_states
|
510 |
+
)
|
511 |
+
|
512 |
+
# project the triplane tokens back to the original dimension
|
513 |
+
triplane_tokens = self.proj_out(triplane_tokens).permute(0, 2, 1).contiguous()
|
514 |
+
triplane_tokens = triplane_tokens + hidden_states
|
515 |
+
return triplane_tokens
|
sf3d/models/utils.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import importlib
|
3 |
+
import math
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from jaxtyping import Bool, Float, Int, Num
|
13 |
+
from omegaconf import DictConfig, OmegaConf
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
|
17 |
+
class BaseModule(nn.Module):
|
18 |
+
@dataclass
|
19 |
+
class Config:
|
20 |
+
pass
|
21 |
+
|
22 |
+
cfg: Config # add this to every subclass of BaseModule to enable static type checking
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
self.cfg = parse_structured(self.Config, cfg)
|
29 |
+
self.configure(*args, **kwargs)
|
30 |
+
|
31 |
+
def configure(self, *args, **kwargs) -> None:
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
|
35 |
+
def find_class(cls_string):
|
36 |
+
module_string = ".".join(cls_string.split(".")[:-1])
|
37 |
+
cls_name = cls_string.split(".")[-1]
|
38 |
+
module = importlib.import_module(module_string, package=None)
|
39 |
+
cls = getattr(module, cls_name)
|
40 |
+
return cls
|
41 |
+
|
42 |
+
|
43 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
44 |
+
# Check if cfg.keys are in fields
|
45 |
+
cfg_ = cfg.copy()
|
46 |
+
keys = list(cfg_.keys())
|
47 |
+
|
48 |
+
field_names = {f.name for f in dataclasses.fields(fields)}
|
49 |
+
for key in keys:
|
50 |
+
# This is helpful when swapping out modules from CLI
|
51 |
+
if key not in field_names:
|
52 |
+
print(f"Ignoring {key} as it's not supported by {fields}")
|
53 |
+
cfg_.pop(key)
|
54 |
+
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
|
55 |
+
return scfg
|
56 |
+
|
57 |
+
|
58 |
+
EPS_DTYPE = {
|
59 |
+
torch.float16: 1e-4,
|
60 |
+
torch.bfloat16: 1e-4,
|
61 |
+
torch.float32: 1e-7,
|
62 |
+
torch.float64: 1e-8,
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def dot(x, y, dim=-1):
|
67 |
+
return torch.sum(x * y, dim, keepdim=True)
|
68 |
+
|
69 |
+
|
70 |
+
def reflect(x, n):
|
71 |
+
return x - 2 * dot(x, n) * n
|
72 |
+
|
73 |
+
|
74 |
+
def normalize(x, dim=-1, eps=None):
|
75 |
+
if eps is None:
|
76 |
+
eps = EPS_DTYPE[x.dtype]
|
77 |
+
return F.normalize(x, dim=dim, p=2, eps=eps)
|
78 |
+
|
79 |
+
|
80 |
+
def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
|
81 |
+
# One pad for determinant
|
82 |
+
tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
|
83 |
+
det_tri = torch.det(tri_sq)
|
84 |
+
tri_rev = torch.cat(
|
85 |
+
(tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
|
86 |
+
)
|
87 |
+
tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
|
88 |
+
return tri_sq
|
89 |
+
|
90 |
+
|
91 |
+
def triangle_intersection_2d(
|
92 |
+
t1: Float[Tensor, "*B 3 2"],
|
93 |
+
t2: Float[Tensor, "*B 3 2"],
|
94 |
+
eps=1e-12,
|
95 |
+
) -> Float[Tensor, "*B"]: # noqa: F821
|
96 |
+
"""Returns True if triangles collide, False otherwise"""
|
97 |
+
|
98 |
+
def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
|
99 |
+
logdetx = torch.logdet(x.double())
|
100 |
+
if eps is None:
|
101 |
+
return ~torch.isfinite(logdetx)
|
102 |
+
return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
|
103 |
+
|
104 |
+
t1s = tri_winding(t1)
|
105 |
+
t2s = tri_winding(t2)
|
106 |
+
|
107 |
+
# Assume the triangles do not collide in the begging
|
108 |
+
ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
|
109 |
+
for i in range(3):
|
110 |
+
edge = torch.roll(t1s, i, dims=1)[:, :2, :]
|
111 |
+
# Check if all points of triangle 2 lay on the external side of edge E.
|
112 |
+
# If this is the case the triangle do not collide
|
113 |
+
upd = (
|
114 |
+
chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
|
115 |
+
& chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
|
116 |
+
& chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
|
117 |
+
)
|
118 |
+
# Here no collision is still True due to inversion
|
119 |
+
ret = ret | upd
|
120 |
+
|
121 |
+
for i in range(3):
|
122 |
+
edge = torch.roll(t2s, i, dims=1)[:, :2, :]
|
123 |
+
|
124 |
+
upd = (
|
125 |
+
chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
|
126 |
+
& chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
|
127 |
+
& chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
|
128 |
+
)
|
129 |
+
# Here no collision is still True due to inversion
|
130 |
+
ret = ret | upd
|
131 |
+
|
132 |
+
return ~ret # Do the inversion
|
133 |
+
|
134 |
+
|
135 |
+
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
|
136 |
+
|
137 |
+
|
138 |
+
def scale_tensor(
|
139 |
+
dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
|
140 |
+
):
|
141 |
+
if inp_scale is None:
|
142 |
+
inp_scale = (0, 1)
|
143 |
+
if tgt_scale is None:
|
144 |
+
tgt_scale = (0, 1)
|
145 |
+
if isinstance(tgt_scale, Tensor):
|
146 |
+
assert dat.shape[-1] == tgt_scale.shape[-1]
|
147 |
+
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
148 |
+
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
149 |
+
return dat
|
150 |
+
|
151 |
+
|
152 |
+
def dilate_fill(img, mask, iterations=10):
|
153 |
+
oldMask = mask.float()
|
154 |
+
oldImg = img
|
155 |
+
|
156 |
+
mask_kernel = torch.ones(
|
157 |
+
(1, 1, 3, 3),
|
158 |
+
dtype=oldMask.dtype,
|
159 |
+
device=oldMask.device,
|
160 |
+
)
|
161 |
+
|
162 |
+
for i in range(iterations):
|
163 |
+
newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
|
164 |
+
|
165 |
+
# Fill the extension with mean color of old valid regions
|
166 |
+
img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
|
167 |
+
mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
168 |
+
new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
169 |
+
|
170 |
+
# Average color of the valid region
|
171 |
+
mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
|
172 |
+
2
|
173 |
+
)
|
174 |
+
# Extend it to the new region
|
175 |
+
fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
|
176 |
+
|
177 |
+
mask_conv = F.conv2d(
|
178 |
+
newMask, mask_kernel, padding=1
|
179 |
+
) # Get the sum for each kernel patch
|
180 |
+
newImg = F.fold(
|
181 |
+
fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
|
182 |
+
) / mask_conv.clamp(1)
|
183 |
+
|
184 |
+
diffMask = newMask - oldMask
|
185 |
+
|
186 |
+
oldMask = newMask
|
187 |
+
oldImg = torch.lerp(oldImg, newImg, diffMask)
|
188 |
+
|
189 |
+
return oldImg
|
190 |
+
|
191 |
+
|
192 |
+
def float32_to_uint8_np(
|
193 |
+
x: Float[np.ndarray, "*B H W C"],
|
194 |
+
dither: bool = True,
|
195 |
+
dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
|
196 |
+
dither_strength: float = 1.0,
|
197 |
+
) -> Int[np.ndarray, "*B H W C"]:
|
198 |
+
if dither:
|
199 |
+
dither = (
|
200 |
+
dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
|
201 |
+
)
|
202 |
+
if dither_mask is not None:
|
203 |
+
dither = dither * dither_mask
|
204 |
+
return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
|
205 |
+
return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
|
206 |
+
|
207 |
+
|
208 |
+
def convert_data(data):
|
209 |
+
if data is None:
|
210 |
+
return None
|
211 |
+
elif isinstance(data, np.ndarray):
|
212 |
+
return data
|
213 |
+
elif isinstance(data, torch.Tensor):
|
214 |
+
if data.dtype in [torch.float16, torch.bfloat16]:
|
215 |
+
data = data.float()
|
216 |
+
return data.detach().cpu().numpy()
|
217 |
+
elif isinstance(data, list):
|
218 |
+
return [convert_data(d) for d in data]
|
219 |
+
elif isinstance(data, dict):
|
220 |
+
return {k: convert_data(v) for k, v in data.items()}
|
221 |
+
else:
|
222 |
+
raise TypeError(
|
223 |
+
"Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
|
224 |
+
type(data),
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
class ImageProcessor:
|
229 |
+
def convert_and_resize(
|
230 |
+
self,
|
231 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
232 |
+
size: int,
|
233 |
+
):
|
234 |
+
if isinstance(image, PIL.Image.Image):
|
235 |
+
image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
236 |
+
elif isinstance(image, np.ndarray):
|
237 |
+
if image.dtype == np.uint8:
|
238 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0)
|
239 |
+
else:
|
240 |
+
image = torch.from_numpy(image)
|
241 |
+
elif isinstance(image, torch.Tensor):
|
242 |
+
pass
|
243 |
+
|
244 |
+
batched = image.ndim == 4
|
245 |
+
|
246 |
+
if not batched:
|
247 |
+
image = image[None, ...]
|
248 |
+
image = F.interpolate(
|
249 |
+
image.permute(0, 3, 1, 2),
|
250 |
+
(size, size),
|
251 |
+
mode="bilinear",
|
252 |
+
align_corners=False,
|
253 |
+
antialias=True,
|
254 |
+
).permute(0, 2, 3, 1)
|
255 |
+
if not batched:
|
256 |
+
image = image[0]
|
257 |
+
return image
|
258 |
+
|
259 |
+
def __call__(
|
260 |
+
self,
|
261 |
+
image: Union[
|
262 |
+
PIL.Image.Image,
|
263 |
+
np.ndarray,
|
264 |
+
torch.FloatTensor,
|
265 |
+
List[PIL.Image.Image],
|
266 |
+
List[np.ndarray],
|
267 |
+
List[torch.FloatTensor],
|
268 |
+
],
|
269 |
+
size: int,
|
270 |
+
) -> Any:
|
271 |
+
if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
|
272 |
+
image = self.convert_and_resize(image, size)
|
273 |
+
else:
|
274 |
+
if not isinstance(image, list):
|
275 |
+
image = [image]
|
276 |
+
image = [self.convert_and_resize(im, size) for im in image]
|
277 |
+
image = torch.stack(image, dim=0)
|
278 |
+
return image
|
279 |
+
|
280 |
+
|
281 |
+
def get_intrinsic_from_fov(fov, H, W, bs=-1):
|
282 |
+
focal_length = 0.5 * H / np.tan(0.5 * fov)
|
283 |
+
intrinsic = np.identity(3, dtype=np.float32)
|
284 |
+
intrinsic[0, 0] = focal_length
|
285 |
+
intrinsic[1, 1] = focal_length
|
286 |
+
intrinsic[0, 2] = W / 2.0
|
287 |
+
intrinsic[1, 2] = H / 2.0
|
288 |
+
|
289 |
+
if bs > 0:
|
290 |
+
intrinsic = intrinsic[None].repeat(bs, axis=0)
|
291 |
+
|
292 |
+
return torch.from_numpy(intrinsic)
|
sf3d/system.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Any, List, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import trimesh
|
9 |
+
from einops import rearrange
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from jaxtyping import Float
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from PIL import Image
|
14 |
+
from safetensors.torch import load_model
|
15 |
+
from torch import Tensor
|
16 |
+
|
17 |
+
from sf3d.models.isosurface import MarchingTetrahedraHelper
|
18 |
+
from sf3d.models.mesh import Mesh
|
19 |
+
from sf3d.models.utils import (
|
20 |
+
BaseModule,
|
21 |
+
ImageProcessor,
|
22 |
+
convert_data,
|
23 |
+
dilate_fill,
|
24 |
+
dot,
|
25 |
+
find_class,
|
26 |
+
float32_to_uint8_np,
|
27 |
+
normalize,
|
28 |
+
scale_tensor,
|
29 |
+
)
|
30 |
+
from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w
|
31 |
+
|
32 |
+
from .texture_baker import TextureBaker
|
33 |
+
|
34 |
+
|
35 |
+
class SF3D(BaseModule):
|
36 |
+
@dataclass
|
37 |
+
class Config(BaseModule.Config):
|
38 |
+
cond_image_size: int
|
39 |
+
isosurface_resolution: int
|
40 |
+
isosurface_threshold: float = 10.0
|
41 |
+
radius: float = 1.0
|
42 |
+
background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
|
43 |
+
default_fovy_deg: float = 40.0
|
44 |
+
default_distance: float = 1.6
|
45 |
+
|
46 |
+
camera_embedder_cls: str = ""
|
47 |
+
camera_embedder: dict = field(default_factory=dict)
|
48 |
+
|
49 |
+
image_tokenizer_cls: str = ""
|
50 |
+
image_tokenizer: dict = field(default_factory=dict)
|
51 |
+
|
52 |
+
tokenizer_cls: str = ""
|
53 |
+
tokenizer: dict = field(default_factory=dict)
|
54 |
+
|
55 |
+
backbone_cls: str = ""
|
56 |
+
backbone: dict = field(default_factory=dict)
|
57 |
+
|
58 |
+
post_processor_cls: str = ""
|
59 |
+
post_processor: dict = field(default_factory=dict)
|
60 |
+
|
61 |
+
decoder_cls: str = ""
|
62 |
+
decoder: dict = field(default_factory=dict)
|
63 |
+
|
64 |
+
image_estimator_cls: str = ""
|
65 |
+
image_estimator: dict = field(default_factory=dict)
|
66 |
+
|
67 |
+
global_estimator_cls: str = ""
|
68 |
+
global_estimator: dict = field(default_factory=dict)
|
69 |
+
|
70 |
+
cfg: Config
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def from_pretrained(
|
74 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
75 |
+
):
|
76 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
77 |
+
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
78 |
+
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
79 |
+
else:
|
80 |
+
config_path = hf_hub_download(
|
81 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
82 |
+
)
|
83 |
+
weight_path = hf_hub_download(
|
84 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
85 |
+
)
|
86 |
+
|
87 |
+
cfg = OmegaConf.load(config_path)
|
88 |
+
OmegaConf.resolve(cfg)
|
89 |
+
model = cls(cfg)
|
90 |
+
load_model(model, weight_path)
|
91 |
+
return model
|
92 |
+
|
93 |
+
@property
|
94 |
+
def device(self):
|
95 |
+
return next(self.parameters()).device
|
96 |
+
|
97 |
+
def configure(self):
|
98 |
+
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
|
99 |
+
self.cfg.image_tokenizer
|
100 |
+
)
|
101 |
+
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
102 |
+
self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
|
103 |
+
self.cfg.camera_embedder
|
104 |
+
)
|
105 |
+
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
|
106 |
+
self.post_processor = find_class(self.cfg.post_processor_cls)(
|
107 |
+
self.cfg.post_processor
|
108 |
+
)
|
109 |
+
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
|
110 |
+
self.image_estimator = find_class(self.cfg.image_estimator_cls)(
|
111 |
+
self.cfg.image_estimator
|
112 |
+
)
|
113 |
+
self.global_estimator = find_class(self.cfg.global_estimator_cls)(
|
114 |
+
self.cfg.global_estimator
|
115 |
+
)
|
116 |
+
|
117 |
+
self.bbox: Float[Tensor, "2 3"]
|
118 |
+
self.register_buffer(
|
119 |
+
"bbox",
|
120 |
+
torch.as_tensor(
|
121 |
+
[
|
122 |
+
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
|
123 |
+
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
|
124 |
+
],
|
125 |
+
dtype=torch.float32,
|
126 |
+
),
|
127 |
+
)
|
128 |
+
self.isosurface_helper = MarchingTetrahedraHelper(
|
129 |
+
self.cfg.isosurface_resolution,
|
130 |
+
os.path.join(
|
131 |
+
os.path.dirname(__file__),
|
132 |
+
"..",
|
133 |
+
"load",
|
134 |
+
"tets",
|
135 |
+
f"{self.cfg.isosurface_resolution}_tets.npz",
|
136 |
+
),
|
137 |
+
)
|
138 |
+
|
139 |
+
self.baker = TextureBaker()
|
140 |
+
self.image_processor = ImageProcessor()
|
141 |
+
|
142 |
+
def triplane_to_meshes(
|
143 |
+
self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
|
144 |
+
) -> list[Mesh]:
|
145 |
+
meshes = []
|
146 |
+
for i in range(triplanes.shape[0]):
|
147 |
+
triplane = triplanes[i]
|
148 |
+
grid_vertices = scale_tensor(
|
149 |
+
self.isosurface_helper.grid_vertices.to(triplanes.device),
|
150 |
+
self.isosurface_helper.points_range,
|
151 |
+
self.bbox,
|
152 |
+
)
|
153 |
+
|
154 |
+
values = self.query_triplane(grid_vertices, triplane)
|
155 |
+
decoded = self.decoder(values, include=["vertex_offset", "density"])
|
156 |
+
sdf = decoded["density"] - self.cfg.isosurface_threshold
|
157 |
+
|
158 |
+
deform = decoded["vertex_offset"].squeeze(0)
|
159 |
+
|
160 |
+
mesh: Mesh = self.isosurface_helper(
|
161 |
+
sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
|
162 |
+
)
|
163 |
+
mesh.v_pos = scale_tensor(
|
164 |
+
mesh.v_pos, self.isosurface_helper.points_range, self.bbox
|
165 |
+
)
|
166 |
+
|
167 |
+
meshes.append(mesh)
|
168 |
+
|
169 |
+
return meshes
|
170 |
+
|
171 |
+
def query_triplane(
|
172 |
+
self,
|
173 |
+
positions: Float[Tensor, "*B N 3"],
|
174 |
+
triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
|
175 |
+
) -> Float[Tensor, "*B N F"]:
|
176 |
+
batched = positions.ndim == 3
|
177 |
+
if not batched:
|
178 |
+
# no batch dimension
|
179 |
+
triplanes = triplanes[None, ...]
|
180 |
+
positions = positions[None, ...]
|
181 |
+
assert triplanes.ndim == 5 and positions.ndim == 3
|
182 |
+
|
183 |
+
positions = scale_tensor(
|
184 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
185 |
+
)
|
186 |
+
|
187 |
+
indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
|
188 |
+
(positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
|
189 |
+
dim=-3,
|
190 |
+
).to(triplanes.dtype)
|
191 |
+
out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
|
192 |
+
rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
|
193 |
+
rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
|
194 |
+
align_corners=True,
|
195 |
+
mode="bilinear",
|
196 |
+
)
|
197 |
+
out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
|
198 |
+
|
199 |
+
return out
|
200 |
+
|
201 |
+
def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
|
202 |
+
# if batch[rgb_cond] is only one view, add a view dimension
|
203 |
+
if len(batch["rgb_cond"].shape) == 4:
|
204 |
+
batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
|
205 |
+
batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
|
206 |
+
batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
|
207 |
+
batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
|
208 |
+
batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
|
209 |
+
batch_size, n_input_views = batch["rgb_cond"].shape[:2]
|
210 |
+
|
211 |
+
camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
|
212 |
+
camera_embeds = self.camera_embedder(**batch)
|
213 |
+
|
214 |
+
input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
|
215 |
+
rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
|
216 |
+
modulation_cond=camera_embeds,
|
217 |
+
)
|
218 |
+
|
219 |
+
input_image_tokens = rearrange(
|
220 |
+
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
|
221 |
+
)
|
222 |
+
|
223 |
+
tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
|
224 |
+
|
225 |
+
tokens = self.backbone(
|
226 |
+
tokens,
|
227 |
+
encoder_hidden_states=input_image_tokens,
|
228 |
+
modulation_cond=None,
|
229 |
+
)
|
230 |
+
|
231 |
+
direct_codes = self.tokenizer.detokenize(tokens)
|
232 |
+
scene_codes = self.post_processor(direct_codes)
|
233 |
+
return scene_codes, direct_codes
|
234 |
+
|
235 |
+
def run_image(
|
236 |
+
self,
|
237 |
+
image: Image,
|
238 |
+
bake_resolution: int,
|
239 |
+
estimate_illumination: bool = False,
|
240 |
+
) -> Tuple[trimesh.Trimesh, dict[str, Any]]:
|
241 |
+
if image.mode != "RGBA":
|
242 |
+
raise ValueError("Image must be in RGBA mode")
|
243 |
+
img_cond = (
|
244 |
+
torch.from_numpy(
|
245 |
+
np.asarray(
|
246 |
+
image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
|
247 |
+
).astype(np.float32)
|
248 |
+
/ 255.0
|
249 |
+
)
|
250 |
+
.float()
|
251 |
+
.clip(0, 1)
|
252 |
+
.to(self.device)
|
253 |
+
)
|
254 |
+
mask_cond = img_cond[:, :, -1:]
|
255 |
+
rgb_cond = torch.lerp(
|
256 |
+
torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
|
257 |
+
img_cond[:, :, :3],
|
258 |
+
mask_cond,
|
259 |
+
)
|
260 |
+
|
261 |
+
c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
|
262 |
+
intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
|
263 |
+
self.cfg.default_fovy_deg,
|
264 |
+
self.cfg.cond_image_size,
|
265 |
+
self.cfg.cond_image_size,
|
266 |
+
)
|
267 |
+
|
268 |
+
batch = {
|
269 |
+
"rgb_cond": rgb_cond,
|
270 |
+
"mask_cond": mask_cond,
|
271 |
+
"c2w_cond": c2w_cond.unsqueeze(0),
|
272 |
+
"intrinsic_cond": intrinsic.to(self.device).unsqueeze(0),
|
273 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0),
|
274 |
+
}
|
275 |
+
|
276 |
+
meshes, global_dict = self.generate_mesh(
|
277 |
+
batch, bake_resolution, estimate_illumination
|
278 |
+
)
|
279 |
+
return meshes[0], global_dict
|
280 |
+
|
281 |
+
def generate_mesh(
|
282 |
+
self,
|
283 |
+
batch,
|
284 |
+
bake_resolution: int,
|
285 |
+
estimate_illumination: bool = False,
|
286 |
+
) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
|
287 |
+
batch["rgb_cond"] = self.image_processor(
|
288 |
+
batch["rgb_cond"], self.cfg.cond_image_size
|
289 |
+
)
|
290 |
+
batch["mask_cond"] = self.image_processor(
|
291 |
+
batch["mask_cond"], self.cfg.cond_image_size
|
292 |
+
)
|
293 |
+
scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
|
294 |
+
|
295 |
+
global_dict = {}
|
296 |
+
if self.image_estimator is not None:
|
297 |
+
global_dict.update(
|
298 |
+
self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
|
299 |
+
)
|
300 |
+
if self.global_estimator is not None and estimate_illumination:
|
301 |
+
global_dict.update(self.global_estimator(non_postprocessed_codes))
|
302 |
+
|
303 |
+
with torch.no_grad():
|
304 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
305 |
+
meshes = self.triplane_to_meshes(scene_codes)
|
306 |
+
|
307 |
+
rets = []
|
308 |
+
for i, mesh in enumerate(meshes):
|
309 |
+
# Check for empty mesh
|
310 |
+
if mesh.v_pos.shape[0] == 0:
|
311 |
+
rets.append(trimesh.Trimesh())
|
312 |
+
continue
|
313 |
+
|
314 |
+
mesh.unwrap_uv()
|
315 |
+
|
316 |
+
# Build textures
|
317 |
+
rast = self.baker.rasterize(
|
318 |
+
mesh.v_tex, mesh.t_pos_idx, bake_resolution
|
319 |
+
)
|
320 |
+
bake_mask = self.baker.get_mask(rast)
|
321 |
+
|
322 |
+
pos_bake = self.baker.interpolate(
|
323 |
+
mesh.v_pos,
|
324 |
+
rast,
|
325 |
+
mesh.t_pos_idx,
|
326 |
+
mesh.v_tex,
|
327 |
+
)
|
328 |
+
gb_pos = pos_bake[bake_mask]
|
329 |
+
|
330 |
+
tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
|
331 |
+
decoded = self.decoder(
|
332 |
+
tri_query, exclude=["density", "vertex_offset"]
|
333 |
+
)
|
334 |
+
|
335 |
+
nrm = self.baker.interpolate(
|
336 |
+
mesh.v_nrm,
|
337 |
+
rast,
|
338 |
+
mesh.t_pos_idx,
|
339 |
+
mesh.v_tex,
|
340 |
+
)
|
341 |
+
gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
|
342 |
+
decoded["normal"] = gb_nrm
|
343 |
+
|
344 |
+
# Check if any keys in global_dict start with decoded_
|
345 |
+
for k, v in global_dict.items():
|
346 |
+
if k.startswith("decoder_"):
|
347 |
+
decoded[k.replace("decoder_", "")] = v[i]
|
348 |
+
|
349 |
+
mat_out = {
|
350 |
+
"albedo": decoded["features"],
|
351 |
+
"roughness": decoded["roughness"],
|
352 |
+
"metallic": decoded["metallic"],
|
353 |
+
"normal": normalize(decoded["perturb_normal"]),
|
354 |
+
"bump": None,
|
355 |
+
}
|
356 |
+
|
357 |
+
for k, v in mat_out.items():
|
358 |
+
if v is None:
|
359 |
+
continue
|
360 |
+
if v.shape[0] == 1:
|
361 |
+
# Skip and directly add a single value
|
362 |
+
mat_out[k] = v[0]
|
363 |
+
else:
|
364 |
+
f = torch.zeros(
|
365 |
+
bake_resolution,
|
366 |
+
bake_resolution,
|
367 |
+
v.shape[-1],
|
368 |
+
dtype=v.dtype,
|
369 |
+
device=v.device,
|
370 |
+
)
|
371 |
+
if v.shape == f.shape:
|
372 |
+
continue
|
373 |
+
if k == "normal":
|
374 |
+
# Use un-normalized tangents here so that larger smaller tris
|
375 |
+
# Don't effect the tangents that much
|
376 |
+
tng = self.baker.interpolate(
|
377 |
+
mesh.v_tng,
|
378 |
+
rast,
|
379 |
+
mesh.t_pos_idx,
|
380 |
+
mesh.v_tex,
|
381 |
+
)
|
382 |
+
gb_tng = tng[bake_mask]
|
383 |
+
gb_tng = F.normalize(gb_tng, dim=-1)
|
384 |
+
gb_btng = F.normalize(
|
385 |
+
torch.cross(gb_tng, gb_nrm, dim=-1), dim=-1
|
386 |
+
)
|
387 |
+
normal = F.normalize(mat_out["normal"], dim=-1)
|
388 |
+
|
389 |
+
bump = torch.cat(
|
390 |
+
# Check if we have to flip some things
|
391 |
+
(
|
392 |
+
dot(normal, gb_tng),
|
393 |
+
dot(normal, gb_btng),
|
394 |
+
dot(normal, gb_nrm).clip(
|
395 |
+
0.3, 1
|
396 |
+
), # Never go below 0.3. This would indicate a flipped (or close to one) normal
|
397 |
+
),
|
398 |
+
-1,
|
399 |
+
)
|
400 |
+
bump = (bump * 0.5 + 0.5).clamp(0, 1)
|
401 |
+
|
402 |
+
f[bake_mask] = bump.view(-1, 3)
|
403 |
+
mat_out["bump"] = f
|
404 |
+
else:
|
405 |
+
f[bake_mask] = v.view(-1, v.shape[-1])
|
406 |
+
mat_out[k] = f
|
407 |
+
|
408 |
+
def uv_padding(arr):
|
409 |
+
if arr.ndim == 1:
|
410 |
+
return arr
|
411 |
+
return (
|
412 |
+
dilate_fill(
|
413 |
+
arr.permute(2, 0, 1)[None, ...],
|
414 |
+
bake_mask.unsqueeze(0).unsqueeze(0),
|
415 |
+
iterations=bake_resolution // 150,
|
416 |
+
)
|
417 |
+
.squeeze(0)
|
418 |
+
.permute(1, 2, 0)
|
419 |
+
)
|
420 |
+
|
421 |
+
verts_np = convert_data(mesh.v_pos)
|
422 |
+
faces = convert_data(mesh.t_pos_idx)
|
423 |
+
uvs = convert_data(mesh.v_tex)
|
424 |
+
|
425 |
+
basecolor_tex = Image.fromarray(
|
426 |
+
float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
|
427 |
+
).convert("RGB")
|
428 |
+
basecolor_tex.format = "JPEG"
|
429 |
+
|
430 |
+
metallic = mat_out["metallic"].squeeze().cpu().item()
|
431 |
+
roughness = mat_out["roughness"].squeeze().cpu().item()
|
432 |
+
|
433 |
+
if "bump" in mat_out and mat_out["bump"] is not None:
|
434 |
+
bump_np = convert_data(uv_padding(mat_out["bump"]))
|
435 |
+
bump_up = np.ones_like(bump_np)
|
436 |
+
bump_up[..., :2] = 0.5
|
437 |
+
bump_up[..., 2:] = 1
|
438 |
+
bump_tex = Image.fromarray(
|
439 |
+
float32_to_uint8_np(
|
440 |
+
bump_np,
|
441 |
+
dither=True,
|
442 |
+
# Do not dither if something is perfectly flat
|
443 |
+
dither_mask=np.all(
|
444 |
+
bump_np == bump_up, axis=-1, keepdims=True
|
445 |
+
).astype(np.float32),
|
446 |
+
)
|
447 |
+
).convert("RGB")
|
448 |
+
bump_tex.format = (
|
449 |
+
"JPEG" # PNG would be better but the assets are larger
|
450 |
+
)
|
451 |
+
else:
|
452 |
+
bump_tex = None
|
453 |
+
|
454 |
+
material = trimesh.visual.material.PBRMaterial(
|
455 |
+
baseColorTexture=basecolor_tex,
|
456 |
+
roughnessFactor=roughness,
|
457 |
+
metallicFactor=metallic,
|
458 |
+
normalTexture=bump_tex,
|
459 |
+
)
|
460 |
+
|
461 |
+
tmesh = trimesh.Trimesh(
|
462 |
+
vertices=verts_np,
|
463 |
+
faces=faces,
|
464 |
+
visual=trimesh.visual.texture.TextureVisuals(
|
465 |
+
uv=uvs, material=material
|
466 |
+
),
|
467 |
+
)
|
468 |
+
rot = trimesh.transformations.rotation_matrix(
|
469 |
+
np.radians(-90), [1, 0, 0]
|
470 |
+
)
|
471 |
+
tmesh.apply_transform(rot)
|
472 |
+
tmesh.apply_transform(
|
473 |
+
trimesh.transformations.rotation_matrix(
|
474 |
+
np.radians(90), [0, 1, 0]
|
475 |
+
)
|
476 |
+
)
|
477 |
+
|
478 |
+
tmesh.invert()
|
479 |
+
|
480 |
+
rets.append(tmesh)
|
481 |
+
|
482 |
+
return rets, global_dict
|
sf3d/texture_baker.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import slangtorch
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from jaxtyping import Bool, Float
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
|
10 |
+
class TextureBaker(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self.baker = slangtorch.loadModule(
|
14 |
+
os.path.join(os.path.dirname(__file__), "texture_baker.slang")
|
15 |
+
)
|
16 |
+
|
17 |
+
def rasterize(
|
18 |
+
self,
|
19 |
+
uv: Float[Tensor, "Nv 2"],
|
20 |
+
face_indices: Float[Tensor, "Nf 3"],
|
21 |
+
bake_resolution: int,
|
22 |
+
) -> Float[Tensor, "bake_resolution bake_resolution 4"]:
|
23 |
+
if not face_indices.is_cuda or not uv.is_cuda:
|
24 |
+
raise ValueError("All input tensors must be on cuda")
|
25 |
+
|
26 |
+
face_indices = face_indices.to(torch.int32)
|
27 |
+
uv = uv.to(torch.float32)
|
28 |
+
|
29 |
+
rast_result = torch.empty(
|
30 |
+
bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32
|
31 |
+
)
|
32 |
+
|
33 |
+
block_size = 16
|
34 |
+
grid_size = bake_resolution // block_size
|
35 |
+
self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw(
|
36 |
+
blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
|
37 |
+
)
|
38 |
+
|
39 |
+
return rast_result
|
40 |
+
|
41 |
+
def get_mask(
|
42 |
+
self, rast: Float[Tensor, "bake_resolution bake_resolution 4"]
|
43 |
+
) -> Bool[Tensor, "bake_resolution bake_resolution"]:
|
44 |
+
return rast[..., -1] >= 0
|
45 |
+
|
46 |
+
def interpolate(
|
47 |
+
self,
|
48 |
+
attr: Float[Tensor, "Nv 3"],
|
49 |
+
rast: Float[Tensor, "bake_resolution bake_resolution 4"],
|
50 |
+
face_indices: Float[Tensor, "Nf 3"],
|
51 |
+
uv: Float[Tensor, "Nv 2"],
|
52 |
+
) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
|
53 |
+
# Make sure all input tensors are on torch
|
54 |
+
if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda:
|
55 |
+
raise ValueError("All input tensors must be on cuda")
|
56 |
+
|
57 |
+
attr = attr.to(torch.float32)
|
58 |
+
face_indices = face_indices.to(torch.int32)
|
59 |
+
uv = uv.to(torch.float32)
|
60 |
+
|
61 |
+
pos_bake = torch.zeros(
|
62 |
+
rast.shape[0],
|
63 |
+
rast.shape[1],
|
64 |
+
3,
|
65 |
+
device=attr.device,
|
66 |
+
dtype=attr.dtype,
|
67 |
+
)
|
68 |
+
|
69 |
+
block_size = 16
|
70 |
+
grid_size = rast.shape[0] // block_size
|
71 |
+
self.baker.interpolate(
|
72 |
+
attr=attr, indices=face_indices, rast=rast, output=pos_bake
|
73 |
+
).launchRaw(
|
74 |
+
blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
|
75 |
+
)
|
76 |
+
|
77 |
+
return pos_bake
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
attr: Float[Tensor, "Nv 3"],
|
82 |
+
uv: Float[Tensor, "Nv 2"],
|
83 |
+
face_indices: Float[Tensor, "Nf 3"],
|
84 |
+
bake_resolution: int,
|
85 |
+
) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
|
86 |
+
rast = self.rasterize(uv, face_indices, bake_resolution)
|
87 |
+
return self.interpolate(attr, rast, face_indices, uv)
|
sf3d/texture_baker.slang
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// xy: 2D test position
|
2 |
+
// v1: vertex position 1
|
3 |
+
// v2: vertex position 2
|
4 |
+
// v3: vertex position 3
|
5 |
+
//
|
6 |
+
bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, out float u, out float v, out float w)
|
7 |
+
{
|
8 |
+
// Return true if the point (x,y) is inside the triangle defined by the vertices v1, v2, v3.
|
9 |
+
// If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
|
10 |
+
float2 v1v2 = v2 - v1;
|
11 |
+
float2 v1v3 = v3 - v1;
|
12 |
+
float2 xyv1 = xy - v1;
|
13 |
+
|
14 |
+
float d00 = dot(v1v2, v1v2);
|
15 |
+
float d01 = dot(v1v2, v1v3);
|
16 |
+
float d11 = dot(v1v3, v1v3);
|
17 |
+
float d20 = dot(xyv1, v1v2);
|
18 |
+
float d21 = dot(xyv1, v1v3);
|
19 |
+
|
20 |
+
float denom = d00 * d11 - d01 * d01;
|
21 |
+
v = (d11 * d20 - d01 * d21) / denom;
|
22 |
+
w = (d00 * d21 - d01 * d20) / denom;
|
23 |
+
u = 1.0 - v - w;
|
24 |
+
|
25 |
+
return (v >= 0.0) && (w >= 0.0) && (v + w <= 1.0);
|
26 |
+
}
|
27 |
+
|
28 |
+
[AutoPyBindCUDA]
|
29 |
+
[CUDAKernel]
|
30 |
+
void interpolate(
|
31 |
+
TensorView<float3> attr,
|
32 |
+
TensorView<int3> indices,
|
33 |
+
TensorView<float4> rast,
|
34 |
+
TensorView<float3> output)
|
35 |
+
{
|
36 |
+
// Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
|
37 |
+
|
38 |
+
uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
|
39 |
+
|
40 |
+
if (dispatch_id.x > output.size(0) || dispatch_id.y > output.size(1))
|
41 |
+
return;
|
42 |
+
|
43 |
+
float4 barycentric = rast[dispatch_id.x, dispatch_id.y];
|
44 |
+
int triangle_idx = int(barycentric.w);
|
45 |
+
|
46 |
+
if (triangle_idx < 0) {
|
47 |
+
output[dispatch_id.x, dispatch_id.y] = float3(0.0, 0.0, 0.0);
|
48 |
+
return;
|
49 |
+
}
|
50 |
+
|
51 |
+
float3 v1 = attr[indices[triangle_idx].x];
|
52 |
+
float3 v2 = attr[indices[triangle_idx].y];
|
53 |
+
float3 v3 = attr[indices[triangle_idx].z];
|
54 |
+
|
55 |
+
output[dispatch_id.x, dispatch_id.y] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
|
56 |
+
}
|
57 |
+
|
58 |
+
[AutoPyBindCUDA]
|
59 |
+
[CUDAKernel]
|
60 |
+
void bake_uv(
|
61 |
+
TensorView<float2> uv,
|
62 |
+
TensorView<int3> indices,
|
63 |
+
TensorView<float4> output)
|
64 |
+
{
|
65 |
+
uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
|
66 |
+
|
67 |
+
if (dispatch_id.y > output.size(0) || dispatch_id.x > output.size(1))
|
68 |
+
return;
|
69 |
+
|
70 |
+
// We index x,y but the orginal coords are HW. So swap them
|
71 |
+
float2 pixel_coord = float2(dispatch_id.y, dispatch_id.x);
|
72 |
+
// Normalize to [0, 1]
|
73 |
+
pixel_coord /= float2(output.size(1), output.size(0));
|
74 |
+
pixel_coord = clamp(pixel_coord, 0.0, 1.0);
|
75 |
+
// Flip x-axis
|
76 |
+
pixel_coord.y = 1 - pixel_coord.y;
|
77 |
+
|
78 |
+
for (int i = 0; i < indices.size(0); i++) {
|
79 |
+
float2 v1 = float2(uv[indices[i].x].x, uv[indices[i].x].y);
|
80 |
+
float2 v2 = float2(uv[indices[i].y].x, uv[indices[i].y].y);
|
81 |
+
float2 v3 = float2(uv[indices[i].z].x, uv[indices[i].z].y);
|
82 |
+
|
83 |
+
float u, v, w;
|
84 |
+
bool hit = barycentric_coordinates(pixel_coord, v1, v2, v3, u, v, w);
|
85 |
+
|
86 |
+
if (hit){
|
87 |
+
output[dispatch_id.x, dispatch_id.y] = float4(u, v, w, i);
|
88 |
+
return;
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
output[dispatch_id.x, dispatch_id.y] = float4(0.0, 0.0, 0.0, -1);
|
93 |
+
}
|
sf3d/utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import rembg
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import sf3d.models.utils as sf3d_utils
|
9 |
+
|
10 |
+
|
11 |
+
def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
|
12 |
+
intrinsic = sf3d_utils.get_intrinsic_from_fov(
|
13 |
+
np.deg2rad(fov_deg),
|
14 |
+
H=cond_height,
|
15 |
+
W=cond_width,
|
16 |
+
)
|
17 |
+
intrinsic_normed_cond = intrinsic.clone()
|
18 |
+
intrinsic_normed_cond[..., 0, 2] /= cond_width
|
19 |
+
intrinsic_normed_cond[..., 1, 2] /= cond_height
|
20 |
+
intrinsic_normed_cond[..., 0, 0] /= cond_width
|
21 |
+
intrinsic_normed_cond[..., 1, 1] /= cond_height
|
22 |
+
|
23 |
+
return intrinsic, intrinsic_normed_cond
|
24 |
+
|
25 |
+
|
26 |
+
def default_cond_c2w(distance: float):
|
27 |
+
c2w_cond = torch.as_tensor(
|
28 |
+
[
|
29 |
+
[0, 0, 1, distance],
|
30 |
+
[1, 0, 0, 0],
|
31 |
+
[0, 1, 0, 0],
|
32 |
+
[0, 0, 0, 1],
|
33 |
+
]
|
34 |
+
).float()
|
35 |
+
return c2w_cond
|
36 |
+
|
37 |
+
|
38 |
+
def remove_background(
|
39 |
+
image: Image,
|
40 |
+
rembg_session: Any = None,
|
41 |
+
force: bool = False,
|
42 |
+
**rembg_kwargs,
|
43 |
+
) -> Image:
|
44 |
+
do_remove = True
|
45 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
46 |
+
do_remove = False
|
47 |
+
do_remove = do_remove or force
|
48 |
+
if do_remove:
|
49 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
def resize_foreground(
|
54 |
+
image: Image,
|
55 |
+
ratio: float,
|
56 |
+
) -> Image:
|
57 |
+
image = np.array(image)
|
58 |
+
assert image.shape[-1] == 4
|
59 |
+
alpha = np.where(image[..., 3] > 0)
|
60 |
+
y1, y2, x1, x2 = (
|
61 |
+
alpha[0].min(),
|
62 |
+
alpha[0].max(),
|
63 |
+
alpha[1].min(),
|
64 |
+
alpha[1].max(),
|
65 |
+
)
|
66 |
+
# crop the foreground
|
67 |
+
fg = image[y1:y2, x1:x2]
|
68 |
+
# pad to square
|
69 |
+
size = max(fg.shape[0], fg.shape[1])
|
70 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
71 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
72 |
+
new_image = np.pad(
|
73 |
+
fg,
|
74 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
75 |
+
mode="constant",
|
76 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
77 |
+
)
|
78 |
+
|
79 |
+
# compute padding according to the ratio
|
80 |
+
new_size = int(new_image.shape[0] / ratio)
|
81 |
+
# pad to size, double side
|
82 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
83 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
84 |
+
new_image = np.pad(
|
85 |
+
new_image,
|
86 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
87 |
+
mode="constant",
|
88 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
89 |
+
)
|
90 |
+
new_image = Image.fromarray(new_image, mode="RGBA")
|
91 |
+
return new_image
|
stable_fast.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import time
|
4 |
+
from functools import lru_cache
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import rembg
|
10 |
+
import torch
|
11 |
+
from gradio_litmodel3d import LitModel3D
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
import sf3d.utils as sf3d_utils
|
15 |
+
from sf3d.system import SF3D
|
16 |
+
|
17 |
+
rembg_session = rembg.new_session()
|
18 |
+
|
19 |
+
COND_WIDTH = 512
|
20 |
+
COND_HEIGHT = 512
|
21 |
+
COND_DISTANCE = 1.6
|
22 |
+
COND_FOVY_DEG = 40
|
23 |
+
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
|
24 |
+
|
25 |
+
# Cached. Doesn't change
|
26 |
+
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
|
27 |
+
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
|
28 |
+
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
model = SF3D.from_pretrained(
|
33 |
+
"stabilityai/stable-fast-3d",
|
34 |
+
config_name="config.yaml",
|
35 |
+
weight_name="model.safetensors",
|
36 |
+
)
|
37 |
+
model.eval().cuda()
|
38 |
+
|
39 |
+
example_files = [
|
40 |
+
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
|
41 |
+
]
|
42 |
+
|
43 |
+
|
44 |
+
def run_model(input_image):
|
45 |
+
start = time.time()
|
46 |
+
with torch.no_grad():
|
47 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
48 |
+
model_batch = create_batch(input_image)
|
49 |
+
model_batch = {k: v.cuda() for k, v in model_batch.items()}
|
50 |
+
trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
|
51 |
+
trimesh_mesh = trimesh_mesh[0]
|
52 |
+
|
53 |
+
# Create new tmp file
|
54 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
55 |
+
|
56 |
+
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
|
57 |
+
|
58 |
+
print("Generation took:", time.time() - start, "s")
|
59 |
+
|
60 |
+
return tmp_file.name
|
61 |
+
|
62 |
+
|
63 |
+
def create_batch(input_image: Image) -> dict[str, Any]:
|
64 |
+
img_cond = (
|
65 |
+
torch.from_numpy(
|
66 |
+
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
|
67 |
+
/ 255.0
|
68 |
+
)
|
69 |
+
.float()
|
70 |
+
.clip(0, 1)
|
71 |
+
)
|
72 |
+
mask_cond = img_cond[:, :, -1:]
|
73 |
+
rgb_cond = torch.lerp(
|
74 |
+
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
|
75 |
+
)
|
76 |
+
|
77 |
+
batch_elem = {
|
78 |
+
"rgb_cond": rgb_cond,
|
79 |
+
"mask_cond": mask_cond,
|
80 |
+
"c2w_cond": c2w_cond.unsqueeze(0),
|
81 |
+
"intrinsic_cond": intrinsic.unsqueeze(0),
|
82 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
|
83 |
+
}
|
84 |
+
# Add batch dim
|
85 |
+
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
|
86 |
+
return batched
|
87 |
+
|
88 |
+
|
89 |
+
@lru_cache
|
90 |
+
def checkerboard(squares: int, size: int, min_value: float = 0.5):
|
91 |
+
base = np.zeros((squares, squares)) + min_value
|
92 |
+
base[1::2, ::2] = 1
|
93 |
+
base[::2, 1::2] = 1
|
94 |
+
|
95 |
+
repeat_mult = size // squares
|
96 |
+
return (
|
97 |
+
base.repeat(repeat_mult, axis=0)
|
98 |
+
.repeat(repeat_mult, axis=1)[:, :, None]
|
99 |
+
.repeat(3, axis=-1)
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
def remove_background(input_image: Image) -> Image:
|
104 |
+
return rembg.remove(input_image, session=rembg_session)
|
105 |
+
|
106 |
+
|
107 |
+
def resize_foreground(
|
108 |
+
image: Image,
|
109 |
+
ratio: float,
|
110 |
+
) -> Image:
|
111 |
+
image = np.array(image)
|
112 |
+
assert image.shape[-1] == 4
|
113 |
+
alpha = np.where(image[..., 3] > 0)
|
114 |
+
y1, y2, x1, x2 = (
|
115 |
+
alpha[0].min(),
|
116 |
+
alpha[0].max(),
|
117 |
+
alpha[1].min(),
|
118 |
+
alpha[1].max(),
|
119 |
+
)
|
120 |
+
# crop the foreground
|
121 |
+
fg = image[y1:y2, x1:x2]
|
122 |
+
# pad to square
|
123 |
+
size = max(fg.shape[0], fg.shape[1])
|
124 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
125 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
126 |
+
new_image = np.pad(
|
127 |
+
fg,
|
128 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
129 |
+
mode="constant",
|
130 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
131 |
+
)
|
132 |
+
|
133 |
+
# compute padding according to the ratio
|
134 |
+
new_size = int(new_image.shape[0] / ratio)
|
135 |
+
# pad to size, double side
|
136 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
137 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
138 |
+
new_image = np.pad(
|
139 |
+
new_image,
|
140 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
141 |
+
mode="constant",
|
142 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
143 |
+
)
|
144 |
+
new_image = Image.fromarray(new_image, mode="RGBA").resize(
|
145 |
+
(COND_WIDTH, COND_HEIGHT)
|
146 |
+
)
|
147 |
+
return new_image
|
148 |
+
|
149 |
+
|
150 |
+
def square_crop(input_image: Image) -> Image:
|
151 |
+
# Perform a center square crop
|
152 |
+
min_size = min(input_image.size)
|
153 |
+
left = (input_image.size[0] - min_size) // 2
|
154 |
+
top = (input_image.size[1] - min_size) // 2
|
155 |
+
right = (input_image.size[0] + min_size) // 2
|
156 |
+
bottom = (input_image.size[1] + min_size) // 2
|
157 |
+
return input_image.crop((left, top, right, bottom)).resize(
|
158 |
+
(COND_WIDTH, COND_HEIGHT)
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
def show_mask_img(input_image: Image) -> Image:
|
163 |
+
img_numpy = np.array(input_image)
|
164 |
+
alpha = img_numpy[:, :, 3] / 255.0
|
165 |
+
chkb = checkerboard(32, 512) * 255
|
166 |
+
new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
|
167 |
+
return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
|
168 |
+
|
169 |
+
|
170 |
+
def run_button(run_btn, input_image, background_state, foreground_ratio):
|
171 |
+
if run_btn == "Run":
|
172 |
+
glb_file: str = run_model(background_state)
|
173 |
+
|
174 |
+
return (
|
175 |
+
gr.update(),
|
176 |
+
gr.update(),
|
177 |
+
gr.update(),
|
178 |
+
gr.update(),
|
179 |
+
gr.update(value=glb_file, visible=True),
|
180 |
+
gr.update(visible=True),
|
181 |
+
)
|
182 |
+
elif run_btn == "Remove Background":
|
183 |
+
rem_removed = remove_background(input_image)
|
184 |
+
|
185 |
+
sqr_crop = square_crop(rem_removed)
|
186 |
+
fr_res = resize_foreground(sqr_crop, foreground_ratio)
|
187 |
+
|
188 |
+
return (
|
189 |
+
gr.update(value="Run", visible=True),
|
190 |
+
sqr_crop,
|
191 |
+
fr_res,
|
192 |
+
gr.update(value=show_mask_img(fr_res), visible=True),
|
193 |
+
gr.update(value=None, visible=False),
|
194 |
+
gr.update(visible=False),
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
def requires_bg_remove(image, fr):
|
199 |
+
if image is None:
|
200 |
+
return (
|
201 |
+
gr.update(visible=False, value="Run"),
|
202 |
+
None,
|
203 |
+
None,
|
204 |
+
gr.update(value=None, visible=False),
|
205 |
+
gr.update(visible=False),
|
206 |
+
gr.update(visible=False),
|
207 |
+
)
|
208 |
+
alpha_channel = np.array(image.getchannel("A"))
|
209 |
+
min_alpha = alpha_channel.min()
|
210 |
+
|
211 |
+
if min_alpha == 0:
|
212 |
+
print("Already has alpha")
|
213 |
+
sqr_crop = square_crop(image)
|
214 |
+
fr_res = resize_foreground(sqr_crop, fr)
|
215 |
+
return (
|
216 |
+
gr.update(value="Run", visible=True),
|
217 |
+
sqr_crop,
|
218 |
+
fr_res,
|
219 |
+
gr.update(value=show_mask_img(fr_res), visible=True),
|
220 |
+
gr.update(visible=False),
|
221 |
+
gr.update(visible=False),
|
222 |
+
)
|
223 |
+
return (
|
224 |
+
gr.update(value="Remove Background", visible=True),
|
225 |
+
None,
|
226 |
+
None,
|
227 |
+
gr.update(value=None, visible=False),
|
228 |
+
gr.update(visible=False),
|
229 |
+
gr.update(visible=False),
|
230 |
+
)
|
231 |
+
|
232 |
+
|
233 |
+
def update_foreground_ratio(img_proc, fr):
|
234 |
+
foreground_res = resize_foreground(img_proc, fr)
|
235 |
+
return (
|
236 |
+
foreground_res,
|
237 |
+
gr.update(value=show_mask_img(foreground_res)),
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
with gr.Blocks() as demo:
|
242 |
+
img_proc_state = gr.State()
|
243 |
+
background_remove_state = gr.State()
|
244 |
+
gr.Markdown("""
|
245 |
+
# SF3D: Stable Fast 3D Mesh Reconstruction with UV-unwrapping and Illumination Disentanglement
|
246 |
+
|
247 |
+
**SF3D** is a state-of-the-art method for 3D mesh reconstruction from a single image.
|
248 |
+
This demo allows you to upload an image and generate a 3D mesh model from it.
|
249 |
+
|
250 |
+
**Tips**
|
251 |
+
1. If the image already has an alpha channel, you can skip the background removal step.
|
252 |
+
2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
|
253 |
+
3. You can upload your own HDR environment map to light the 3D model.
|
254 |
+
""")
|
255 |
+
with gr.Row(variant="panel"):
|
256 |
+
with gr.Column():
|
257 |
+
with gr.Row():
|
258 |
+
input_img = gr.Image(
|
259 |
+
type="pil", label="Input Image", sources="upload", image_mode="RGBA"
|
260 |
+
)
|
261 |
+
preview_removal = gr.Image(
|
262 |
+
label="Preview Background Removal",
|
263 |
+
type="pil",
|
264 |
+
image_mode="RGB",
|
265 |
+
interactive=False,
|
266 |
+
visible=False,
|
267 |
+
)
|
268 |
+
|
269 |
+
foreground_ratio = gr.Slider(
|
270 |
+
label="Foreground Ratio",
|
271 |
+
minimum=0.5,
|
272 |
+
maximum=1.0,
|
273 |
+
value=0.85,
|
274 |
+
step=0.05,
|
275 |
+
)
|
276 |
+
|
277 |
+
foreground_ratio.change(
|
278 |
+
update_foreground_ratio,
|
279 |
+
inputs=[img_proc_state, foreground_ratio],
|
280 |
+
outputs=[background_remove_state, preview_removal],
|
281 |
+
)
|
282 |
+
|
283 |
+
run_btn = gr.Button("Run", variant="primary", visible=False)
|
284 |
+
|
285 |
+
with gr.Column():
|
286 |
+
output_3d = LitModel3D(
|
287 |
+
label="3D Model",
|
288 |
+
visible=False,
|
289 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
290 |
+
tonemapping="aces",
|
291 |
+
contrast=1.0,
|
292 |
+
scale=1.0,
|
293 |
+
)
|
294 |
+
with gr.Column(visible=False, scale=1.0) as hdr_row:
|
295 |
+
gr.Markdown("""## HDR Environment Map
|
296 |
+
|
297 |
+
Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
|
298 |
+
""")
|
299 |
+
|
300 |
+
with gr.Row():
|
301 |
+
hdr_illumination_file = gr.File(
|
302 |
+
label="HDR Env Map", file_types=[".hdr"], file_count="single"
|
303 |
+
)
|
304 |
+
example_hdris = [
|
305 |
+
os.path.join("demo_files/hdri", f)
|
306 |
+
for f in os.listdir("demo_files/hdri")
|
307 |
+
]
|
308 |
+
hdr_illumination_example = gr.Examples(
|
309 |
+
examples=example_hdris,
|
310 |
+
inputs=hdr_illumination_file,
|
311 |
+
)
|
312 |
+
|
313 |
+
hdr_illumination_file.change(
|
314 |
+
lambda x: gr.update(env_map=x.name if x is not None else None),
|
315 |
+
inputs=hdr_illumination_file,
|
316 |
+
outputs=[output_3d],
|
317 |
+
)
|
318 |
+
|
319 |
+
examples = gr.Examples(
|
320 |
+
examples=example_files,
|
321 |
+
inputs=input_img,
|
322 |
+
)
|
323 |
+
|
324 |
+
input_img.change(
|
325 |
+
requires_bg_remove,
|
326 |
+
inputs=[input_img, foreground_ratio],
|
327 |
+
outputs=[
|
328 |
+
run_btn,
|
329 |
+
img_proc_state,
|
330 |
+
background_remove_state,
|
331 |
+
preview_removal,
|
332 |
+
output_3d,
|
333 |
+
hdr_row,
|
334 |
+
],
|
335 |
+
)
|
336 |
+
|
337 |
+
run_btn.click(
|
338 |
+
run_button,
|
339 |
+
inputs=[
|
340 |
+
run_btn,
|
341 |
+
input_img,
|
342 |
+
background_remove_state,
|
343 |
+
foreground_ratio,
|
344 |
+
],
|
345 |
+
outputs=[
|
346 |
+
run_btn,
|
347 |
+
img_proc_state,
|
348 |
+
background_remove_state,
|
349 |
+
preview_removal,
|
350 |
+
output_3d,
|
351 |
+
hdr_row,
|
352 |
+
],
|
353 |
+
)
|
354 |
+
|
355 |
+
demo.launch()
|