ZhiyuanthePony commited on
Commit
f876753
·
1 Parent(s): e9a4e66
app.py CHANGED
@@ -1,7 +1,160 @@
 
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
  import gradio as gr
4
+ from typing import *
5
+ from collections import deque
6
+ from diffusers import StableDiffusionPipeline
7
 
8
+ from triplaneturbo_executable import TriplaneTurboTextTo3DPipeline
9
+ from triplaneturbo_executable.utils.mesh_exporter import export_obj
10
 
11
+ # Initialize global variables
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+ ADAPTER_PATH = "pretrained/triplane_turbo_sd_v1.pth"
14
+ PIPELINE = None # Will hold our pipeline instance
15
+ OBJ_FILE_QUEUE = deque(maxlen=100) # Queue to store OBJ file paths
16
+
17
+ def download_model():
18
+ """Download the pretrained model if not exists"""
19
+ if not os.path.exists(ADAPTER_PATH):
20
+ print("Downloading pretrained models from huggingface")
21
+ os.system(
22
+ f"huggingface-cli download --resume-download ZhiyuanthePony/TriplaneTurbo \
23
+ --include \"triplane_turbo_sd_v1.pth\" \
24
+ --local-dir ./pretrained \
25
+ --local-dir-use-symlinks False"
26
+ )
27
+
28
+ def initialize_pipeline():
29
+ """Initialize the pipeline once and keep it in memory"""
30
+ global PIPELINE
31
+ if PIPELINE is None:
32
+ print("Initializing pipeline...")
33
+ PIPELINE = TriplaneTurboTextTo3DPipeline.from_pretrained(ADAPTER_PATH)
34
+ PIPELINE.to(DEVICE)
35
+ print("Pipeline initialized!")
36
+ return PIPELINE
37
+
38
+ def generate_3d_mesh(prompt: str) -> Tuple[str, str]:
39
+ """Generate 3D mesh from text prompt"""
40
+ global PIPELINE, OBJ_FILE_QUEUE
41
+
42
+ # Use the global pipeline instance
43
+ pipeline = initialize_pipeline()
44
+
45
+ # Use fixed seed value
46
+ seed = 42
47
+
48
+ # Generate mesh
49
+ output = pipeline(
50
+ prompt=prompt,
51
+ num_results_per_prompt=1,
52
+ generator=torch.Generator(device=DEVICE).manual_seed(seed),
53
+ )
54
+
55
+ # Save mesh
56
+ output_dir = "outputs"
57
+ os.makedirs(output_dir, exist_ok=True)
58
+
59
+ mesh_path = None
60
+ for i, mesh in enumerate(output["mesh"]):
61
+ vertices = mesh.v_pos
62
+
63
+ # 1. First rotate -90 degrees around X-axis to make the model face up
64
+ vertices = torch.stack([
65
+ vertices[:, 0], # x remains unchanged
66
+ vertices[:, 2], # y = z
67
+ -vertices[:, 1] # z = -y
68
+ ], dim=1)
69
+
70
+ # 2. Then rotate 90 degrees around Y-axis to make the model face the observer
71
+ vertices = torch.stack([
72
+ -vertices[:, 2], # x = -z
73
+ vertices[:, 1], # y remains unchanged
74
+ vertices[:, 0] # z = x
75
+ ], dim=1)
76
+
77
+ mesh.v_pos = vertices
78
+
79
+ # If mesh has normals, they need to be rotated in the same way
80
+ if mesh.v_nrm is not None:
81
+ normals = mesh.v_nrm
82
+ # 1. Rotate -90 degrees around X-axis
83
+ normals = torch.stack([
84
+ normals[:, 0],
85
+ normals[:, 2],
86
+ -normals[:, 1]
87
+ ], dim=1)
88
+ # 2. Rotate 90 degrees around Y-axis
89
+ normals = torch.stack([
90
+ -normals[:, 2],
91
+ normals[:, 1],
92
+ normals[:, 0]
93
+ ], dim=1)
94
+ mesh._v_nrm = normals
95
+
96
+ name = f"{prompt.replace(' ', '_')}"
97
+ save_paths = export_obj(mesh, f"{output_dir}/{name}.obj")
98
+ mesh_path = save_paths[0]
99
+
100
+ # Add new file path to queue
101
+ OBJ_FILE_QUEUE.append(mesh_path)
102
+
103
+ # If queue is at max length, remove oldest file
104
+ if len(OBJ_FILE_QUEUE) == OBJ_FILE_QUEUE.maxlen:
105
+ old_file = OBJ_FILE_QUEUE[0] # Get oldest file (will be automatically removed from queue)
106
+ if os.path.exists(old_file):
107
+ try:
108
+ os.remove(old_file)
109
+ except OSError as e:
110
+ print(f"Error deleting file {old_file}: {e}")
111
+
112
+ return mesh_path, mesh_path # Return the path twice - once for 3D preview, once for download
113
+
114
+ def main():
115
+ # Download model if needed
116
+ download_model()
117
+
118
+ # Initialize pipeline at startup
119
+ initialize_pipeline()
120
+
121
+ # Create Gradio interface
122
+ iface = gr.Interface(
123
+ fn=generate_3d_mesh,
124
+ inputs=[
125
+ gr.Textbox(
126
+ label="Text Prompt",
127
+ placeholder="Enter your text description...",
128
+ value="Armor dress style of outsiderzone fantasy helmet"
129
+ )
130
+ ],
131
+ outputs=[
132
+ gr.Model3D(
133
+ label="Generated 3D Mesh",
134
+ camera_position=(90, 90, 3),
135
+ clear_color=(0.5, 0.5, 0.5, 1),
136
+ ),
137
+ gr.File(label="Download OBJ file")
138
+ ],
139
+ title="Text to 3D Mesh Generation with TriplaneTurbo",
140
+ description="Demo of the paper Progressive Rendering Distillation: Adapting Stable Diffusion for Instant Text-to-Mesh Generation beyond 3D Training Data [CVPR 2025] <br><a href='https://github.com/theEricMa/TriplaneTurbo' style='color: #2196F3;'>https://github.com/theEricMa/TriplaneTurbo</a>",
141
+ examples=[
142
+ ["Armor dress style of outsiderzone fantasy helmet"],
143
+ ["Gandalf the grey riding a camel in a rock concert, victorian newspaper article, hyperrealistic"],
144
+ ["A DSLR photo of a bald eagle"],
145
+ ["A goblin riding a lawnmower in a hospital, victorian newspaper article, 4k hd"],
146
+ ["An imperial stormtrooper, highly detailed"],
147
+ ],
148
+ allow_flagging="never",
149
+ )
150
+
151
+ # Launch the interface
152
+ iface.launch(
153
+ server_name="0.0.0.0",
154
+ server_port=7860,
155
+ share=True,
156
+ show_error=True,
157
+ )
158
+
159
+ if __name__ == "__main__":
160
+ main()
example.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ from typing import *
5
+ from diffusers import StableDiffusionPipeline
6
+ from collections import deque
7
+
8
+ from triplaneturbo_executable.utils.mesh_exporter import export_obj
9
+ from triplaneturbo_executable import TriplaneTurboTextTo3DPipeline, TriplaneTurboTextTo3DPipelineConfig
10
+
11
+
12
+
13
+ # Initialize configuration and parameters
14
+ prompt = "a beautiful girl"
15
+ output_dir = "examples/output"
16
+ adapter_name_or_path = "pretrained/triplane_turbo_sd_v1.pth"
17
+ num_results_per_prompt = 1
18
+ seed = 42
19
+ device = "cuda"
20
+ max_obj_files = 100
21
+
22
+ # download pretrained models if not exist
23
+ if not os.path.exists(adapter_name_or_path):
24
+ print(f"Downloading pretrained models from huggingface")
25
+ os.system(
26
+ f"huggingface-cli download --resume-download ZhiyuanthePony/TriplaneTurbo \
27
+ --include \"triplane_turbo_sd_v1.pth\" \
28
+ --local-dir ./pretrained \
29
+ --local-dir-use-symlinks False"
30
+ )
31
+
32
+
33
+ # Initialize the TriplaneTurbo pipeline
34
+ triplane_turbo_pipeline = TriplaneTurboTextTo3DPipeline.from_pretrained(adapter_name_or_path)
35
+ triplane_turbo_pipeline.to(device)
36
+
37
+ # Run the pipeline
38
+ output = triplane_turbo_pipeline(
39
+ prompt=prompt,
40
+ num_results_per_prompt=num_results_per_prompt,
41
+ generator=torch.Generator(device=device).manual_seed(seed),
42
+ device=device,
43
+ )
44
+
45
+ # Initialize a deque with maximum length of 100 to store obj file paths
46
+ obj_file_queue = deque(maxlen=max_obj_files)
47
+
48
+ # Save mesh
49
+ os.makedirs(output_dir, exist_ok=True)
50
+ for i, mesh in enumerate(output["mesh"]):
51
+ vertices = mesh.v_pos
52
+
53
+ # 1. First rotate -90 degrees around X-axis to make the model face up
54
+ vertices = torch.stack([
55
+ vertices[:, 0], # x remains unchanged
56
+ vertices[:, 2], # y = z
57
+ -vertices[:, 1] # z = -y
58
+ ], dim=1)
59
+
60
+ # 2. Then rotate 90 degrees around Y-axis to make the model face the observer
61
+ vertices = torch.stack([
62
+ -vertices[:, 2], # x = -z
63
+ vertices[:, 1], # y remains unchanged
64
+ vertices[:, 0] # z = x
65
+ ], dim=1)
66
+
67
+ mesh.v_pos = vertices
68
+
69
+ # If mesh has normals, they need to be rotated in the same way
70
+ if mesh.v_nrm is not None:
71
+ normals = mesh.v_nrm
72
+ # 1. Rotate -90 degrees around X-axis
73
+ normals = torch.stack([
74
+ normals[:, 0],
75
+ normals[:, 2],
76
+ -normals[:, 1]
77
+ ], dim=1)
78
+ # 2. Rotate 90 degrees around Y-axis
79
+ normals = torch.stack([
80
+ -normals[:, 2],
81
+ normals[:, 1],
82
+ normals[:, 0]
83
+ ], dim=1)
84
+ mesh._v_nrm = normals
85
+
86
+ # Save obj file and add its path to the queue
87
+ name = f"{prompt.replace(' ', '_')}_{seed}_{i}"
88
+ save_paths = export_obj(mesh, f"{output_dir}/{name}.obj")
89
+ obj_file_queue.append(save_paths[0])
90
+
91
+ # If an old file needs to be removed (queue is at max length)
92
+ # and the file exists, delete it
93
+ if len(obj_file_queue) == max_obj_files and os.path.exists(obj_file_queue[0]):
94
+ old_file = obj_file_queue[0]
95
+ try:
96
+ os.remove(old_file)
97
+ except OSError as e:
98
+ print(f"Error deleting file {old_file}: {e}")
99
+
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf==2.3.0
2
+ jaxtyping
3
+ typeguard
4
+ diffusers==0.25
5
+ transformers==4.28.1
6
+ accelerate
7
+ imageio>=2.28.0
8
+ imageio[ffmpeg]
9
+ git+https://github.com/NVlabs/nvdiffrast.git
10
+ libigl
11
+ trimesh[easy]
12
+ networkx
13
+ pysdf
14
+ PyMCubes
15
+ wandb
16
+ torchmetrics
17
+ huggingface_hub==0.24.7
18
+ numpy==1.26.4
19
+ gradio==2.9.4
20
+
21
+ # # 3d gaussian
22
+ # plyfile
23
+
24
+ # diffmc
25
+ diso
26
+ einops
triplaneturbo_executable/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .pipelines.triplaneturbo_text_to_3d import (
2
+ TriplaneTurboTextTo3DPipeline,
3
+ TriplaneTurboTextTo3DPipelineConfig
4
+ )
5
+
6
+ __all__ = [
7
+ "TriplaneTurboTextTo3DPipeline",
8
+ "TriplaneTurboTextTo3DPipelineConfig"
9
+ ]
triplaneturbo_executable/extern/sd_dual_triplane_modules.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torch.nn as nn
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Union, Tuple
6
+
7
+ from diffusers.models.attention_processor import Attention
8
+ from diffusers import (
9
+ DDPMScheduler,
10
+ UNet2DConditionModel,
11
+ AutoencoderKL
12
+ )
13
+ from diffusers.loaders import AttnProcsLayers
14
+
15
+
16
+ class LoRALinearLayerwBias(nn.Module):
17
+ r"""
18
+ A linear layer that is used with LoRA, can be used with bias.
19
+
20
+ Parameters:
21
+ in_features (`int`):
22
+ Number of input features.
23
+ out_features (`int`):
24
+ Number of output features.
25
+ rank (`int`, `optional`, defaults to 4):
26
+ The rank of the LoRA layer.
27
+ network_alpha (`float`, `optional`, defaults to `None`):
28
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
29
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
30
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
31
+ device (`torch.device`, `optional`, defaults to `None`):
32
+ The device to use for the layer's weights.
33
+ dtype (`torch.dtype`, `optional`, defaults to `None`):
34
+ The dtype to use for the layer's weights.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ in_features: int,
40
+ out_features: int,
41
+ rank: int = 4,
42
+ network_alpha: Optional[float] = None,
43
+ device: Optional[Union[torch.device, str]] = None,
44
+ dtype: Optional[torch.dtype] = None,
45
+ with_bias: bool = False
46
+ ):
47
+ super().__init__()
48
+
49
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
50
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
51
+ if with_bias:
52
+ self.bias = nn.Parameter(torch.zeros([1, 1, out_features], device=device, dtype=dtype))
53
+ self.with_bias = with_bias
54
+
55
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
56
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
57
+ self.network_alpha = network_alpha
58
+ self.rank = rank
59
+ self.out_features = out_features
60
+ self.in_features = in_features
61
+
62
+ nn.init.normal_(self.down.weight, std=1 / rank)
63
+ nn.init.zeros_(self.up.weight)
64
+
65
+
66
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
67
+ orig_dtype = hidden_states.dtype
68
+ dtype = self.down.weight.dtype
69
+
70
+ down_hidden_states = self.down(hidden_states.to(dtype))
71
+ up_hidden_states = self.up(down_hidden_states)
72
+ if self.with_bias:
73
+ up_hidden_states = up_hidden_states + self.bias
74
+
75
+ if self.network_alpha is not None:
76
+ up_hidden_states *= self.network_alpha / self.rank
77
+
78
+ return up_hidden_states.to(orig_dtype)
79
+
80
+ class TriplaneLoRAConv2dLayer(nn.Module):
81
+ r"""
82
+ A convolutional layer that is used with LoRA.
83
+
84
+ Parameters:
85
+ in_features (`int`):
86
+ Number of input features.
87
+ out_features (`int`):
88
+ Number of output features.
89
+ rank (`int`, `optional`, defaults to 4):
90
+ The rank of the LoRA layer.
91
+ kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):
92
+ The kernel size of the convolution.
93
+ stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):
94
+ The stride of the convolution.
95
+ padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):
96
+ The padding of the convolution.
97
+ network_alpha (`float`, `optional`, defaults to `None`):
98
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
99
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
100
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ in_features: int,
106
+ out_features: int,
107
+ rank: int = 4,
108
+ kernel_size: Union[int, Tuple[int, int]] = (1, 1),
109
+ stride: Union[int, Tuple[int, int]] = (1, 1),
110
+ padding: Union[int, Tuple[int, int], str] = 0,
111
+ network_alpha: Optional[float] = None,
112
+ with_bias: bool = False,
113
+ locon_type: str = "hexa_v1", #hexa_v2, vanilla_v1, vanilla_v2
114
+ ):
115
+ super().__init__()
116
+
117
+ assert locon_type in ["hexa_v1", "hexa_v2", "vanilla_v1", "vanilla_v2"], "The LoCON type is not supported."
118
+ if locon_type == "hexa_v1":
119
+ self.down_xy_geo = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
120
+ self.down_xz_geo = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
121
+ self.down_yz_geo = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
122
+ self.down_xy_tex = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
123
+ self.down_xz_tex = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
124
+ self.down_yz_tex = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
125
+ # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
126
+ # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
127
+ self.up_xy_geo = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias)
128
+ self.up_xz_geo = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias)
129
+ self.up_yz_geo = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias)
130
+ self.up_xy_tex = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias)
131
+ self.up_xz_tex = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias)
132
+ self.up_yz_tex = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias)
133
+
134
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
135
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
136
+
137
+ elif locon_type == "hexa_v2":
138
+ self.down_xy_geo = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False)
139
+ self.down_xz_geo = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False)
140
+ self.down_yz_geo = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False)
141
+ self.down_xy_tex = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False)
142
+ self.down_xz_tex = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False)
143
+ self.down_yz_tex = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False)
144
+
145
+ self.up_xy_geo = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias)
146
+ self.up_xz_geo = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias)
147
+ self.up_yz_geo = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias)
148
+ self.up_xy_tex = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias)
149
+ self.up_xz_tex = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias)
150
+ self.up_yz_tex = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias)
151
+
152
+ elif locon_type == "vanilla_v1":
153
+ self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
154
+ self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias)
155
+
156
+ elif locon_type == "vanilla_v2":
157
+ self.down = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1), padding=padding, bias=False)
158
+ self.up = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias)
159
+
160
+ self.network_alpha = network_alpha
161
+ self.rank = rank
162
+ self.locon_type = locon_type
163
+ self._init_weights()
164
+
165
+ def _init_weights(self):
166
+ for layer in [
167
+ "down_xy_geo", "down_xz_geo", "down_yz_geo", "down_xy_tex", "down_xz_tex", "down_yz_tex", # in case of hexa_vX
168
+ "up_xy", "up_xz", "up_yz", "up_xy_tex", "up_xz_tex", "up_yz_tex", # in case of hexa_vX
169
+ "down", "up" # in case of vanilla
170
+ ]:
171
+ if hasattr(self, layer):
172
+ # initialize the weights
173
+ if "down" in layer:
174
+ nn.init.normal_(getattr(self, layer).weight, std=1 / self.rank)
175
+ elif "up" in layer:
176
+ nn.init.zeros_(getattr(self, layer).weight)
177
+ # initialize the bias
178
+ if getattr(self, layer).bias is not None:
179
+ nn.init.zeros_(getattr(self, layer).bias)
180
+
181
+
182
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
183
+ orig_dtype = hidden_states.dtype
184
+ dtype = self.down_xy_geo.weight.dtype if "hexa" in self.locon_type else self.down.weight.dtype
185
+
186
+ if "hexa" in self.locon_type:
187
+ # xy plane
188
+ hidden_states_xy_geo = self.up_xy_geo(self.down_xy_geo(hidden_states[0::6].to(dtype)))
189
+ hidden_states_xy_tex = self.up_xy_tex(self.down_xy_tex(hidden_states[3::6].to(dtype)))
190
+
191
+ lora_hidden_states = torch.concat(
192
+ [torch.zeros_like(hidden_states_xy_tex)] * 6,
193
+ dim=0
194
+ )
195
+
196
+ lora_hidden_states[0::6] = hidden_states_xy_geo
197
+ lora_hidden_states[3::6] = hidden_states_xy_tex
198
+
199
+ # xz plane
200
+ lora_hidden_states[1::6] = self.up_xz_geo(self.down_xz_geo(hidden_states[1::6].to(dtype)))
201
+ lora_hidden_states[4::6] = self.up_xz_tex(self.down_xz_tex(hidden_states[4::6].to(dtype)))
202
+ # yz plane
203
+ lora_hidden_states[2::6] = self.up_yz_geo(self.down_yz_geo(hidden_states[2::6].to(dtype)))
204
+ lora_hidden_states[5::6] = self.up_yz_tex(self.down_yz_tex(hidden_states[5::6].to(dtype)))
205
+
206
+ elif "vanilla" in self.locon_type:
207
+ lora_hidden_states = self.up(self.down(hidden_states.to(dtype)))
208
+
209
+ if self.network_alpha is not None:
210
+ lora_hidden_states *= self.network_alpha / self.rank
211
+
212
+ return lora_hidden_states.to(orig_dtype)
213
+
214
+ class TriplaneSelfAttentionLoRAAttnProcessor(nn.Module):
215
+ """
216
+ Perform for implementing the Triplane Self-Attention LoRA Attention Processor.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ hidden_size: int,
222
+ rank: int = 4,
223
+ network_alpha: Optional[float] = None,
224
+ with_bias: bool = False,
225
+ lora_type: str = "hexa_v1", # vanilla,
226
+ ):
227
+ super().__init__()
228
+
229
+ assert lora_type in ["hexa_v1", "vanilla", "none", "basic"], "The LoRA type is not supported."
230
+
231
+ self.hidden_size = hidden_size
232
+ self.rank = rank
233
+ self.lora_type = lora_type
234
+
235
+ if lora_type in ["hexa_v1"]:
236
+ # lora for 1st plane geometry
237
+ self.to_q_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
238
+ self.to_k_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
239
+ self.to_v_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
240
+ self.to_out_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
241
+
242
+ # lora for 1st plane texture
243
+ self.to_q_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
244
+ self.to_k_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
245
+ self.to_v_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
246
+ self.to_out_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
247
+
248
+ # lora for 2nd plane geometry
249
+ self.to_q_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
250
+ self.to_k_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
251
+ self.to_v_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
252
+ self.to_out_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
253
+
254
+ # lora for 2nd plane texture
255
+ self.to_q_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
256
+ self.to_k_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
257
+ self.to_v_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
258
+ self.to_out_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
259
+
260
+ # lora for 3nd plane geometry
261
+ self.to_q_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
262
+ self.to_k_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
263
+ self.to_v_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
264
+ self.to_out_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
265
+
266
+ # lora for 3nd plane texture
267
+ self.to_q_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
268
+ self.to_k_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
269
+ self.to_v_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
270
+ self.to_out_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
271
+
272
+ elif lora_type in ["vanilla", "basic"]:
273
+ self.to_q_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
274
+ self.to_k_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
275
+ self.to_v_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
276
+ self.to_out_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
277
+
278
+ def __call__(
279
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
280
+ ):
281
+ assert encoder_hidden_states is None, "The encoder_hidden_states should be None."
282
+
283
+ residual = hidden_states
284
+
285
+ if attn.spatial_norm is not None:
286
+ hidden_states = attn.spatial_norm(hidden_states, temb)
287
+
288
+ input_ndim = hidden_states.ndim
289
+
290
+ if input_ndim == 4:
291
+ batch_size, channel, height, width = hidden_states.shape
292
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
293
+
294
+ batch_size, sequence_length, _ = (
295
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
296
+ )
297
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
298
+
299
+ if attn.group_norm is not None:
300
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
301
+
302
+
303
+ ############################################################################################################
304
+ # query
305
+ if self.lora_type in ["hexa_v1",]:
306
+ query = attn.to_q(hidden_states)
307
+ _query_new = torch.zeros_like(query)
308
+ # lora for xy plane geometry
309
+ _query_new[0::6] = self.to_q_xy_lora_geo(hidden_states[0::6])
310
+ # lora for xy plane texture
311
+ _query_new[3::6] = self.to_q_xy_lora_tex(hidden_states[3::6])
312
+ # lora for xz plane geometry
313
+ _query_new[1::6] = self.to_q_xz_lora_geo(hidden_states[1::6])
314
+ # lora for xz plane texture
315
+ _query_new[4::6] = self.to_q_xz_lora_tex(hidden_states[4::6])
316
+ # lora for yz plane geometry
317
+ _query_new[2::6] = self.to_q_yz_lora_geo(hidden_states[2::6])
318
+ # lora for yz plane texture
319
+ _query_new[5::6] = self.to_q_yz_lora_tex(hidden_states[5::6])
320
+ query = query + scale * _query_new
321
+
322
+ # # speed up inference
323
+ # query[0::6] += self.to_q_xy_lora_geo(hidden_states[0::6]) * scale
324
+ # query[3::6] += self.to_q_xy_lora_tex(hidden_states[3::6]) * scale
325
+ # query[1::6] += self.to_q_xz_lora_geo(hidden_states[1::6]) * scale
326
+ # query[4::6] += self.to_q_xz_lora_tex(hidden_states[4::6]) * scale
327
+ # query[2::6] += self.to_q_yz_lora_geo(hidden_states[2::6]) * scale
328
+ # query[5::6] += self.to_q_yz_lora_tex(hidden_states[5::6]) * scale
329
+
330
+ elif self.lora_type in ["vanilla", "basic"]:
331
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
332
+ elif self.lora_type in ["none"]:
333
+ query = attn.to_q(hidden_states)
334
+ else:
335
+ raise NotImplementedError("The LoRA type is not supported for the query in HplaneSelfAttentionLoRAAttnProcessor.")
336
+
337
+ ############################################################################################################
338
+
339
+ if encoder_hidden_states is None:
340
+ encoder_hidden_states = hidden_states
341
+ elif attn.norm_cross:
342
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
343
+
344
+ ############################################################################################################
345
+ # key and value
346
+ if self.lora_type in ["hexa_v1",]:
347
+ key = attn.to_k(encoder_hidden_states)
348
+ _key_new = torch.zeros_like(key)
349
+ # lora for xy plane geometry
350
+ _key_new[0::6] = self.to_k_xy_lora_geo(encoder_hidden_states[0::6])
351
+ # lora for xy plane texture
352
+ _key_new[3::6] = self.to_k_xy_lora_tex(encoder_hidden_states[3::6])
353
+ # lora for xz plane geometry
354
+ _key_new[1::6] = self.to_k_xz_lora_geo(encoder_hidden_states[1::6])
355
+ # lora for xz plane texture
356
+ _key_new[4::6] = self.to_k_xz_lora_tex(encoder_hidden_states[4::6])
357
+ # lora for yz plane geometry
358
+ _key_new[2::6] = self.to_k_yz_lora_geo(encoder_hidden_states[2::6])
359
+ # lora for yz plane texture
360
+ _key_new[5::6] = self.to_k_yz_lora_tex(encoder_hidden_states[5::6])
361
+ key = key + scale * _key_new
362
+
363
+ # # speed up inference
364
+ # key[0::6] += self.to_k_xy_lora_geo(encoder_hidden_states[0::6]) * scale
365
+ # key[3::6] += self.to_k_xy_lora_tex(encoder_hidden_states[3::6]) * scale
366
+ # key[1::6] += self.to_k_xz_lora_geo(encoder_hidden_states[1::6]) * scale
367
+ # key[4::6] += self.to_k_xz_lora_tex(encoder_hidden_states[4::6]) * scale
368
+ # key[2::6] += self.to_k_yz_lora_geo(encoder_hidden_states[2::6]) * scale
369
+ # key[5::6] += self.to_k_yz_lora_tex(encoder_hidden_states[5::6]) * scale
370
+
371
+ value = attn.to_v(encoder_hidden_states)
372
+ _value_new = torch.zeros_like(value)
373
+ # lora for xy plane geometry
374
+ _value_new[0::6] = self.to_v_xy_lora_geo(encoder_hidden_states[0::6])
375
+ # lora for xy plane texture
376
+ _value_new[3::6] = self.to_v_xy_lora_tex(encoder_hidden_states[3::6])
377
+ # lora for xz plane geometry
378
+ _value_new[1::6] = self.to_v_xz_lora_geo(encoder_hidden_states[1::6])
379
+ # lora for xz plane texture
380
+ _value_new[4::6] = self.to_v_xz_lora_tex(encoder_hidden_states[4::6])
381
+ # lora for yz plane geometry
382
+ _value_new[2::6] = self.to_v_yz_lora_geo(encoder_hidden_states[2::6])
383
+ # lora for yz plane texture
384
+ _value_new[5::6] = self.to_v_yz_lora_tex(encoder_hidden_states[5::6])
385
+ value = value + scale * _value_new
386
+
387
+ # # speed up inference
388
+ # value[0::6] += self.to_v_xy_lora_geo(encoder_hidden_states[0::6]) * scale
389
+ # value[3::6] += self.to_v_xy_lora_tex(encoder_hidden_states[3::6]) * scale
390
+ # value[1::6] += self.to_v_xz_lora_geo(encoder_hidden_states[1::6]) * scale
391
+ # value[4::6] += self.to_v_xz_lora_tex(encoder_hidden_states[4::6]) * scale
392
+ # value[2::6] += self.to_v_yz_lora_geo(encoder_hidden_states[2::6]) * scale
393
+ # value[5::6] += self.to_v_yz_lora_tex(encoder_hidden_states[5::6]) * scale
394
+
395
+ elif self.lora_type in ["vanilla", "basic"]:
396
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
397
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
398
+
399
+ elif self.lora_type in ["none", ]:
400
+ key = attn.to_k(encoder_hidden_states)
401
+ value = attn.to_v(encoder_hidden_states)
402
+
403
+ else:
404
+ raise NotImplementedError("The LoRA type is not supported for the key and value in HplaneSelfAttentionLoRAAttnProcessor.")
405
+
406
+ ############################################################################################################
407
+ # attention scores
408
+
409
+ # in self-attention, query of each plane should be used to calculate the attention scores of all planes
410
+ if self.lora_type in ["hexa_v1", "vanilla",]:
411
+ query = attn.head_to_batch_dim(
412
+ query.view(batch_size // 6, sequence_length * 6, self.hidden_size)
413
+ )
414
+ key = attn.head_to_batch_dim(
415
+ key.view(batch_size // 6, sequence_length * 6, self.hidden_size)
416
+ )
417
+ value = attn.head_to_batch_dim(
418
+ value.view(batch_size // 6, sequence_length * 6, self.hidden_size)
419
+ )
420
+ # calculate the attention scores
421
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
422
+ hidden_states = torch.bmm(attention_probs, value)
423
+ hidden_states = attn.batch_to_head_dim(hidden_states)
424
+ # split the hidden states into 6 planes
425
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.hidden_size)
426
+ elif self.lora_type in ["none", "basic"]:
427
+ query = attn.head_to_batch_dim(query)
428
+ key = attn.head_to_batch_dim(key)
429
+ value = attn.head_to_batch_dim(value)
430
+ # calculate the attention scores
431
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
432
+ hidden_states = torch.bmm(attention_probs, value)
433
+ hidden_states = attn.batch_to_head_dim(hidden_states)
434
+ else:
435
+ raise NotImplementedError("The LoRA type is not supported for attention scores calculation in HplaneSelfAttentionLoRAAttnProcessor.")
436
+
437
+ ############################################################################################################
438
+ # linear proj
439
+ if self.lora_type in ["hexa_v1", ]:
440
+ hidden_states = attn.to_out[0](hidden_states)
441
+ _hidden_states_new = torch.zeros_like(hidden_states)
442
+ # lora for xy plane geometry
443
+ _hidden_states_new[0::6] = self.to_out_xy_lora_geo(hidden_states[0::6])
444
+ # lora for xy plane texture
445
+ _hidden_states_new[3::6] = self.to_out_xy_lora_tex(hidden_states[3::6])
446
+ # lora for xz plane geometry
447
+ _hidden_states_new[1::6] = self.to_out_xz_lora_geo(hidden_states[1::6])
448
+ # lora for xz plane texture
449
+ _hidden_states_new[4::6] = self.to_out_xz_lora_tex(hidden_states[4::6])
450
+ # lora for yz plane geometry
451
+ _hidden_states_new[2::6] = self.to_out_yz_lora_geo(hidden_states[2::6])
452
+ # lora for yz plane texture
453
+ _hidden_states_new[5::6] = self.to_out_yz_lora_tex(hidden_states[5::6])
454
+ hidden_states = hidden_states + scale * _hidden_states_new
455
+
456
+ # # speed up inference
457
+ # hidden_states[0::6] += self.to_out_xy_lora_geo(hidden_states[0::6]) * scale
458
+ # hidden_states[3::6] += self.to_out_xy_lora_tex(hidden_states[3::6]) * scale
459
+ # hidden_states[1::6] += self.to_out_xz_lora_geo(hidden_states[1::6]) * scale
460
+ # hidden_states[4::6] += self.to_out_xz_lora_tex(hidden_states[4::6]) * scale
461
+ # hidden_states[2::6] += self.to_out_yz_lora_geo(hidden_states[2::6]) * scale
462
+ # hidden_states[5::6] += self.to_out_yz_lora_tex(hidden_states[5::6]) * scale
463
+
464
+ elif self.lora_type in ["vanilla", "basic"]:
465
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
466
+ elif self.lora_type in ["none",]:
467
+ hidden_states = attn.to_out[0](hidden_states)
468
+ else:
469
+ raise NotImplementedError("The LoRA type is not supported for the to_out layer in HplaneSelfAttentionLoRAAttnProcessor.")
470
+
471
+ # dropout
472
+ hidden_states = attn.to_out[1](hidden_states)
473
+ ############################################################################################################
474
+
475
+ if input_ndim == 4:
476
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
477
+
478
+ if attn.residual_connection:
479
+ hidden_states = hidden_states + residual
480
+
481
+ hidden_states = hidden_states / attn.rescale_output_factor
482
+
483
+ return hidden_states
484
+
485
+ class TriplaneCrossAttentionLoRAAttnProcessor(nn.Module):
486
+ """
487
+ Perform for implementing the Triplane Cross-Attention LoRA Attention Processor.
488
+ """
489
+
490
+ def __init__(
491
+ self,
492
+ hidden_size: int,
493
+ cross_attention_dim: int,
494
+ rank: int = 4,
495
+ network_alpha: Optional[float] = None,
496
+ with_bias: bool = False,
497
+ lora_type: str = "hexa_v1", # vanilla,
498
+ ):
499
+ super().__init__()
500
+
501
+ assert lora_type in ["hexa_v1", "vanilla", "none"], "The LoRA type is not supported."
502
+
503
+ self.hidden_size = hidden_size
504
+ self.rank = rank
505
+ self.lora_type = lora_type
506
+
507
+ if lora_type in ["hexa_v1"]:
508
+ # lora for 1st plane geometry
509
+ self.to_q_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
510
+ self.to_k_xy_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
511
+ self.to_v_xy_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
512
+ self.to_out_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
513
+
514
+ # lora for 1st plane texture
515
+ self.to_q_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
516
+ self.to_k_xy_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
517
+ self.to_v_xy_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
518
+ self.to_out_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
519
+
520
+ # lora for 2nd plane geometry
521
+ self.to_q_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
522
+ self.to_k_xz_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
523
+ self.to_v_xz_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
524
+ self.to_out_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
525
+
526
+ # lora for 2nd plane texture
527
+ self.to_q_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
528
+ self.to_k_xz_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
529
+ self.to_v_xz_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
530
+ self.to_out_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
531
+
532
+ # lora for 3nd plane geometry
533
+ self.to_q_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
534
+ self.to_k_yz_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
535
+ self.to_v_yz_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
536
+ self.to_out_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
537
+
538
+ # lora for 3nd plane texture
539
+ self.to_q_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
540
+ self.to_k_yz_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
541
+ self.to_v_yz_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
542
+ self.to_out_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
543
+
544
+ elif lora_type in ["vanilla"]:
545
+ # lora for all planes
546
+ self.to_q_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
547
+ self.to_k_lora = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
548
+ self.to_v_lora = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias)
549
+ self.to_out_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias)
550
+
551
+ def __call__(
552
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
553
+ ):
554
+ assert encoder_hidden_states is not None, "The encoder_hidden_states should not be None."
555
+
556
+ residual = hidden_states
557
+
558
+ if attn.spatial_norm is not None:
559
+ hidden_states = attn.spatial_norm(hidden_states, temb)
560
+
561
+ input_ndim = hidden_states.ndim
562
+
563
+ if input_ndim == 4:
564
+ batch_size, channel, height, width = hidden_states.shape
565
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
566
+
567
+ batch_size, sequence_length, _ = (
568
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
569
+ )
570
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
571
+
572
+ if attn.group_norm is not None:
573
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
574
+
575
+ ############################################################################################################
576
+ # query
577
+ if self.lora_type in ["hexa_v1",]:
578
+ query = attn.to_q(hidden_states)
579
+ _query_new = torch.zeros_like(query)
580
+ # lora for xy plane geometry
581
+ _query_new[0::6] = self.to_q_xy_lora_geo(hidden_states[0::6])
582
+ # lora for xy plane texture
583
+ _query_new[3::6] = self.to_q_xy_lora_tex(hidden_states[3::6])
584
+ # lora for xz plane geometry
585
+ _query_new[1::6] = self.to_q_xz_lora_geo(hidden_states[1::6])
586
+ # lora for xz plane texture
587
+ _query_new[4::6] = self.to_q_xz_lora_tex(hidden_states[4::6])
588
+ # lora for yz plane geometry
589
+ _query_new[2::6] = self.to_q_yz_lora_geo(hidden_states[2::6])
590
+ # lora for yz plane texture
591
+ _query_new[5::6] = self.to_q_yz_lora_tex(hidden_states[5::6])
592
+ query = query + scale * _query_new
593
+
594
+ elif self.lora_type == "vanilla":
595
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
596
+
597
+ elif self.lora_type == "none":
598
+ query = attn.to_q(hidden_states)
599
+
600
+ query = attn.head_to_batch_dim(query)
601
+ ############################################################################################################
602
+
603
+ if encoder_hidden_states is None:
604
+ encoder_hidden_states = hidden_states
605
+ elif attn.norm_cross:
606
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
607
+
608
+ ############################################################################################################
609
+ # key and value
610
+ if self.lora_type in ["hexa_v1",]:
611
+ key = attn.to_k(encoder_hidden_states)
612
+ _key_new = torch.zeros_like(key)
613
+ # lora for xy plane geometry
614
+ _key_new[0::6] = self.to_k_xy_lora_geo(encoder_hidden_states[0::6])
615
+ # lora for xy plane texture
616
+ _key_new[3::6] = self.to_k_xy_lora_tex(encoder_hidden_states[3::6])
617
+ # lora for xz plane geometry
618
+ _key_new[1::6] = self.to_k_xz_lora_geo(encoder_hidden_states[1::6])
619
+ # lora for xz plane texture
620
+ _key_new[4::6] = self.to_k_xz_lora_tex(encoder_hidden_states[4::6])
621
+ # lora for yz plane geometry
622
+ _key_new[2::6] = self.to_k_yz_lora_geo(encoder_hidden_states[2::6])
623
+ # lora for yz plane texture
624
+ _key_new[5::6] = self.to_k_yz_lora_tex(encoder_hidden_states[5::6])
625
+ key = key + scale * _key_new
626
+
627
+ value = attn.to_v(encoder_hidden_states)
628
+ _value_new = torch.zeros_like(value)
629
+ # lora for xy plane geometry
630
+ _value_new[0::6] = self.to_v_xy_lora_geo(encoder_hidden_states[0::6])
631
+ # lora for xy plane texture
632
+ _value_new[3::6] = self.to_v_xy_lora_tex(encoder_hidden_states[3::6])
633
+ # lora for xz plane geometry
634
+ _value_new[1::6] = self.to_v_xz_lora_geo(encoder_hidden_states[1::6])
635
+ # lora for xz plane texture
636
+ _value_new[4::6] = self.to_v_xz_lora_tex(encoder_hidden_states[4::6])
637
+ # lora for yz plane geometry
638
+ _value_new[2::6] = self.to_v_yz_lora_geo(encoder_hidden_states[2::6])
639
+ # lora for yz plane texture
640
+ _value_new[5::6] = self.to_v_yz_lora_tex(encoder_hidden_states[5::6])
641
+ value = value + scale * _value_new
642
+
643
+ elif self.lora_type in ["vanilla",]:
644
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
645
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
646
+
647
+ elif self.lora_type in ["none",]:
648
+ key = attn.to_k(encoder_hidden_states)
649
+ value = attn.to_v(encoder_hidden_states)
650
+
651
+ key = attn.head_to_batch_dim(key)
652
+ value = attn.head_to_batch_dim(value)
653
+ ############################################################################################################
654
+
655
+ # calculate the attention scores
656
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
657
+ hidden_states = torch.bmm(attention_probs, value)
658
+ hidden_states = attn.batch_to_head_dim(hidden_states)
659
+
660
+
661
+ ############################################################################################################
662
+ # linear proj
663
+ if self.lora_type in ["hexa_v1", ]:
664
+ hidden_states = attn.to_out[0](hidden_states)
665
+ _hidden_states_new = torch.zeros_like(hidden_states)
666
+ # lora for xy plane geometry
667
+ _hidden_states_new[0::6] = self.to_out_xy_lora_geo(hidden_states[0::6])
668
+ # lora for xy plane texture
669
+ _hidden_states_new[3::6] = self.to_out_xy_lora_tex(hidden_states[3::6])
670
+ # lora for xz plane geometry
671
+ _hidden_states_new[1::6] = self.to_out_xz_lora_geo(hidden_states[1::6])
672
+ # lora for xz plane texture
673
+ _hidden_states_new[4::6] = self.to_out_xz_lora_tex(hidden_states[4::6])
674
+ # lora for yz plane geometry
675
+ _hidden_states_new[2::6] = self.to_out_yz_lora_geo(hidden_states[2::6])
676
+ # lora for yz plane texture
677
+ _hidden_states_new[5::6] = self.to_out_yz_lora_tex(hidden_states[5::6])
678
+ hidden_states = hidden_states + scale * _hidden_states_new
679
+ elif self.lora_type in ["vanilla",]:
680
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
681
+ elif self.lora_type in ["none",]:
682
+ hidden_states = attn.to_out[0](hidden_states)
683
+ else:
684
+ raise NotImplementedError("The LoRA type is not supported for the to_out layer in HplaneCrossAttentionLoRAAttnProcessor.")
685
+
686
+ # dropout
687
+ hidden_states = attn.to_out[1](hidden_states)
688
+ ############################################################################################################
689
+
690
+ if input_ndim == 4:
691
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
692
+
693
+ if attn.residual_connection:
694
+ hidden_states = hidden_states + residual
695
+
696
+ hidden_states = hidden_states / attn.rescale_output_factor
697
+
698
+ return hidden_states
699
+
700
+ @dataclass
701
+ class GeneratorConfig:
702
+ training_type: str = "self_lora_rank_16-cross_lora_rank_16-locon_rank_16"
703
+ output_dim: int = 32
704
+ self_lora_type: str = "hexa_v1"
705
+ cross_lora_type: str = "hexa_v1"
706
+ locon_type: str = "vanilla_v1"
707
+ vae_attn_type: str = "basic"
708
+ prompt_bias: bool = False
709
+
710
+ class OneStepTriplaneDualStableDiffusion(nn.Module):
711
+ """
712
+ One-step Triplane Stable Diffusion module.
713
+ """
714
+ def __init__(
715
+ self,
716
+ config: Union[dict, GeneratorConfig],
717
+ vae: AutoencoderKL,
718
+ unet: UNet2DConditionModel,
719
+ ):
720
+ super().__init__()
721
+ # Convert dict to GeneratorConfig if needed
722
+ self.cfg = GeneratorConfig(**config) if isinstance(config, dict) else config
723
+ self.output_dim = self.cfg.output_dim
724
+
725
+ # Load models
726
+ self.unet = unet
727
+ self.vae = vae
728
+
729
+ # Get device from one of the models
730
+ self.device = next(self.unet.parameters()).device
731
+
732
+ # Remove unused components
733
+ del vae.encoder
734
+ del vae.quant_conv
735
+
736
+ # Get training type from config
737
+ training_type = self.cfg.training_type
738
+
739
+ # save trainable parameters
740
+ if not "full" in training_type: # then paramter-efficient training
741
+
742
+ trainable_params = {}
743
+
744
+ assert "lora" in training_type or "locon" in training_type, "The training type is not supported."
745
+ @dataclass
746
+ class SubModules:
747
+ unet: UNet2DConditionModel
748
+ vae: AutoencoderKL
749
+
750
+ self.submodules = SubModules(
751
+ unet=unet.to(self.device),
752
+ vae=vae.to(self.device),
753
+ )
754
+
755
+ # free all the parameters
756
+ for param in self.unet.parameters():
757
+ param.requires_grad_(False)
758
+ for param in self.vae.parameters():
759
+ param.requires_grad_(False)
760
+
761
+ ############################################################
762
+ # overwrite the unet and vae with the customized processors
763
+
764
+ if "lora" in training_type:
765
+
766
+ # parse the rank from the training type, with the template "lora_rank_{}"
767
+ assert "self_lora_rank" in training_type, "The self_lora_rank is not specified."
768
+ rank = re.search(r"self_lora_rank_(\d+)", training_type).group(1)
769
+ self.self_lora_rank = int(rank)
770
+
771
+ assert "cross_lora_rank" in training_type, "The cross_lora_rank is not specified."
772
+ rank = re.search(r"cross_lora_rank_(\d+)", training_type).group(1)
773
+ self.cross_lora_rank = int(rank)
774
+
775
+ # if the finetuning is with bias
776
+ self.w_lora_bias = False
777
+ if "with_bias" in training_type:
778
+ self.w_lora_bias = True
779
+
780
+ # specify the attn_processor for unet
781
+ lora_attn_procs = self._set_attn_processor(
782
+ self.unet,
783
+ self_attn_name="attn1.processor",
784
+ self_lora_type=self.cfg.self_lora_type,
785
+ cross_lora_type=self.cfg.cross_lora_type
786
+ )
787
+ self.unet.set_attn_processor(lora_attn_procs)
788
+ # update the trainable parameters
789
+ trainable_params.update(self.unet.attn_processors)
790
+
791
+ # specify the attn_processor for vae
792
+ lora_attn_procs = self._set_attn_processor(
793
+ self.vae,
794
+ self_attn_name="processor",
795
+ self_lora_type=self.cfg.vae_attn_type, # hard-coded for vae
796
+ cross_lora_type="vanilla"
797
+ )
798
+ self.vae.set_attn_processor(lora_attn_procs)
799
+ # update the trainable parameters
800
+ trainable_params.update(self.vae.attn_processors)
801
+ else:
802
+ raise NotImplementedError("The training type is not supported.")
803
+
804
+ if "locon" in training_type:
805
+ # parse the rank from the training type, with the template "locon_rank_{}"
806
+ rank = re.search(r"locon_rank_(\d+)", training_type).group(1)
807
+ self.locon_rank = int(rank)
808
+
809
+ # if the finetuning is with bias
810
+ self.w_locon_bias = False
811
+ if "with_bias" in training_type:
812
+ self.w_locon_bias = True
813
+
814
+ # specify the conv_processor for unet
815
+ locon_procs = self._set_conv_processor(
816
+ self.unet,
817
+ locon_type=self.cfg.locon_type
818
+ )
819
+
820
+ # update the trainable parameters
821
+ trainable_params.update(locon_procs)
822
+
823
+ # specify the conv_processor for vae
824
+ locon_procs = self._set_conv_processor(
825
+ self.vae,
826
+ locon_type="vanilla_v1", # hard-coded for vae decoder
827
+ )
828
+ # update the trainable parameters
829
+ trainable_params.update(locon_procs)
830
+ else:
831
+ raise NotImplementedError("The training type is not supported.")
832
+
833
+ # overwrite the outconv
834
+ # conv_out_orig = self.vae.decoder.conv_out
835
+ conv_out_new = nn.Conv2d(
836
+ in_channels=128, # conv_out_orig.in_channels, hard-coded
837
+ out_channels=self.cfg.output_dim, kernel_size=3, padding=1
838
+ )
839
+
840
+ # update the trainable parameters
841
+ self.vae.decoder.conv_out = conv_out_new
842
+ trainable_params["vae.decoder.conv_out"] = conv_out_new
843
+
844
+ # save the trainable parameters
845
+ self.peft_layers = AttnProcsLayers(trainable_params).to(self.device)
846
+ self.peft_layers._load_state_dict_pre_hooks.clear()
847
+ self.peft_layers._state_dict_hooks.clear()
848
+
849
+ # hard-coded for now
850
+ self.num_planes = 6
851
+
852
+ if self.cfg.prompt_bias:
853
+ self.prompt_bias = nn.Parameter(torch.zeros(self.num_planes, 77, 1024))
854
+
855
+ @property
856
+ def unet(self):
857
+ return self.submodules.unet
858
+
859
+ @property
860
+ def vae(self):
861
+ return self.submodules.vae
862
+
863
+ def _set_conv_processor(
864
+ self,
865
+ module,
866
+ conv_name: str = "LoRACompatibleConv",
867
+ locon_type: str = "vanilla_v1",
868
+ ):
869
+ locon_procs = {}
870
+ for _name, _module in module.named_modules():
871
+ if _module.__class__.__name__ == conv_name:
872
+ # append the locon processor to the module
873
+ locon_proc = TriplaneLoRAConv2dLayer(
874
+ in_features=_module.in_channels,
875
+ out_features=_module.out_channels,
876
+ rank=self.locon_rank,
877
+ kernel_size=_module.kernel_size,
878
+ stride=_module.stride,
879
+ padding=_module.padding,
880
+ with_bias = self.w_locon_bias,
881
+ locon_type= locon_type,
882
+ )
883
+ # add the locon processor to the module
884
+ _module.lora_layer = locon_proc
885
+ # update the trainable parameters
886
+ key_name = f"{_name}.lora_layer"
887
+ locon_procs[key_name] = locon_proc
888
+ return locon_procs
889
+
890
+
891
+
892
+ def _set_attn_processor(
893
+ self,
894
+ module,
895
+ self_attn_name: str = "attn1.processor",
896
+ self_attn_procs = TriplaneSelfAttentionLoRAAttnProcessor,
897
+ self_lora_type: str = "hexa_v1",
898
+ cross_attn_procs = TriplaneCrossAttentionLoRAAttnProcessor,
899
+ cross_lora_type: str = "hexa_v1",
900
+ ):
901
+ lora_attn_procs = {}
902
+ for name in module.attn_processors.keys():
903
+
904
+ if name.startswith("mid_block"):
905
+ hidden_size = module.config.block_out_channels[-1]
906
+ elif name.startswith("up_blocks"):
907
+ block_id = int(name[len("up_blocks.")])
908
+ hidden_size = list(reversed(module.config.block_out_channels))[
909
+ block_id
910
+ ]
911
+ elif name.startswith("down_blocks"):
912
+ block_id = int(name[len("down_blocks.")])
913
+ hidden_size = module.config.block_out_channels[block_id]
914
+ elif name.startswith("decoder"):
915
+ # special case for decoder in SD
916
+ hidden_size = 512
917
+
918
+ if name.endswith(self_attn_name):
919
+ # it is self-attention
920
+ cross_attention_dim = None
921
+ lora_attn_procs[name] = self_attn_procs(
922
+ hidden_size, self.self_lora_rank, with_bias = self.w_lora_bias,
923
+ lora_type = self_lora_type
924
+ )
925
+ else:
926
+ # it is cross-attention
927
+ cross_attention_dim = module.config.cross_attention_dim
928
+ lora_attn_procs[name] = cross_attn_procs(
929
+ hidden_size, cross_attention_dim, self.cross_lora_rank, with_bias = self.w_lora_bias,
930
+ lora_type = cross_lora_type
931
+ )
932
+ return lora_attn_procs
933
+
934
+ def forward(
935
+ self,
936
+ text_embed,
937
+ styles,
938
+ ):
939
+ return None
940
+ def forward_denoise(
941
+ self,
942
+ text_embed,
943
+ noisy_input,
944
+ t,
945
+ ):
946
+
947
+ batch_size = text_embed.size(0)
948
+ noise_shape = noisy_input.size(-2)
949
+
950
+ if text_embed.ndim == 3:
951
+ # same text_embed for all planes
952
+ # text_embed = text_embed.repeat(self.num_planes, 1, 1) # wrong!!!
953
+ text_embed = text_embed.repeat_interleave(self.num_planes, dim=0)
954
+ elif text_embed.ndim == 4:
955
+ # different text_embed for each plane
956
+ text_embed = text_embed.view(batch_size * self.num_planes, *text_embed.shape[-2:])
957
+ else:
958
+ raise ValueError("The text_embed should be either 3D or 4D.")
959
+
960
+ if hasattr(self, "prompt_bias"):
961
+ text_embed = text_embed + self.prompt_bias.repeat(batch_size, 1, 1) * self.cfg.prompt_bias_lr_multiplier
962
+
963
+ noisy_input = noisy_input.view(-1, 4, noise_shape, noise_shape)
964
+ noise_pred = self.unet(
965
+ noisy_input,
966
+ t,
967
+ encoder_hidden_states=text_embed
968
+ ).sample
969
+
970
+
971
+ return noise_pred
972
+
973
+ def forward_decode(
974
+ self,
975
+ latents,
976
+ ):
977
+ latents = latents.view(-1, 4, *latents.shape[-2:])
978
+ triplane = self.vae.decode(latents).sample
979
+ triplane = triplane.view(-1, self.num_planes, self.cfg.output_dim, *triplane.shape[-2:])
980
+
981
+ return triplane
triplaneturbo_executable/models/geometry/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sd_dual_triplanes import StableDiffusionTriplaneDualAttention, StableDiffusionTriplaneDualAttentionConfig
triplaneturbo_executable/models/geometry/sd_dual_triplanes.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from jaxtyping import Float
10
+ from torch import Tensor
11
+ from typing import *
12
+
13
+ from ...utils.general_utils import contract_to_unisphere_custom, sample_from_planes
14
+ from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel
15
+
16
+ from ..networks import get_mlp
17
+ from ...utils.general_utils import config_to_primitive
18
+ @dataclass
19
+ class StableDiffusionTriplaneDualAttentionConfig:
20
+ n_feature_dims: int = 3
21
+ space_generator_config: dict = field(
22
+ default_factory=lambda: {
23
+ "pretrained_model_name_or_path": "stable-diffusion-2-1-base",
24
+ "training_type": "self_lora_rank_16-cross_lora_rank_16-locon_rank_16",
25
+ "output_dim": 32,
26
+ "gradient_checkpoint": False,
27
+ "self_lora_type": "hexa_v1",
28
+ "cross_lora_type": "hexa_v1",
29
+ "locon_type": "vanilla_v1",
30
+
31
+ }
32
+ )
33
+
34
+ mlp_network_config: dict = field(
35
+ default_factory=lambda: {
36
+ "otype": "VanillaMLP",
37
+ "activation": "ReLU",
38
+ "output_activation": "none",
39
+ "n_neurons": 64,
40
+ "n_hidden_layers": 2,
41
+ }
42
+ )
43
+
44
+ backbone: str = "one_step_triplane_dual_stable_diffusion"
45
+ finite_difference_normal_eps: Union[
46
+ float, str
47
+ ] = 0.01 # in [float, "progressive"] finite_difference_normal_eps: Union[float, str] = 0.01
48
+ sdf_bias: Union[float, str] = 0.0
49
+ sdf_bias_params: Optional[Any] = None
50
+
51
+ isosurface_remove_outliers: bool = False
52
+ # rotate planes to fit the conventional direction of image generated by SD
53
+ # in right-handed coordinate system
54
+ # xy plane should looks that a img from top-down / bottom-up view
55
+ # xz plane should looks that a img from right-left / left-right view
56
+ # yz plane should looks that a img from front-back / back-front view
57
+ rotate_planes: Optional[str] = None
58
+ split_channels: Optional[str] = None
59
+
60
+ geo_interpolate: str = "v1"
61
+ tex_interpolate: str = "v1"
62
+
63
+ isosurface_deformable_grid: bool = True
64
+
65
+
66
+ class StableDiffusionTriplaneDualAttention(nn.Module):
67
+ def __init__(
68
+ self,
69
+ config: StableDiffusionTriplaneDualAttentionConfig,
70
+ vae: AutoencoderKL,
71
+ unet: UNet2DConditionModel,
72
+ ):
73
+ super().__init__()
74
+
75
+ self.cfg = StableDiffusionTriplaneDualAttentionConfig(**config) if isinstance(config, dict) else config
76
+
77
+ # set up the space generator
78
+ from ...extern.sd_dual_triplane_modules import OneStepTriplaneDualStableDiffusion as Generator
79
+ self.space_generator = Generator(
80
+ self.cfg.space_generator_config,
81
+ vae=vae,
82
+ unet=unet,
83
+ )
84
+
85
+ input_dim = self.space_generator.output_dim # feat_xy + feat_xz + feat_yz
86
+ assert self.cfg.split_channels in [None, "v1"]
87
+ if self.cfg.split_channels in ["v1"]: # split geometry and texture
88
+ input_dim = input_dim // 2
89
+
90
+ assert self.cfg.geo_interpolate in ["v1", "v2"]
91
+ if self.cfg.geo_interpolate in ["v2"]:
92
+ geo_input_dim = input_dim * 3 # concat[feat_xy, feat_xz, feat_yz]
93
+ else:
94
+ geo_input_dim = input_dim # feat_xy + feat_xz + feat_yz
95
+
96
+ assert self.cfg.tex_interpolate in ["v1", "v2"]
97
+ if self.cfg.tex_interpolate in ["v2"]:
98
+ tex_input_dim = input_dim * 3 # concat[feat_xy, feat_xz, feat_yz]
99
+ else:
100
+ tex_input_dim = input_dim # feat_xy + feat_xz + feat_yz
101
+
102
+ self.sdf_network = get_mlp(
103
+ geo_input_dim,
104
+ 1,
105
+ self.cfg.mlp_network_config,
106
+ )
107
+ if self.cfg.n_feature_dims > 0:
108
+
109
+ self.feature_network = get_mlp(
110
+ tex_input_dim,
111
+ self.cfg.n_feature_dims,
112
+ self.cfg.mlp_network_config,
113
+ )
114
+
115
+ if self.cfg.isosurface_deformable_grid:
116
+ self.deformation_network = get_mlp(
117
+ geo_input_dim,
118
+ 3,
119
+ self.cfg.mlp_network_config,
120
+ )
121
+
122
+ # hard-coded for now
123
+ self.unbounded = False
124
+ radius = 1.0
125
+
126
+ self.register_buffer(
127
+ "bbox",
128
+ torch.as_tensor(
129
+ [
130
+ [-radius, -radius, -radius],
131
+ [radius, radius, radius],
132
+ ],
133
+ dtype=torch.float32,
134
+ )
135
+ )
136
+
137
+ def initialize_shape(self) -> None:
138
+ # not used
139
+ pass
140
+
141
+ def get_shifted_sdf(
142
+ self,
143
+ points: Float[Tensor, "*N Di"],
144
+ sdf: Float[Tensor, "*N 1"]
145
+ ) -> Float[Tensor, "*N 1"]:
146
+ sdf_bias: Union[float, Float[Tensor, "*N 1"]]
147
+ if self.cfg.sdf_bias == "ellipsoid":
148
+ assert (
149
+ isinstance(self.cfg.sdf_bias_params, Sized)
150
+ and len(self.cfg.sdf_bias_params) == 3
151
+ )
152
+ size = torch.as_tensor(self.cfg.sdf_bias_params).to(points)
153
+ sdf_bias = ((points / size) ** 2).sum(
154
+ dim=-1, keepdim=True
155
+ ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
156
+ elif self.cfg.sdf_bias == "sphere":
157
+ assert isinstance(self.cfg.sdf_bias_params, float)
158
+ radius = self.cfg.sdf_bias_params
159
+ sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius
160
+ elif isinstance(self.cfg.sdf_bias, float):
161
+ sdf_bias = self.cfg.sdf_bias
162
+ else:
163
+ raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}")
164
+ return sdf + sdf_bias
165
+
166
+ def generate_space_cache(
167
+ self,
168
+ styles: Float[Tensor, "B Z"],
169
+ text_embed: Float[Tensor, "B C"],
170
+ ) -> Any:
171
+ output = self.space_generator(
172
+ text_embed = text_embed,
173
+ styles = styles,
174
+ )
175
+ return output
176
+
177
+ def denoise(
178
+ self,
179
+ noisy_input: Any,
180
+ text_embed: Float[Tensor, "B C"],
181
+ timestep
182
+ ) -> Any:
183
+ output = self.space_generator.forward_denoise(
184
+ text_embed = text_embed,
185
+ noisy_input = noisy_input,
186
+ t = timestep
187
+ )
188
+ return output
189
+
190
+ def decode(
191
+ self,
192
+ latents: Any,
193
+ ) -> Any:
194
+ triplane = self.space_generator.forward_decode(
195
+ latents = latents
196
+ )
197
+ if self.cfg.split_channels == None:
198
+ return triplane
199
+ elif self.cfg.split_channels == "v1":
200
+ B, _, C, H, W = triplane.shape
201
+ # geometry triplane uses the first n_feature_dims // 2 channels
202
+ # texture triplane uses the last n_feature_dims // 2 channels
203
+ used_indices_geo = torch.tensor([True] * (self.space_generator.output_dim// 2) + [False] * (self.space_generator.output_dim // 2))
204
+ used_indices_tex = torch.tensor([False] * (self.space_generator.output_dim // 2) + [True] * (self.space_generator.output_dim // 2))
205
+ used_indices = torch.stack([used_indices_geo] * 3 + [used_indices_tex] * 3, dim=0).to(triplane.device)
206
+ return triplane[:, used_indices].view(B, 6, C//2, H, W)
207
+
208
+ def interpolate_encodings(
209
+ self,
210
+ points: Float[Tensor, "*N Di"],
211
+ space_cache: Float[Tensor, "B 3 C//3 H W"],
212
+ only_geo: bool = False,
213
+ ):
214
+ batch_size, n_points, n_dims = points.shape
215
+ # the following code is similar to EG3D / OpenLRM
216
+
217
+ assert self.cfg.rotate_planes in [None, "v1", "v2"]
218
+
219
+ if self.cfg.rotate_planes == None:
220
+ raise NotImplementedError("rotate_planes == None is not implemented yet.")
221
+
222
+ space_cache_rotated = torch.zeros_like(space_cache)
223
+ if self.cfg.rotate_planes == "v1":
224
+ # xy plane, diagonal-wise
225
+ space_cache_rotated[:, 0::3] = torch.transpose(
226
+ space_cache[:, 0::3], 3, 4
227
+ )
228
+ # xz plane, rotate 180° counterclockwise
229
+ space_cache_rotated[:, 1::3] = torch.rot90(
230
+ space_cache[:, 1::3], k=2, dims=(3, 4)
231
+ )
232
+ # zy plane, rotate 90° clockwise
233
+ space_cache_rotated[:, 2::3] = torch.rot90(
234
+ space_cache[:, 2::3], k=-1, dims=(3, 4)
235
+ )
236
+ elif self.cfg.rotate_planes == "v2":
237
+ # all are the same as v1, except for the xy plane
238
+ # xy plane, row-wise flip
239
+ space_cache_rotated[:, 0::3] = torch.flip(
240
+ space_cache[:, 0::3], dims=(4,)
241
+ )
242
+ # xz plane, rotate 180° counterclockwise
243
+ space_cache_rotated[:, 1::3] = torch.rot90(
244
+ space_cache[:, 1::3], k=2, dims=(3, 4)
245
+ )
246
+ # zy plane, rotate 90° clockwise
247
+ space_cache_rotated[:, 2::3] = torch.rot90(
248
+ space_cache[:, 2::3], k=-1, dims=(3, 4)
249
+ )
250
+
251
+
252
+ # the 0, 1, 2 axis of the space_cache_rotated is for geometry
253
+ geo_feat = sample_from_planes(
254
+ plane_features = space_cache_rotated[:, 0:3].contiguous(),
255
+ coordinates = points,
256
+ interpolate_feat = self.cfg.geo_interpolate
257
+ ).view(*points.shape[:-1],-1)
258
+
259
+ if only_geo:
260
+ return geo_feat
261
+ else:
262
+ # the 3, 4, 5 axis of the space_cache is for texture
263
+ tex_feat = sample_from_planes(
264
+ plane_features = space_cache_rotated[:, 3:6].contiguous(),
265
+ coordinates = points,
266
+ interpolate_feat = self.cfg.tex_interpolate
267
+ ).view(*points.shape[:-1],-1)
268
+
269
+ return geo_feat, tex_feat
270
+
271
+
272
+ def rescale_points(
273
+ self,
274
+ points: Float[Tensor, "*N Di"],
275
+ ):
276
+ # transform points from original space to [-1, 1]^3
277
+ points = contract_to_unisphere_custom(
278
+ points,
279
+ self.bbox,
280
+ self.unbounded
281
+ )
282
+ return points
283
+
284
+ def forward(
285
+ self,
286
+ points: Float[Tensor, "*N Di"],
287
+ space_cache: Any,
288
+ ) -> Dict[str, Float[Tensor, "..."]]:
289
+ batch_size, n_points, n_dims = points.shape
290
+
291
+ points_unscaled = points
292
+ points = self.rescale_points(points)
293
+
294
+ enc_geo, enc_tex = self.interpolate_encodings(points, space_cache)
295
+ sdf_orig = self.sdf_network(enc_geo).view(*points.shape[:-1], 1)
296
+ sdf = self.get_shifted_sdf(points_unscaled, sdf_orig)
297
+ output = {
298
+ "sdf": sdf.view(batch_size * n_points, 1), # reshape to [B*N, 1]
299
+ }
300
+ if self.cfg.n_feature_dims > 0:
301
+ features = self.feature_network(enc_tex).view(
302
+ *points.shape[:-1], self.cfg.n_feature_dims)
303
+ output.update(
304
+ {
305
+ "features": features.view(batch_size * n_points, self.cfg.n_feature_dims)
306
+ }
307
+ )
308
+ return output
309
+
310
+ def forward_sdf(
311
+ self,
312
+ points: Float[Tensor, "*N Di"],
313
+ space_cache: Float[Tensor, "B 3 C//3 H W"],
314
+ ) -> Float[Tensor, "*N 1"]:
315
+ batch_size = points.shape[0]
316
+ assert points.shape[0] == batch_size, "points and space_cache should have the same batch size in forward_sdf"
317
+ points_unscaled = points
318
+
319
+ points = self.rescale_points(points)
320
+
321
+ # sample from planes
322
+ enc_geo = self.interpolate_encodings(
323
+ points.reshape(batch_size, -1, 3),
324
+ space_cache,
325
+ only_geo = True
326
+ ).reshape(*points.shape[:-1], -1)
327
+ sdf = self.sdf_network(enc_geo).reshape(*points.shape[:-1], 1)
328
+
329
+ sdf = self.get_shifted_sdf(points_unscaled, sdf)
330
+ return sdf
331
+
332
+ def forward_field(
333
+ self,
334
+ points: Float[Tensor, "*N Di"],
335
+ space_cache: Float[Tensor, "B 3 C//3 H W"],
336
+ ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
337
+ batch_size = points.shape[0]
338
+ assert points.shape[0] == batch_size, "points and space_cache should have the same batch size in forward_sdf"
339
+ points_unscaled = points
340
+
341
+ points = self.rescale_points(points)
342
+
343
+ # sample from planes
344
+ enc_geo = self.interpolate_encodings(points, space_cache, only_geo = True)
345
+ sdf = self.sdf_network(enc_geo).reshape(*points.shape[:-1], 1)
346
+ sdf = self.get_shifted_sdf(points_unscaled, sdf)
347
+ deformation: Optional[Float[Tensor, "*N 3"]] = None
348
+ if self.cfg.isosurface_deformable_grid:
349
+ deformation = self.deformation_network(enc_geo).reshape(*points.shape[:-1], 3)
350
+ return sdf, deformation
351
+
352
+ def forward_level(
353
+ self, field: Float[Tensor, "*N 1"], threshold: float
354
+ ) -> Float[Tensor, "*N 1"]:
355
+ # TODO: is this function correct?
356
+ return field - threshold
357
+
358
+ def export(
359
+ self,
360
+ points: Float[Tensor, "*N Di"],
361
+ space_cache: Float[Tensor, "B 3 C//3 H W"],
362
+ **kwargs) -> Dict[str, Any]:
363
+
364
+ # TODO: is this function correct?
365
+ out: Dict[str, Any] = {}
366
+ if self.cfg.n_feature_dims == 0:
367
+ return out
368
+
369
+ orig_shape = points.shape
370
+ points = points.view(1, -1, 3)
371
+
372
+ # assume the batch size is 1
373
+ points_unscaled = points
374
+ points = self.rescale_points(points)
375
+
376
+ # sample from planes
377
+ _, enc_tex = self.interpolate_encodings(points, space_cache)
378
+ features = self.feature_network(enc_tex).view(
379
+ *points.shape[:-1], self.cfg.n_feature_dims
380
+ )
381
+ out.update(
382
+ {
383
+ "features": features.view(orig_shape[:-1] + (self.cfg.n_feature_dims,))
384
+ }
385
+ )
386
+ return out
387
+
388
+ def train(self, mode=True):
389
+ super().train(mode)
390
+ self.space_generator.train(mode)
391
+
392
+ def eval(self):
393
+ super().eval()
394
+ self.space_generator.eval()
triplaneturbo_executable/models/networks.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ..utils.general_utils import config_to_primitive
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Literal
7
+
8
+ def get_activation(name):
9
+ if name is None:
10
+ return lambda x: x
11
+ name = name.lower()
12
+ if name == "none":
13
+ return lambda x: x
14
+ elif name == "sigmoid-mipnerf":
15
+ return lambda x: torch.sigmoid(x) * (1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
16
+ else:
17
+ try:
18
+ return getattr(F, name)
19
+ except AttributeError:
20
+ raise ValueError(f"Unknown activation function: {name}")
21
+
22
+
23
+ class VanillaMLP(nn.Module):
24
+ def __init__(self, dim_in: int, dim_out: int, config: dict):
25
+ super().__init__()
26
+ # Convert dict to MLPConfig if needed
27
+ if isinstance(config, dict):
28
+ config = MLPConfig(**config)
29
+
30
+ self.n_neurons = config.n_neurons
31
+ self.n_hidden_layers = config.n_hidden_layers
32
+
33
+ layers = [
34
+ self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False),
35
+ self.make_activation(),
36
+ ]
37
+ for i in range(self.n_hidden_layers - 1):
38
+ layers += [
39
+ self.make_linear(
40
+ self.n_neurons, self.n_neurons, is_first=False, is_last=False
41
+ ),
42
+ self.make_activation(),
43
+ ]
44
+ layers += [
45
+ self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)
46
+ ]
47
+ self.layers = nn.Sequential(*layers)
48
+ self.output_activation = get_activation(config.output_activation)
49
+
50
+ def forward(self, x):
51
+ # disable autocast
52
+ # strange that the parameters will have empty gradients if autocast is enabled in AMP
53
+ with torch.cuda.amp.autocast(enabled=False):
54
+ x = self.layers(x)
55
+ x = self.output_activation(x)
56
+ return x
57
+
58
+ def make_linear(self, dim_in, dim_out, is_first, is_last):
59
+ layer = nn.Linear(dim_in, dim_out, bias=False)
60
+ return layer
61
+
62
+ def make_activation(self):
63
+ return nn.ReLU(inplace=True)
64
+
65
+ @dataclass
66
+ class MLPConfig:
67
+ otype: str = "VanillaMLP"
68
+ activation: str = "ReLU"
69
+ output_activation: str = "none"
70
+ n_neurons: int = 64
71
+ n_hidden_layers: int = 2
72
+
73
+ def get_mlp(input_dim: int, output_dim: int, config: dict) -> nn.Module:
74
+ """Create MLP network based on config"""
75
+ # Convert dict to MLPConfig
76
+ if isinstance(config, dict):
77
+ config = MLPConfig(**config)
78
+
79
+ if config.otype == "VanillaMLP":
80
+ network = VanillaMLP(input_dim, output_dim, config)
81
+ else:
82
+ raise ValueError(f"Unknown MLP type: {config.otype}")
83
+ return network
triplaneturbo_executable/pipelines/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .triplaneturbo_text_to_3d import (
2
+ TriplaneTurboTextTo3DPipeline,
3
+ TriplaneTurboTextTo3DPipelineConfig
4
+ )
5
+
6
+ __all__ = [
7
+ "TriplaneTurboTextTo3DPipeline",
8
+ "TriplaneTurboTextTo3DPipelineConfig"
9
+ ]
triplaneturbo_executable/pipelines/base.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from diffusers import DiffusionPipeline
5
+
6
+ class Pipeline(DiffusionPipeline):
7
+ """Base class for all pipelines."""
8
+
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def __call__(self, *args, **kwargs):
13
+ raise NotImplementedError
14
+
15
+ def enable_xformers_memory_efficient_attention(self):
16
+ pass
17
+
18
+ def enable_model_cpu_offload(self):
19
+ pass
20
+
21
+ @property
22
+ def device(self) -> torch.device:
23
+ for model in self.models.values():
24
+ if hasattr(model, 'device'):
25
+ return model.device
26
+ for model in self.models.values():
27
+ if hasattr(model, 'parameters'):
28
+ return next(model.parameters()).device
29
+ raise RuntimeError("No device found.")
30
+
31
+ def to(self, device: torch.device) -> None:
32
+ for model in self.models.values():
33
+ model.to(device)
triplaneturbo_executable/pipelines/triplaneturbo_text_to_3d.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ from tqdm import tqdm
5
+
6
+ import torch
7
+ from typing import *
8
+ from dataclasses import dataclass, field
9
+ from diffusers import StableDiffusionPipeline
10
+
11
+ from .base import Pipeline
12
+ from ..models.geometry import StableDiffusionTriplaneDualAttention
13
+ from ..utils.mesh_exporter import isosurface, colorize_mesh, DiffMarchingCubeHelper
14
+
15
+ from diffusers.loaders import AttnProcsLayers
16
+ from ..models.networks import get_activation
17
+
18
+ @dataclass
19
+ class TriplaneTurboTextTo3DPipelineConfig:
20
+ """Configuration for TriplaneTurboTextTo3DPipeline"""
21
+ # Basic pipeline settings
22
+ base_model_name_or_path: str = "pretrained/stable-diffusion-2-1-base"
23
+
24
+ num_inference_steps: int = 4
25
+ num_results_per_prompt: int = 1
26
+ latent_channels: int = 4
27
+ latent_height: int = 64
28
+ latent_width: int = 64
29
+
30
+ # Training/sampling settings
31
+ num_steps_sampling: int = 4
32
+
33
+ # Geometry settings
34
+ radius: float = 1.0
35
+ normal_type: str = "analytic"
36
+ sdf_bias: str = "sphere"
37
+ sdf_bias_params: float = 0.5
38
+ rotate_planes: str = "v1"
39
+ split_channels: str = "v1"
40
+ geo_interpolate: str = "v1"
41
+ tex_interpolate: str = "v2"
42
+ n_feature_dims: int = 3
43
+
44
+ sample_scheduler: str = "ddim" # any of "ddpm", "ddim"
45
+
46
+ # Network settings
47
+ mlp_network_config: dict = field(
48
+ default_factory=lambda: {
49
+ "otype": "VanillaMLP",
50
+ "activation": "ReLU",
51
+ "output_activation": "none",
52
+ "n_neurons": 64,
53
+ "n_hidden_layers": 2,
54
+ }
55
+ )
56
+
57
+ # Adapter settings
58
+ space_generator_config: dict = field(
59
+ default_factory=lambda: {
60
+ "training_type": "self_lora_rank_16-cross_lora_rank_16-locon_rank_16" ,
61
+ "output_dim": 64, # 32 * 2 for v1
62
+ "self_lora_type": "hexa_v1",
63
+ "cross_lora_type": "vanilla",
64
+ "locon_type": "vanilla_v1",
65
+ "prompt_bias": False,
66
+ "vae_attn_type": "basic", # "basic", "vanilla"
67
+ }
68
+ )
69
+
70
+ isosurface_deformable_grid: bool = True
71
+ isosurface_resolution: int = 160
72
+ color_activation: str = "sigmoid-mipnerf"
73
+
74
+ @classmethod
75
+ def from_pretrained(cls, pretrained_path: str) -> "TriplaneTurboTextTo3DPipelineConfig":
76
+ """Load config from pretrained path"""
77
+ config_path = os.path.join(pretrained_path, "config.json")
78
+ if os.path.exists(config_path):
79
+ with open(config_path, "r") as f:
80
+ config_dict = json.load(f)
81
+ return cls(**config_dict)
82
+ else:
83
+ print(f"No config file found at {pretrained_path}, using default config")
84
+ return cls() # Return default config if no config file found
85
+
86
+ class TriplaneTurboTextTo3DPipeline(Pipeline):
87
+ """
88
+ A pipeline for converting text to 3D models using triplane representation.
89
+ """
90
+ config_name = "config.json"
91
+
92
+ def __init__(
93
+ self,
94
+ geometry: StableDiffusionTriplaneDualAttention,
95
+ material: Callable,
96
+ base_pipeline: StableDiffusionPipeline,
97
+ sample_scheduler: Callable,
98
+ isosurface_helper: Callable,
99
+ **kwargs,
100
+ ):
101
+ super().__init__()
102
+ self.geometry = geometry
103
+ self.material = material
104
+
105
+ self.base_pipeline = base_pipeline
106
+
107
+ self.sample_scheduler = sample_scheduler
108
+ self.isosurface_helper = isosurface_helper
109
+
110
+
111
+ self.models = {
112
+ "geometry": geometry,
113
+ "base_pipeline": base_pipeline,
114
+ }
115
+
116
+ @classmethod
117
+ def from_pretrained(
118
+ cls,
119
+ pretrained_model_name_or_path: str,
120
+ **kwargs,
121
+ ):
122
+ """
123
+ Load pretrained adapter weights, config and update pipeline components.
124
+
125
+ Args:
126
+ pretrained_model_name_or_path: Path to pretrained adapter weights
127
+ base_pipeline: Optional base pipeline instance
128
+ **kwargs: Additional arguments to override config values
129
+
130
+ Returns:
131
+ pipeline: Updated pipeline instance
132
+ """
133
+ # Load config from pretrained path
134
+ config = TriplaneTurboTextTo3DPipelineConfig.from_pretrained(
135
+ pretrained_model_name_or_path,
136
+ **kwargs,
137
+ )
138
+
139
+ # load base pipeline
140
+ base_pipeline = StableDiffusionPipeline.from_pretrained(
141
+ config.base_model_name_or_path,
142
+ **kwargs,
143
+ )
144
+
145
+ # load sample scheduler
146
+ if config.sample_scheduler == "ddim":
147
+ from diffusers import DDIMScheduler
148
+ sample_scheduler = DDIMScheduler.from_pretrained(
149
+ config.base_model_name_or_path,
150
+ subfolder="scheduler",
151
+ )
152
+ else:
153
+ raise ValueError(f"Unknown sample scheduler: {config.sample_scheduler}")
154
+
155
+ # load geometry
156
+ geometry = StableDiffusionTriplaneDualAttention(
157
+ config=config,
158
+ vae=base_pipeline.vae,
159
+ unet=base_pipeline.unet,
160
+ )
161
+
162
+ # no gradient for geometry
163
+ for param in geometry.parameters():
164
+ param.requires_grad = False
165
+
166
+ # and load adapter weights
167
+ if pretrained_model_name_or_path.endswith(".pth"):
168
+ state_dict = torch.load(pretrained_model_name_or_path)["state_dict"]
169
+ new_state_dict = {}
170
+ for key, value in state_dict.items():
171
+ new_key = key.replace("geometry.", "")
172
+ new_state_dict[new_key] = value
173
+ _, unused = geometry.load_state_dict(new_state_dict, strict=False)
174
+ if len(unused) > 0:
175
+ print(f"Unused keys: {unused}")
176
+ else:
177
+ raise ValueError(f"Unknown pretrained model name or path: {pretrained_model_name_or_path}")
178
+
179
+
180
+ # load material, convert to int
181
+ # material = lambda x: (256 * get_activation(config.color_activation)(x)).int()
182
+ material = get_activation(config.color_activation)
183
+
184
+ # Load geometry model
185
+ pipeline = cls(
186
+ base_pipeline=base_pipeline,
187
+ geometry=geometry,
188
+ sample_scheduler=sample_scheduler,
189
+ material=material,
190
+ isosurface_helper=DiffMarchingCubeHelper(
191
+ resolution=config.isosurface_resolution,
192
+ ),
193
+ **kwargs,
194
+ )
195
+ return pipeline
196
+
197
+
198
+ def encode_prompt(
199
+ self,
200
+ prompt: Union[str, List[str]],
201
+ device: str,
202
+ num_results_per_prompt: int = 1,
203
+ ) -> torch.FloatTensor:
204
+ """
205
+ Encodes the prompt into text encoder hidden states.
206
+
207
+ Args:
208
+ prompt: The prompt to encode.
209
+ device: The device to use for encoding.
210
+ num_results_per_prompt: Number of results to generate per prompt.
211
+ do_classifier_free_guidance: Whether to use classifier-free guidance.
212
+ negative_prompt: The negative prompt to encode.
213
+
214
+ Returns:
215
+ text_embeddings: Text embeddings tensor.
216
+ """
217
+ # Use base_pipeline to encode prompt
218
+ text_embeddings = self.base_pipeline.encode_prompt(
219
+ prompt=prompt,
220
+ device=device,
221
+ num_images_per_prompt=num_results_per_prompt,
222
+ do_classifier_free_guidance=False,
223
+ negative_prompt=None
224
+ )
225
+ return text_embeddings
226
+
227
+ @torch.no_grad()
228
+ def __call__(
229
+ self,
230
+ prompt: Union[str, List[str]],
231
+ num_inference_steps: int = 4,
232
+ num_results_per_prompt: int = 1,
233
+ generator: Optional[torch.Generator] = None,
234
+ latents: Optional[torch.FloatTensor] = None,
235
+ return_dict: bool = True,
236
+ colorize: bool = True,
237
+ **kwargs,
238
+ ):
239
+ # Implementation similar to Zero123Pipeline
240
+ # Reference code from: https://github.com/zero123/zero123-diffusers
241
+
242
+ # Validate inputs
243
+ if isinstance(prompt, str):
244
+ batch_size = 1
245
+ prompt = [prompt]
246
+ elif isinstance(prompt, list):
247
+ batch_size = len(prompt)
248
+ else:
249
+ raise ValueError(f"Prompt must be a string or list of strings, got {type(prompt)}")
250
+
251
+ # Get the device from the first available module
252
+
253
+ # Generate latents if not provided
254
+ if latents is None:
255
+ latents = torch.randn(
256
+ (batch_size * 6, 4, 32, 32), # hard-coded for now
257
+ generator=generator,
258
+ device=self.device,
259
+ )
260
+
261
+ # Process text prompt through geometry module
262
+ text_embed, _ = self.encode_prompt(prompt, self.device, num_results_per_prompt)
263
+
264
+ # Run diffusion process
265
+ # Set up timesteps for sampling
266
+ timesteps = self._set_timesteps(
267
+ self.sample_scheduler,
268
+ num_inference_steps
269
+ )
270
+
271
+
272
+ with torch.no_grad():
273
+ # Run diffusion process
274
+ for i, t in tqdm(enumerate(timesteps)):
275
+ # Scale model input
276
+ noisy_latent_input = self.sample_scheduler.scale_model_input(
277
+ latents,
278
+ t
279
+ )
280
+
281
+ # Predict noise/sample
282
+ pred = self.geometry.denoise(
283
+ noisy_input=noisy_latent_input,
284
+ text_embed=text_embed,
285
+ timestep=t.to(self.device),
286
+ )
287
+
288
+ # Update latents
289
+ results = self.sample_scheduler.step(pred, t, latents)
290
+ latents = results.prev_sample
291
+ latents_denoised = results.pred_original_sample
292
+
293
+ # Use final denoised latents
294
+ latents = latents_denoised
295
+
296
+ # Generate final 3D representation
297
+ space_cache = self.geometry.decode(latents)
298
+
299
+ # Extract mesh from space cache
300
+ mesh_list = isosurface(
301
+ space_cache,
302
+ self.geometry.forward_field,
303
+ self.isosurface_helper,
304
+ )
305
+
306
+ if colorize:
307
+ mesh_list = colorize_mesh(
308
+ space_cache,
309
+ self.geometry.export,
310
+ mesh_list,
311
+ activation=self.material,
312
+ )
313
+
314
+ # decide output type based on return_dict
315
+ if return_dict:
316
+ return {
317
+ "space_cache": space_cache,
318
+ "latents": latents,
319
+ "mesh": mesh_list,
320
+ }
321
+ else:
322
+ return mesh_list
323
+
324
+ def _set_timesteps(
325
+ self,
326
+ scheduler,
327
+ num_steps: int,
328
+ ):
329
+ """Set up timesteps for sampling.
330
+
331
+ Args:
332
+ scheduler: The scheduler to use for timestep generation
333
+ num_steps: Number of diffusion steps
334
+
335
+ Returns:
336
+ timesteps: Tensor of timesteps to use for sampling
337
+ """
338
+ scheduler.set_timesteps(num_steps)
339
+ timesteps_orig = scheduler.timesteps
340
+ # Shift timesteps to start from T
341
+ timesteps_delta = scheduler.config.num_train_timesteps - 1 - timesteps_orig.max()
342
+ timesteps = timesteps_orig + timesteps_delta
343
+ return timesteps
344
+
triplaneturbo_executable/utils/__init__.py ADDED
File without changes
triplaneturbo_executable/utils/general_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+ from typing import *
6
+ from jaxtyping import Float
7
+ from omegaconf import OmegaConf
8
+
9
+ def config_to_primitive(config, resolve: bool = True) -> Any:
10
+ return OmegaConf.to_container(config, resolve=resolve)
11
+
12
+ def scale_tensor(
13
+ dat: Float[Tensor, "... D"],
14
+ inp_scale: Union[Tuple[float, float], Float[Tensor, "2 D"]],
15
+ tgt_scale: Union[Tuple[float, float], Float[Tensor, "2 D"]]
16
+ ):
17
+ if inp_scale is None:
18
+ inp_scale = (0, 1)
19
+ if tgt_scale is None:
20
+ tgt_scale = (0, 1)
21
+ if isinstance(tgt_scale, Tensor):
22
+ assert dat.shape[-1] == tgt_scale.shape[-1]
23
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
24
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
25
+ return dat
26
+
27
+ def contract_to_unisphere_custom(
28
+ x: Float[Tensor, "... 3"], bbox: Float[Tensor, "2 3"], unbounded: bool = False
29
+ ) -> Float[Tensor, "... 3"]:
30
+ if unbounded:
31
+ x = scale_tensor(x, bbox, (-1, 1))
32
+ x = x * 2 - 1 # aabb is at [-1, 1]
33
+ mag = x.norm(dim=-1, keepdim=True)
34
+ mask = mag.squeeze(-1) > 1
35
+ x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
36
+ x = x / 4 + 0.5 # [-inf, inf] is at [0, 1]
37
+ else:
38
+ x = scale_tensor(x, bbox, (-1, 1))
39
+ return x
40
+
41
+ # bug fix in https://github.com/NVlabs/eg3d/issues/67
42
+ planes = torch.tensor(
43
+ [
44
+ [
45
+ [1, 0, 0],
46
+ [0, 1, 0],
47
+ [0, 0, 1]
48
+ ],
49
+ [
50
+ [1, 0, 0],
51
+ [0, 0, 1],
52
+ [0, 1, 0]
53
+ ],
54
+ [
55
+ [0, 0, 1],
56
+ [0, 1, 0],
57
+ [1, 0, 0]
58
+ ]
59
+ ], dtype=torch.float32)
60
+
61
+
62
+ def grid_sample(input, grid):
63
+ # if grid.requires_grad and _should_use_custom_op():
64
+ # return grid_sample_2d(input, grid, padding_mode='zeros', align_corners=False)
65
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
66
+
67
+
68
+ def project_onto_planes(planes, coordinates):
69
+ """
70
+ Does a projection of a 3D point onto a batch of 2D planes,
71
+ returning 2D plane coordinates.
72
+
73
+ Takes plane axes of shape n_planes, 3, 3
74
+ # Takes coordinates of shape N, M, 3
75
+ # returns projections of shape N*n_planes, M, 2
76
+ """
77
+ N, M, C = coordinates.shape
78
+ n_planes, _, _ = planes.shape
79
+ coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
80
+ inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
81
+ projections = torch.bmm(coordinates, inv_planes)
82
+ return projections[..., :2]
83
+
84
+ def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=2, interpolate_feat: Optional[str] = 'None'):
85
+ assert padding_mode == 'zeros'
86
+ N, n_planes, C, H, W = plane_features.shape
87
+ _, M, _ = coordinates.shape
88
+ plane_features = plane_features.view(N*n_planes, C, H, W)
89
+
90
+ coordinates = (2/box_warp) * coordinates # add specific box bounds
91
+
92
+ if interpolate_feat in [None, "v1"]:
93
+ projected_coordinates = project_onto_planes(planes.to(coordinates), coordinates).unsqueeze(1)
94
+ output_features = grid_sample(plane_features, projected_coordinates.float())
95
+ output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
96
+ output_features = output_features.sum(dim=1, keepdim=True).reshape(N, M, C)
97
+
98
+ elif interpolate_feat in ["v2"]:
99
+ projected_coordinates = project_onto_planes(planes.to(coordinates), coordinates).unsqueeze(1)
100
+ output_features = grid_sample(plane_features, projected_coordinates.float())
101
+ output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
102
+ output_features = output_features.permute(0, 2, 1, 3).reshape(N, M, n_planes*C)
103
+
104
+ return output_features.contiguous()
triplaneturbo_executable/utils/mesh.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Any, Dict, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from jaxtyping import Float, Integer
11
+ from torch import Tensor
12
+
13
+ def dot(x, y):
14
+ return torch.sum(x * y, -1, keepdim=True)
15
+
16
+ class Mesh:
17
+ def __init__(
18
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
19
+ ) -> None:
20
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
21
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
22
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
23
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
24
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
25
+ self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None
26
+ self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None
27
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
28
+ self.extras: Dict[str, Any] = {}
29
+ for k, v in kwargs.items():
30
+ self.add_extra(k, v)
31
+
32
+ def add_extra(self, k, v) -> None:
33
+ self.extras[k] = v
34
+
35
+ def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]):
36
+
37
+ # use trimesh to first split the mesh into connected components
38
+ # then remove the components with less than n_face_threshold faces
39
+ import trimesh
40
+
41
+ # construct a trimesh object
42
+ mesh = trimesh.Trimesh(
43
+ vertices=self.v_pos.detach().cpu().numpy(),
44
+ faces=self.t_pos_idx.detach().cpu().numpy(),
45
+ )
46
+
47
+ # split the mesh into connected components
48
+ components = mesh.split(only_watertight=False)
49
+
50
+
51
+ n_faces_threshold: int
52
+ if isinstance(outlier_n_faces_threshold, float):
53
+ # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold
54
+ n_faces_threshold = int(
55
+ max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold
56
+ )
57
+ else:
58
+ # set the threshold directly to outlier_n_faces_threshold
59
+ n_faces_threshold = outlier_n_faces_threshold
60
+
61
+ # remove the components with less than n_face_threshold faces
62
+ components = [c for c in components if c.faces.shape[0] >= n_faces_threshold]
63
+
64
+ # merge the components
65
+ mesh = trimesh.util.concatenate(components)
66
+
67
+ # convert back to our mesh format
68
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
69
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
70
+
71
+ clean_mesh = Mesh(v_pos, t_pos_idx)
72
+ # keep the extras unchanged
73
+
74
+ return clean_mesh
75
+
76
+ @property
77
+ def requires_grad(self):
78
+ return self.v_pos.requires_grad
79
+
80
+ @property
81
+ def v_nrm(self):
82
+ if self._v_nrm is None:
83
+ self._v_nrm = self._compute_vertex_normal()
84
+ return self._v_nrm
85
+
86
+ @property
87
+ def v_tng(self):
88
+ if self._v_tng is None:
89
+ self._v_tng = self._compute_vertex_tangent()
90
+ return self._v_tng
91
+
92
+ @property
93
+ def v_tex(self):
94
+ if self._v_tex is None:
95
+ self._v_tex, self._t_tex_idx = self._unwrap_uv()
96
+ return self._v_tex
97
+
98
+ @property
99
+ def t_tex_idx(self):
100
+ if self._t_tex_idx is None:
101
+ self._v_tex, self._t_tex_idx = self._unwrap_uv()
102
+ return self._t_tex_idx
103
+
104
+ @property
105
+ def v_rgb(self):
106
+ return self._v_rgb
107
+
108
+ @property
109
+ def edges(self):
110
+ if self._edges is None:
111
+ self._edges = self._compute_edges()
112
+ return self._edges
113
+
114
+ def _compute_vertex_normal(self):
115
+ i0 = self.t_pos_idx[:, 0]
116
+ i1 = self.t_pos_idx[:, 1]
117
+ i2 = self.t_pos_idx[:, 2]
118
+
119
+ v0 = self.v_pos[i0, :]
120
+ v1 = self.v_pos[i1, :]
121
+ v2 = self.v_pos[i2, :]
122
+
123
+ face_normals = torch.cross(v1 - v0, v2 - v0)
124
+
125
+ # Splat face normals to vertices
126
+ v_nrm = torch.zeros_like(self.v_pos)
127
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
128
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
129
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
130
+
131
+ # Normalize, replace zero (degenerated) normals with some default value
132
+ v_nrm = torch.where(
133
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
134
+ )
135
+ v_nrm = F.normalize(v_nrm, dim=1)
136
+
137
+ if torch.is_anomaly_enabled():
138
+ assert torch.all(torch.isfinite(v_nrm))
139
+
140
+ return v_nrm
141
+
142
+ def _compute_vertex_tangent(self):
143
+ vn_idx = [None] * 3
144
+ pos = [None] * 3
145
+ tex = [None] * 3
146
+ for i in range(0, 3):
147
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
148
+ tex[i] = self.v_tex[self.t_tex_idx[:, i]]
149
+ # t_nrm_idx is always the same as t_pos_idx
150
+ vn_idx[i] = self.t_pos_idx[:, i]
151
+
152
+ tangents = torch.zeros_like(self.v_nrm)
153
+ tansum = torch.zeros_like(self.v_nrm)
154
+
155
+ # Compute tangent space for each triangle
156
+ uve1 = tex[1] - tex[0]
157
+ uve2 = tex[2] - tex[0]
158
+ pe1 = pos[1] - pos[0]
159
+ pe2 = pos[2] - pos[0]
160
+
161
+ nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]
162
+ denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]
163
+
164
+ # Avoid division by zero for degenerated texture coordinates
165
+ tang = nom / torch.where(
166
+ denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)
167
+ )
168
+
169
+ # Update all 3 vertices
170
+ for i in range(0, 3):
171
+ idx = vn_idx[i][:, None].repeat(1, 3)
172
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
173
+ tansum.scatter_add_(
174
+ 0, idx, torch.ones_like(tang)
175
+ ) # tansum[n_i] = tansum[n_i] + 1
176
+ tangents = tangents / tansum
177
+
178
+ # Normalize and make sure tangent is perpendicular to normal
179
+ tangents = F.normalize(tangents, dim=1)
180
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
181
+
182
+ if torch.is_anomaly_enabled():
183
+ assert torch.all(torch.isfinite(tangents))
184
+
185
+ return tangents
186
+
187
+ def _unwrap_uv(
188
+ self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
189
+ ):
190
+
191
+ import xatlas
192
+
193
+ atlas = xatlas.Atlas()
194
+ atlas.add_mesh(
195
+ self.v_pos.detach().cpu().numpy(),
196
+ self.t_pos_idx.cpu().numpy(),
197
+ )
198
+ co = xatlas.ChartOptions()
199
+ po = xatlas.PackOptions()
200
+ for k, v in xatlas_chart_options.items():
201
+ setattr(co, k, v)
202
+ for k, v in xatlas_pack_options.items():
203
+ setattr(po, k, v)
204
+ atlas.generate(co, po)
205
+ vmapping, indices, uvs = atlas.get_mesh(0)
206
+ vmapping = (
207
+ torch.from_numpy(
208
+ vmapping.astype(np.uint64, casting="same_kind").view(np.int64)
209
+ )
210
+ .to(self.v_pos.device)
211
+ .long()
212
+ )
213
+ uvs = torch.from_numpy(uvs).to(self.v_pos.device).float()
214
+ indices = (
215
+ torch.from_numpy(
216
+ indices.astype(np.uint64, casting="same_kind").view(np.int64)
217
+ )
218
+ .to(self.v_pos.device)
219
+ .long()
220
+ )
221
+ return uvs, indices
222
+
223
+ def unwrap_uv(
224
+ self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
225
+ ):
226
+ self._v_tex, self._t_tex_idx = self._unwrap_uv(
227
+ xatlas_chart_options, xatlas_pack_options
228
+ )
229
+
230
+ def set_vertex_color(self, v_rgb):
231
+ assert v_rgb.shape[0] == self.v_pos.shape[0]
232
+ self._v_rgb = v_rgb
233
+
234
+ def _compute_edges(self):
235
+ # Compute edges
236
+ edges = torch.cat(
237
+ [
238
+ self.t_pos_idx[:, [0, 1]],
239
+ self.t_pos_idx[:, [1, 2]],
240
+ self.t_pos_idx[:, [2, 0]],
241
+ ],
242
+ dim=0,
243
+ )
244
+ edges = edges.sort()[0]
245
+ edges = torch.unique(edges, dim=0)
246
+ return edges
247
+
248
+ def normal_consistency(self) -> Float[Tensor, ""]:
249
+ edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges]
250
+ nc = (
251
+ 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
252
+ ).mean()
253
+ return nc
254
+
255
+ def _laplacian_uniform(self):
256
+ # from stable-dreamfusion
257
+ # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224
258
+ verts, faces = self.v_pos, self.t_pos_idx
259
+
260
+ V = verts.shape[0]
261
+ F = faces.shape[0]
262
+
263
+ # Neighbor indices
264
+ ii = faces[:, [1, 2, 0]].flatten()
265
+ jj = faces[:, [2, 0, 1]].flatten()
266
+ adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(
267
+ dim=1
268
+ )
269
+ adj_values = torch.ones(adj.shape[1]).to(verts)
270
+
271
+ # Diagonal indices
272
+ diag_idx = adj[0]
273
+
274
+ # Build the sparse matrix
275
+ idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
276
+ values = torch.cat((-adj_values, adj_values))
277
+
278
+ # The coalesce operation sums the duplicate indices, resulting in the
279
+ # correct diagonal
280
+ return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
281
+
282
+ def laplacian(self) -> Float[Tensor, ""]:
283
+ with torch.no_grad():
284
+ L = self._laplacian_uniform()
285
+ loss = L.mm(self.v_pos)
286
+ loss = loss.norm(dim=1)
287
+ loss = loss.mean()
288
+ return loss
triplaneturbo_executable/utils/mesh_exporter.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Optional, Tuple, Any
2
+ from jaxtyping import Float
3
+ from torch import Tensor
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import os
9
+ import numpy as np
10
+ from .saving import SaverMixin
11
+
12
+ from ..utils.mesh import Mesh
13
+ from ..utils.general_utils import scale_tensor
14
+
15
+ @dataclass
16
+ class ExporterOutput:
17
+ save_name: str
18
+ save_type: str
19
+ params: Dict[str, Any]
20
+
21
+
22
+ class IsosurfaceHelper(nn.Module):
23
+ points_range: Tuple[float, float] = (0, 1)
24
+
25
+ @property
26
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
27
+ raise NotImplementedError
28
+
29
+ class DiffMarchingCubeHelper(IsosurfaceHelper):
30
+ def __init__(
31
+ self,
32
+ resolution: int,
33
+ point_range: Tuple[float, float] = (0, 1)
34
+ ) -> None:
35
+ super().__init__()
36
+ self.resolution = resolution
37
+ self.points_range = point_range
38
+
39
+ from diso import DiffMC
40
+ self.mc_func: Callable = DiffMC(dtype=torch.float32)
41
+ self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None
42
+ self._dummy: Float[Tensor, "..."]
43
+ self.register_buffer(
44
+ "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False
45
+ )
46
+
47
+ @property
48
+ def grid_vertices(self) -> Float[Tensor, "N3 3"]:
49
+ if self._grid_vertices is None:
50
+ # keep the vertices on CPU so that we can support very large resolution
51
+ x, y, z = (
52
+ torch.linspace(*self.points_range, self.resolution),
53
+ torch.linspace(*self.points_range, self.resolution),
54
+ torch.linspace(*self.points_range, self.resolution),
55
+ )
56
+ x, y, z = torch.meshgrid(x, y, z, indexing="ij")
57
+ verts = torch.stack([x, y, z], dim=-1).reshape(-1, 3)
58
+ verts = verts * (self.points_range[1] - self.points_range[0]) + self.points_range[0]
59
+
60
+ self._grid_vertices = verts
61
+ return self._grid_vertices
62
+
63
+ def forward(
64
+ self,
65
+ level: Float[Tensor, "N3 1"],
66
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
67
+ isovalue=0.0,
68
+ ) -> Mesh:
69
+ level = level.view(self.resolution, self.resolution, self.resolution)
70
+ if deformation is not None:
71
+ deformation = deformation.view(self.resolution, self.resolution, self.resolution, 3)
72
+ v_pos, t_pos_idx = self.mc_func(level, deformation, isovalue=isovalue)
73
+ v_pos = v_pos * (self.points_range[1] - self.points_range[0]) + self.points_range[0]
74
+ # TODO: if the mesh is good
75
+ return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx)
76
+
77
+
78
+ def isosurface(
79
+ space_cache: Float[Tensor, "B ..."],
80
+ forward_field: Callable,
81
+ isosurface_helper: Callable,
82
+ ) -> List[Mesh]:
83
+
84
+ # the isosurface is dependent on the space cache
85
+ # randomly detach isosurface method if it is differentiable
86
+ # get the batchsize
87
+ if torch.is_tensor(space_cache): #space cache
88
+ batch_size = space_cache.shape[0]
89
+ elif isinstance(space_cache, Dict): #hyper net
90
+ # Dict[str, List[Float[Tensor, "B ..."]]]
91
+ for key in space_cache.keys():
92
+ batch_size = space_cache[key][0].shape[0]
93
+ break
94
+
95
+ # scale the points to [-1, 1]
96
+ points = scale_tensor(
97
+ isosurface_helper.grid_vertices.to(space_cache.device),
98
+ isosurface_helper.points_range,
99
+ [-1, 1], # hard coded isosurface_bbox
100
+ )
101
+ # get the sdf values
102
+ sdf_batch, deformation_batch = forward_field(
103
+ points[None, ...].expand(batch_size, -1, -1),
104
+ space_cache
105
+ )
106
+
107
+ # get the isosurface
108
+ mesh_list = []
109
+
110
+ # check if the sdf is empty
111
+ # for sdf, deformation in zip(sdf_batch, deformation_batch):
112
+ for index in range(sdf_batch.shape[0]):
113
+ sdf = sdf_batch[index]
114
+
115
+ # the deformation may be None
116
+ if deformation_batch is None:
117
+ deformation = None
118
+ else:
119
+ deformation = deformation_batch[index]
120
+
121
+ # special case when all sdf values are positive or negative, thus no isosurface
122
+ if torch.all(sdf > 0) or torch.all(sdf < 0):
123
+
124
+ print(f"All sdf values are positive or negative, no isosurface")
125
+ sdf = torch.norm(points, dim=-1) - 1
126
+
127
+ mesh = isosurface_helper(sdf, deformation)
128
+
129
+ mesh.v_pos = scale_tensor(
130
+ mesh.v_pos,
131
+ isosurface_helper.points_range,
132
+ [-1, 1], # hard coded isosurface_bbox
133
+ )
134
+
135
+ # TODO: implement outlier removal
136
+ # if cfg.isosurface_remove_outliers:
137
+ # mesh = mesh.remove_outlier(cfg.isosurface_outlier_n_faces_threshold)
138
+
139
+ mesh_list.append(mesh)
140
+
141
+ return mesh_list
142
+
143
+ def colorize_mesh(
144
+ space_cache: Any,
145
+ export_fn: Callable,
146
+ mesh_list: List[Mesh],
147
+ activation: Callable,
148
+ ) -> List[Mesh]:
149
+ """Colorize the mesh using the geometry's export function and space cache.
150
+
151
+ Args:
152
+ space_cache: The space cache containing feature information
153
+ export_fn: The export function from geometry that generates features
154
+ mesh_list: List of meshes to colorize
155
+
156
+ Returns:
157
+ List[Mesh]: List of colorized meshes
158
+ """
159
+ # Process each mesh in the batch
160
+ for i, mesh in enumerate(mesh_list):
161
+ # Get vertex positions
162
+ points = mesh.v_pos[None, ...] # Add batch dimension [1, N, 3]
163
+
164
+ # Get the corresponding space cache slice for this mesh
165
+ if torch.is_tensor(space_cache):
166
+ space_cache_slice = space_cache[i:i+1]
167
+ elif isinstance(space_cache, dict):
168
+ space_cache_slice = {}
169
+ for key in space_cache.keys():
170
+ space_cache_slice[key] = [
171
+ weight[i:i+1] for weight in space_cache[key]
172
+ ]
173
+
174
+ # Export features for the vertices
175
+ out = export_fn(points, space_cache_slice)
176
+
177
+ # Update vertex colors if features exist
178
+ if "features" in out:
179
+ features = out["features"].squeeze(0) # Remove batch dim [N, C]
180
+ # Convert features to RGB colors
181
+ mesh._v_rgb = activation(features) # Access private attribute directly
182
+
183
+ return mesh_list
184
+
185
+ class MeshExporter(SaverMixin):
186
+ def __init__(self, save_dir="outputs"):
187
+ self.save_dir = save_dir
188
+ os.makedirs(save_dir, exist_ok=True)
189
+
190
+ def get_save_dir(self):
191
+ return self.save_dir
192
+
193
+ def get_save_path(self, filename):
194
+ return os.path.join(self.save_dir, filename)
195
+
196
+ def convert_data(self, x):
197
+ if isinstance(x, torch.Tensor):
198
+ return x.detach().cpu().numpy()
199
+ return x
200
+
201
+ def export_obj(
202
+ mesh: Mesh,
203
+ save_path: str,
204
+ save_normal: bool = False,
205
+ ) -> List[str]:
206
+ """
207
+ Export mesh data to OBJ file format.
208
+
209
+ Args:
210
+ mesh_data: Dictionary containing mesh data (vertices, faces, etc.)
211
+ save_path: Path to save the OBJ file
212
+
213
+ Returns:
214
+ List of saved file paths
215
+ """
216
+
217
+ # Create exporter
218
+ exporter = MeshExporter(os.path.dirname(save_path))
219
+
220
+ # Export mesh
221
+ save_paths = exporter.save_obj(
222
+ os.path.basename(save_path),
223
+ mesh,
224
+ save_mat=None,
225
+ save_normal=save_normal and mesh.v_nrm is not None,
226
+ save_uv=False,
227
+ save_vertex_color=mesh.v_rgb is not None,
228
+ )
229
+
230
+ return save_paths
231
+
triplaneturbo_executable/utils/saving.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import shutil
5
+
6
+ import cv2
7
+ import imageio
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import torch
11
+ import trimesh
12
+ import wandb
13
+ from matplotlib import cm
14
+ from matplotlib.colors import LinearSegmentedColormap
15
+ from PIL import Image, ImageDraw
16
+ from pytorch_lightning.loggers import WandbLogger
17
+
18
+ from ..utils.mesh import Mesh
19
+
20
+ from typing import Dict, List, Optional, Union, Any
21
+ from omegaconf import DictConfig
22
+ from jaxtyping import Float
23
+ from torch import Tensor
24
+
25
+ import threading
26
+
27
+ class SaverMixin:
28
+ _save_dir: Optional[str] = None
29
+ _wandb_logger: Optional[WandbLogger] = None
30
+
31
+ def set_save_dir(self, save_dir: str):
32
+ self._save_dir = save_dir
33
+
34
+ def get_save_dir(self):
35
+ if self._save_dir is None:
36
+ raise ValueError("Save dir is not set")
37
+ return self._save_dir
38
+
39
+ def convert_data(self, data):
40
+ if data is None:
41
+ return None
42
+ elif isinstance(data, np.ndarray):
43
+ return data
44
+ elif isinstance(data, torch.Tensor):
45
+ return data.detach().cpu().numpy()
46
+ elif isinstance(data, list):
47
+ return [self.convert_data(d) for d in data]
48
+ elif isinstance(data, dict):
49
+ return {k: self.convert_data(v) for k, v in data.items()}
50
+ else:
51
+ raise TypeError(
52
+ "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
53
+ type(data),
54
+ )
55
+
56
+ def get_save_path(self, filename):
57
+ save_path = os.path.join(self.get_save_dir(), filename)
58
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
59
+ return save_path
60
+
61
+ def create_loggers(self, cfg_loggers: DictConfig) -> None:
62
+ if "wandb" in cfg_loggers.keys() and cfg_loggers.wandb.enable:
63
+ self._wandb_logger = WandbLogger(
64
+ project=cfg_loggers.wandb.project, name=cfg_loggers.wandb.name
65
+ )
66
+
67
+ def get_loggers(self) -> List:
68
+ if self._wandb_logger:
69
+ return [self._wandb_logger]
70
+ else:
71
+ return []
72
+
73
+ DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)}
74
+ DEFAULT_UV_KWARGS = {
75
+ "data_format": "HWC",
76
+ "data_range": (0, 1),
77
+ "cmap": "checkerboard",
78
+ }
79
+ DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"}
80
+ DEFAULT_GRID_KWARGS = {"align": "max"}
81
+
82
+ def get_rgb_image_(self, img, data_format, data_range, rgba=False):
83
+ img = self.convert_data(img)
84
+ assert data_format in ["CHW", "HWC"]
85
+ if data_format == "CHW":
86
+ img = img.transpose(1, 2, 0)
87
+ if img.dtype != np.uint8:
88
+ img = img.clip(min=data_range[0], max=data_range[1])
89
+ img = (
90
+ (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0
91
+ ).astype(np.uint8)
92
+ nc = 4 if rgba else 3
93
+ imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)]
94
+ imgs = [
95
+ img_
96
+ if img_.shape[-1] == nc
97
+ else np.concatenate(
98
+ [
99
+ img_,
100
+ np.zeros(
101
+ (img_.shape[0], img_.shape[1], nc - img_.shape[2]),
102
+ dtype=img_.dtype,
103
+ ),
104
+ ],
105
+ axis=-1,
106
+ )
107
+ for img_ in imgs
108
+ ]
109
+ img = np.concatenate(imgs, axis=1)
110
+ if rgba:
111
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
112
+ else:
113
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
114
+ return img
115
+
116
+ def _save_rgb_image(
117
+ self,
118
+ filename,
119
+ img,
120
+ data_format,
121
+ data_range,
122
+ name: Optional[str] = None,
123
+ step: Optional[int] = None,
124
+ ):
125
+ img = self.get_rgb_image_(img, data_format, data_range)
126
+ cv2.imwrite(filename, img)
127
+ if name and self._wandb_logger:
128
+ wandb.log(
129
+ {
130
+ name: wandb.Image(self.get_save_path(filename)),
131
+ "trainer/global_step": step,
132
+ }
133
+ )
134
+
135
+ def save_rgb_image(
136
+ self,
137
+ filename,
138
+ img,
139
+ data_format=DEFAULT_RGB_KWARGS["data_format"],
140
+ data_range=DEFAULT_RGB_KWARGS["data_range"],
141
+ name: Optional[str] = None,
142
+ step: Optional[int] = None,
143
+ ) -> str:
144
+ save_path = self.get_save_path(filename)
145
+ self._save_rgb_image(save_path, img, data_format, data_range, name, step)
146
+ return save_path
147
+
148
+ def get_uv_image_(self, img, data_format, data_range, cmap):
149
+ img = self.convert_data(img)
150
+ assert data_format in ["CHW", "HWC"]
151
+ if data_format == "CHW":
152
+ img = img.transpose(1, 2, 0)
153
+ img = img.clip(min=data_range[0], max=data_range[1])
154
+ img = (img - data_range[0]) / (data_range[1] - data_range[0])
155
+ assert cmap in ["checkerboard", "color"]
156
+ if cmap == "checkerboard":
157
+ n_grid = 64
158
+ mask = (img * n_grid).astype(int)
159
+ mask = (mask[..., 0] + mask[..., 1]) % 2 == 0
160
+ img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255
161
+ img[mask] = np.array([255, 0, 255], dtype=np.uint8)
162
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
163
+ elif cmap == "color":
164
+ img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
165
+ img_[..., 0] = (img[..., 0] * 255).astype(np.uint8)
166
+ img_[..., 1] = (img[..., 1] * 255).astype(np.uint8)
167
+ img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR)
168
+ img = img_
169
+ return img
170
+
171
+ def save_uv_image(
172
+ self,
173
+ filename,
174
+ img,
175
+ data_format=DEFAULT_UV_KWARGS["data_format"],
176
+ data_range=DEFAULT_UV_KWARGS["data_range"],
177
+ cmap=DEFAULT_UV_KWARGS["cmap"],
178
+ ) -> str:
179
+ save_path = self.get_save_path(filename)
180
+ img = self.get_uv_image_(img, data_format, data_range, cmap)
181
+ cv2.imwrite(save_path, img)
182
+ return save_path
183
+
184
+ def get_grayscale_image_(self, img, data_range, cmap):
185
+ img = self.convert_data(img)
186
+ img = np.nan_to_num(img)
187
+ if data_range is None:
188
+ img = (img - img.min()) / (img.max() - img.min())
189
+ else:
190
+ img = img.clip(data_range[0], data_range[1])
191
+ img = (img - data_range[0]) / (data_range[1] - data_range[0])
192
+ assert cmap in [None, "jet", "magma", "spectral"]
193
+ if cmap == None:
194
+ img = (img * 255.0).astype(np.uint8)
195
+ img = np.repeat(img[..., None], 3, axis=2)
196
+ elif cmap == "jet":
197
+ img = (img * 255.0).astype(np.uint8)
198
+ img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
199
+ elif cmap == "magma":
200
+ img = 1.0 - img
201
+ base = cm.get_cmap("magma")
202
+ num_bins = 256
203
+ colormap = LinearSegmentedColormap.from_list(
204
+ f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins
205
+ )(np.linspace(0, 1, num_bins))[:, :3]
206
+ a = np.floor(img * 255.0)
207
+ b = (a + 1).clip(max=255.0)
208
+ f = img * 255.0 - a
209
+ a = a.astype(np.uint16).clip(0, 255)
210
+ b = b.astype(np.uint16).clip(0, 255)
211
+ img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None]
212
+ img = (img * 255.0).astype(np.uint8)
213
+ elif cmap == "spectral":
214
+ colormap = plt.get_cmap("Spectral")
215
+
216
+ def blend_rgba(image):
217
+ image = image[..., :3] * image[..., -1:] + (
218
+ 1.0 - image[..., -1:]
219
+ ) # blend A to RGB
220
+ return image
221
+
222
+ img = colormap(img)
223
+ img = blend_rgba(img)
224
+ img = (img * 255).astype(np.uint8)
225
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
226
+ return img
227
+
228
+ def _save_grayscale_image(
229
+ self,
230
+ filename,
231
+ img,
232
+ data_range,
233
+ cmap,
234
+ name: Optional[str] = None,
235
+ step: Optional[int] = None,
236
+ ):
237
+ img = self.get_grayscale_image_(img, data_range, cmap)
238
+ cv2.imwrite(filename, img)
239
+ if name and self._wandb_logger:
240
+ wandb.log(
241
+ {
242
+ name: wandb.Image(self.get_save_path(filename)),
243
+ "trainer/global_step": step,
244
+ }
245
+ )
246
+
247
+ def save_grayscale_image(
248
+ self,
249
+ filename,
250
+ img,
251
+ data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"],
252
+ cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"],
253
+ name: Optional[str] = None,
254
+ step: Optional[int] = None,
255
+ ) -> str:
256
+ save_path = self.get_save_path(filename)
257
+ self._save_grayscale_image(save_path, img, data_range, cmap, name, step)
258
+ return save_path
259
+
260
+ def get_image_grid_(self, imgs, align):
261
+ if isinstance(imgs[0], list):
262
+ return np.concatenate(
263
+ [self.get_image_grid_(row, align) for row in imgs], axis=0
264
+ )
265
+ cols = []
266
+ for col in imgs:
267
+ assert col["type"] in ["rgb", "uv", "grayscale"]
268
+ if col["type"] == "rgb":
269
+ rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy()
270
+ rgb_kwargs.update(col["kwargs"])
271
+ cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs))
272
+ elif col["type"] == "uv":
273
+ uv_kwargs = self.DEFAULT_UV_KWARGS.copy()
274
+ uv_kwargs.update(col["kwargs"])
275
+ cols.append(self.get_uv_image_(col["img"], **uv_kwargs))
276
+ elif col["type"] == "grayscale":
277
+ grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy()
278
+ grayscale_kwargs.update(col["kwargs"])
279
+ cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs))
280
+
281
+ if align == "max":
282
+ h = max([col.shape[0] for col in cols])
283
+ w = max([col.shape[1] for col in cols])
284
+ elif align == "min":
285
+ h = min([col.shape[0] for col in cols])
286
+ w = min([col.shape[1] for col in cols])
287
+ elif isinstance(align, int):
288
+ h = align
289
+ w = align
290
+ elif (
291
+ isinstance(align, tuple)
292
+ and isinstance(align[0], int)
293
+ and isinstance(align[1], int)
294
+ ):
295
+ h, w = align
296
+ else:
297
+ raise ValueError(
298
+ f"Unsupported image grid align: {align}, should be min, max, int or (int, int)"
299
+ )
300
+
301
+ for i in range(len(cols)):
302
+ if cols[i].shape[0] != h or cols[i].shape[1] != w:
303
+ cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_LINEAR)
304
+ return np.concatenate(cols, axis=1)
305
+
306
+ def save_image_grid(
307
+ self,
308
+ filename,
309
+ imgs,
310
+ align=DEFAULT_GRID_KWARGS["align"],
311
+ name: Optional[str] = None,
312
+ step: Optional[int] = None,
313
+ texts: Optional[List[float]] = None,
314
+ ):
315
+ save_path = self.get_save_path(filename)
316
+ img = self.get_image_grid_(imgs, align=align)
317
+
318
+ if texts is not None:
319
+ img = Image.fromarray(img)
320
+ draw = ImageDraw.Draw(img)
321
+ black, white = (0, 0, 0), (255, 255, 255)
322
+ for i, text in enumerate(texts):
323
+ draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white)
324
+ draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white)
325
+ draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white)
326
+ draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white)
327
+ draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black)
328
+ img = np.asarray(img)
329
+
330
+ cv2.imwrite(save_path, img)
331
+ if name and self._wandb_logger:
332
+ wandb.log({name: wandb.Image(save_path), "trainer/global_step": step})
333
+ return save_path
334
+
335
+ def save_image(self, filename, img) -> str:
336
+ save_path = self.get_save_path(filename)
337
+ img = self.convert_data(img)
338
+ assert img.dtype == np.uint8 or img.dtype == np.uint16
339
+ if img.ndim == 3 and img.shape[-1] == 3:
340
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
341
+ elif img.ndim == 3 and img.shape[-1] == 4:
342
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
343
+ cv2.imwrite(save_path, img)
344
+ return save_path
345
+
346
+ def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str:
347
+ save_path = self.get_save_path(filename)
348
+ img = self.convert_data(img)
349
+ assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2]
350
+
351
+ imgs_full = []
352
+ for start in range(0, img.shape[-1], 3):
353
+ img_ = img[..., start : start + 3]
354
+ img_ = np.stack(
355
+ [
356
+ self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba)
357
+ for i in range(img_.shape[0])
358
+ ],
359
+ axis=0,
360
+ )
361
+ size = img_.shape[1]
362
+ placeholder = np.zeros((size, size, 3), dtype=np.float32)
363
+ img_full = np.concatenate(
364
+ [
365
+ np.concatenate(
366
+ [placeholder, img_[2], placeholder, placeholder], axis=1
367
+ ),
368
+ np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1),
369
+ np.concatenate(
370
+ [placeholder, img_[3], placeholder, placeholder], axis=1
371
+ ),
372
+ ],
373
+ axis=0,
374
+ )
375
+ imgs_full.append(img_full)
376
+
377
+ imgs_full = np.concatenate(imgs_full, axis=1)
378
+ cv2.imwrite(save_path, imgs_full)
379
+ return save_path
380
+
381
+ def save_data(self, filename, data) -> str:
382
+ data = self.convert_data(data)
383
+ if isinstance(data, dict):
384
+ if not filename.endswith(".npz"):
385
+ filename += ".npz"
386
+ save_path = self.get_save_path(filename)
387
+ np.savez(save_path, **data)
388
+ else:
389
+ if not filename.endswith(".npy"):
390
+ filename += ".npy"
391
+ save_path = self.get_save_path(filename)
392
+ np.save(save_path, data)
393
+ return save_path
394
+
395
+ def save_state_dict(self, filename, data) -> str:
396
+ save_path = self.get_save_path(filename)
397
+ torch.save(data, save_path)
398
+ return save_path
399
+
400
+ # def save_img_sequence(
401
+ # self,
402
+ # filename,
403
+ # img_dir,
404
+ # matcher,
405
+ # save_format="mp4",
406
+ # fps=30,
407
+ # name: Optional[str] = None,
408
+ # step: Optional[int] = None,
409
+ # ) -> str:
410
+ # assert save_format in ["gif", "mp4"]
411
+ # if not filename.endswith(save_format):
412
+ # filename += f".{save_format}"
413
+ # save_path = self.get_save_path(filename)
414
+ # matcher = re.compile(matcher)
415
+ # img_dir = os.path.join(self.get_save_dir(), img_dir)
416
+ # imgs = []
417
+ # for f in os.listdir(img_dir):
418
+ # if matcher.search(f):
419
+ # imgs.append(f)
420
+ # imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0]))
421
+ # imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs]
422
+
423
+ # if save_format == "gif":
424
+ # imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
425
+ # imageio.mimsave(save_path, imgs, fps=fps, palettesize=256)
426
+ # elif save_format == "mp4":
427
+ # imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
428
+ # imageio.mimsave(save_path, imgs, fps=fps)
429
+ # if name and self._wandb_logger:
430
+ # wandb.log(
431
+ # {
432
+ # name: wandb.Video(save_path, format="mp4"),
433
+ # "trainer/global_step": step,
434
+ # }
435
+ # )
436
+ # return save_path
437
+
438
+ def save_img_sequence(
439
+ self,
440
+ filename,
441
+ img_dir,
442
+ matcher,
443
+ save_format="mp4",
444
+ fps=30,
445
+ name: Optional[str] = None,
446
+ step: Optional[int] = None,
447
+ multithreaded: bool = False
448
+ ) -> str:
449
+ assert save_format in ["gif", "mp4"]
450
+ if not filename.endswith(save_format):
451
+ filename += f".{save_format}"
452
+ save_path = self.get_save_path(filename)
453
+ matcher = re.compile(matcher)
454
+ img_dir = os.path.join(self.get_save_dir(), img_dir)
455
+ imgs = []
456
+ for f in os.listdir(img_dir):
457
+ if matcher.search(f):
458
+ imgs.append(f)
459
+ imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0]))
460
+ imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs]
461
+
462
+ if save_format == "gif":
463
+ imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
464
+ if multithreaded:
465
+ # threestudio.info("Multithreaded gif saving: {}".format(save_path))
466
+ thread = threading.Thread(target=imageio.mimsave, args=(save_path, imgs), kwargs={"fps": fps})
467
+ thread.start()
468
+ else:
469
+ imageio.mimsave(save_path, imgs, fps=fps, palettesize=256)
470
+ elif save_format == "mp4":
471
+ imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
472
+ if multithreaded:
473
+ # threestudio.info("Multithreaded mp4 saving: {}".format(save_path))
474
+ thread = threading.Thread(target=imageio.mimsave, args=(save_path, imgs), kwargs={"fps": fps})
475
+ thread.start()
476
+ else:
477
+ imageio.mimsave(save_path, imgs, fps=fps)
478
+ if name and self._wandb_logger:
479
+ wandb.log(
480
+ {
481
+ name: wandb.Video(save_path, format="mp4"),
482
+ "trainer/global_step": step,
483
+ }
484
+ )
485
+ return save_path
486
+
487
+ def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None) -> str:
488
+ save_path = self.get_save_path(filename)
489
+ v_pos = self.convert_data(v_pos)
490
+ t_pos_idx = self.convert_data(t_pos_idx)
491
+ mesh = trimesh.Trimesh(vertices=v_pos, faces=t_pos_idx)
492
+ mesh.export(save_path)
493
+ return save_path
494
+
495
+ def save_obj(
496
+ self,
497
+ filename: str,
498
+ mesh: Mesh,
499
+ save_mat: bool = False,
500
+ save_normal: bool = False,
501
+ save_uv: bool = False,
502
+ save_vertex_color: bool = False,
503
+ map_Kd: Optional[Float[Tensor, "H W 3"]] = None,
504
+ map_Ks: Optional[Float[Tensor, "H W 3"]] = None,
505
+ map_Bump: Optional[Float[Tensor, "H W 3"]] = None,
506
+ map_Pm: Optional[Float[Tensor, "H W 1"]] = None,
507
+ map_Pr: Optional[Float[Tensor, "H W 1"]] = None,
508
+ map_format: str = "jpg",
509
+ ) -> List[str]:
510
+
511
+ if not filename.endswith(".obj"):
512
+ filename += ".obj"
513
+ save_path = self.get_save_path(filename)
514
+ v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data(
515
+ mesh.t_pos_idx
516
+ )
517
+ v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None
518
+ if save_normal:
519
+ v_nrm = self.convert_data(mesh.v_nrm)
520
+ if save_uv:
521
+ v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data(
522
+ mesh.t_tex_idx
523
+ )
524
+ if save_vertex_color:
525
+ v_rgb = self.convert_data(mesh.v_rgb)
526
+
527
+ # use trimesh to save obj
528
+ mesh = trimesh.Trimesh(
529
+ vertices=v_pos,
530
+ faces=t_pos_idx,
531
+ vertex_normals=v_nrm,
532
+ vertex_colors=v_rgb,
533
+ visual=trimesh.visual.TextureVisuals(
534
+ uv=v_tex,
535
+ face_uv=t_tex_idx
536
+ ) if save_uv else None
537
+ )
538
+
539
+ # save the mesh to obj
540
+ mesh.export(save_path)
541
+ return [save_path]
542
+
543
+ # def save_obj(
544
+ # self,
545
+ # filename: str,
546
+ # mesh: Mesh,
547
+ # save_mat: bool = False,
548
+ # save_normal: bool = False,
549
+ # save_uv: bool = False,
550
+ # save_vertex_color: bool = False,
551
+ # map_Kd: Optional[Float[Tensor, "H W 3"]] = None,
552
+ # map_Ks: Optional[Float[Tensor, "H W 3"]] = None,
553
+ # map_Bump: Optional[Float[Tensor, "H W 3"]] = None,
554
+ # map_Pm: Optional[Float[Tensor, "H W 1"]] = None,
555
+ # map_Pr: Optional[Float[Tensor, "H W 1"]] = None,
556
+ # map_format: str = "jpg",
557
+ # ) -> List[str]:
558
+ # save_paths: List[str] = []
559
+ # if not filename.endswith(".obj"):
560
+ # filename += ".obj"
561
+ # v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data(
562
+ # mesh.t_pos_idx
563
+ # )
564
+ # v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None
565
+ # if save_normal:
566
+ # v_nrm = self.convert_data(mesh.v_nrm)
567
+ # if save_uv:
568
+ # v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data(
569
+ # mesh.t_tex_idx
570
+ # )
571
+ # if save_vertex_color:
572
+ # v_rgb = self.convert_data(mesh.v_rgb)
573
+ # matname, mtllib = None, None
574
+ # if save_mat:
575
+ # matname = "default"
576
+ # mtl_filename = filename.replace(".obj", ".mtl")
577
+ # mtllib = os.path.basename(mtl_filename)
578
+ # mtl_save_paths = self._save_mtl(
579
+ # mtl_filename,
580
+ # matname,
581
+ # map_Kd=self.convert_data(map_Kd),
582
+ # map_Ks=self.convert_data(map_Ks),
583
+ # map_Bump=self.convert_data(map_Bump),
584
+ # map_Pm=self.convert_data(map_Pm),
585
+ # map_Pr=self.convert_data(map_Pr),
586
+ # map_format=map_format,
587
+ # )
588
+ # save_paths += mtl_save_paths
589
+ # obj_save_path = self._save_obj(
590
+ # filename,
591
+ # v_pos,
592
+ # t_pos_idx,
593
+ # v_nrm=v_nrm,
594
+ # v_tex=v_tex,
595
+ # t_tex_idx=t_tex_idx,
596
+ # v_rgb=v_rgb,
597
+ # matname=matname,
598
+ # mtllib=mtllib,
599
+ # )
600
+ # save_paths.append(obj_save_path)
601
+ # return save_paths
602
+
603
+ # def _save_obj(
604
+ # self,
605
+ # filename,
606
+ # v_pos,
607
+ # t_pos_idx,
608
+ # v_nrm=None,
609
+ # v_tex=None,
610
+ # t_tex_idx=None,
611
+ # v_rgb=None,
612
+ # matname=None,
613
+ # mtllib=None,
614
+ # ) -> str:
615
+ # obj_str = ""
616
+ # if matname is not None:
617
+ # obj_str += f"mtllib {mtllib}\n"
618
+ # obj_str += f"g object\n"
619
+ # obj_str += f"usemtl {matname}\n"
620
+ # for i in range(len(v_pos)):
621
+ # obj_str += f"v {v_pos[i][0]} {v_pos[i][1]} {v_pos[i][2]}"
622
+ # if v_rgb is not None:
623
+ # obj_str += f" {v_rgb[i][0]} {v_rgb[i][1]} {v_rgb[i][2]}"
624
+ # obj_str += "\n"
625
+ # if v_nrm is not None:
626
+ # for v in v_nrm:
627
+ # obj_str += f"vn {v[0]} {v[1]} {v[2]}\n"
628
+ # if v_tex is not None:
629
+ # for v in v_tex:
630
+ # obj_str += f"vt {v[0]} {1.0 - v[1]}\n"
631
+
632
+ # for i in range(len(t_pos_idx)):
633
+ # obj_str += "f"
634
+ # for j in range(3):
635
+ # obj_str += f" {t_pos_idx[i][j] + 1}/"
636
+ # if v_tex is not None:
637
+ # obj_str += f"{t_tex_idx[i][j] + 1}"
638
+ # obj_str += "/"
639
+ # if v_nrm is not None:
640
+ # obj_str += f"{t_pos_idx[i][j] + 1}"
641
+ # obj_str += "\n"
642
+
643
+ # save_path = self.get_save_path(filename)
644
+ # with open(save_path, "w") as f:
645
+ # f.write(obj_str)
646
+ # return save_path
647
+
648
+ def _save_mtl(
649
+ self,
650
+ filename,
651
+ matname,
652
+ Ka=(0.0, 0.0, 0.0),
653
+ Kd=(1.0, 1.0, 1.0),
654
+ Ks=(0.0, 0.0, 0.0),
655
+ map_Kd=None,
656
+ map_Ks=None,
657
+ map_Bump=None,
658
+ map_Pm=None,
659
+ map_Pr=None,
660
+ map_format="jpg",
661
+ step: Optional[int] = None,
662
+ ) -> List[str]:
663
+ mtl_save_path = self.get_save_path(filename)
664
+ save_paths = [mtl_save_path]
665
+ mtl_str = f"newmtl {matname}\n"
666
+ mtl_str += f"Ka {Ka[0]} {Ka[1]} {Ka[2]}\n"
667
+ if map_Kd is not None:
668
+ map_Kd_save_path = os.path.join(
669
+ os.path.dirname(mtl_save_path), f"texture_kd.{map_format}"
670
+ )
671
+ mtl_str += f"map_Kd texture_kd.{map_format}\n"
672
+ self._save_rgb_image(
673
+ map_Kd_save_path,
674
+ map_Kd,
675
+ data_format="HWC",
676
+ data_range=(0, 1),
677
+ name=f"{matname}_Kd",
678
+ step=step,
679
+ )
680
+ save_paths.append(map_Kd_save_path)
681
+ else:
682
+ mtl_str += f"Kd {Kd[0]} {Kd[1]} {Kd[2]}\n"
683
+ if map_Ks is not None:
684
+ map_Ks_save_path = os.path.join(
685
+ os.path.dirname(mtl_save_path), f"texture_ks.{map_format}"
686
+ )
687
+ mtl_str += f"map_Ks texture_ks.{map_format}\n"
688
+ self._save_rgb_image(
689
+ map_Ks_save_path,
690
+ map_Ks,
691
+ data_format="HWC",
692
+ data_range=(0, 1),
693
+ name=f"{matname}_Ks",
694
+ step=step,
695
+ )
696
+ save_paths.append(map_Ks_save_path)
697
+ else:
698
+ mtl_str += f"Ks {Ks[0]} {Ks[1]} {Ks[2]}\n"
699
+ if map_Bump is not None:
700
+ map_Bump_save_path = os.path.join(
701
+ os.path.dirname(mtl_save_path), f"texture_nrm.{map_format}"
702
+ )
703
+ mtl_str += f"map_Bump texture_nrm.{map_format}\n"
704
+ self._save_rgb_image(
705
+ map_Bump_save_path,
706
+ map_Bump,
707
+ data_format="HWC",
708
+ data_range=(0, 1),
709
+ name=f"{matname}_Bump",
710
+ step=step,
711
+ )
712
+ save_paths.append(map_Bump_save_path)
713
+ if map_Pm is not None:
714
+ map_Pm_save_path = os.path.join(
715
+ os.path.dirname(mtl_save_path), f"texture_metallic.{map_format}"
716
+ )
717
+ mtl_str += f"map_Pm texture_metallic.{map_format}\n"
718
+ self._save_grayscale_image(
719
+ map_Pm_save_path,
720
+ map_Pm,
721
+ data_range=(0, 1),
722
+ cmap=None,
723
+ name=f"{matname}_refl",
724
+ step=step,
725
+ )
726
+ save_paths.append(map_Pm_save_path)
727
+ if map_Pr is not None:
728
+ map_Pr_save_path = os.path.join(
729
+ os.path.dirname(mtl_save_path), f"texture_roughness.{map_format}"
730
+ )
731
+ mtl_str += f"map_Pr texture_roughness.{map_format}\n"
732
+ self._save_grayscale_image(
733
+ map_Pr_save_path,
734
+ map_Pr,
735
+ data_range=(0, 1),
736
+ cmap=None,
737
+ name=f"{matname}_Ns",
738
+ step=step,
739
+ )
740
+ save_paths.append(map_Pr_save_path)
741
+ with open(self.get_save_path(filename), "w") as f:
742
+ f.write(mtl_str)
743
+ return save_paths
744
+
745
+ def save_file(self, filename, src_path) -> str:
746
+ save_path = self.get_save_path(filename)
747
+ shutil.copyfile(src_path, save_path)
748
+ return save_path
749
+
750
+ def save_json(self, filename, payload) -> str:
751
+ save_path = self.get_save_path(filename)
752
+ with open(save_path, "w") as f:
753
+ f.write(json.dumps(payload))
754
+ return save_path