Vibu46vk commited on
Commit
709e7c1
·
verified ·
1 Parent(s): f4a0891

Upload 88 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. pix2pix3D-main/pix2pix3D-main/.gitignore +146 -0
  3. pix2pix3D-main/pix2pix3D-main/LICENSE +21 -0
  4. pix2pix3D-main/pix2pix3D-main/README.md +170 -0
  5. pix2pix3D-main/pix2pix3D-main/applications/demo/qt_demo_seg2cat.py +504 -0
  6. pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/__init__.py +0 -0
  7. pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/mouse_event.py +100 -0
  8. pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/ui.py +988 -0
  9. pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/ui_clean.py +797 -0
  10. pix2pix3D-main/pix2pix3D-main/applications/edge2cat.ipynb +0 -0
  11. pix2pix3D-main/pix2pix3D-main/applications/extract_mesh.py +267 -0
  12. pix2pix3D-main/pix2pix3D-main/applications/generate_samples.py +128 -0
  13. pix2pix3D-main/pix2pix3D-main/applications/generate_video.py +220 -0
  14. pix2pix3D-main/pix2pix3D-main/assets/demo.mp4 +3 -0
  15. pix2pix3D-main/pix2pix3D-main/assets/rendered_mesh_colored.gif +3 -0
  16. pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1.gif +3 -0
  17. pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_color.png +3 -0
  18. pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_label.png +0 -0
  19. pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_input.png +0 -0
  20. pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1_label.gif +3 -0
  21. pix2pix3D-main/pix2pix3D-main/assets/teaser_gif.gif +3 -0
  22. pix2pix3D-main/pix2pix3D-main/assets/teaser_jpg.jpg +3 -0
  23. pix2pix3D-main/pix2pix3D-main/assets/teaser_png.png +3 -0
  24. pix2pix3D-main/pix2pix3D-main/camera_utils.py +149 -0
  25. pix2pix3D-main/pix2pix3D-main/checkpoints/download_models.sh +5 -0
  26. pix2pix3D-main/pix2pix3D-main/dnnlib/__init__.py +11 -0
  27. pix2pix3D-main/pix2pix3D-main/dnnlib/util.py +493 -0
  28. pix2pix3D-main/pix2pix3D-main/environment.yml +39 -0
  29. pix2pix3D-main/pix2pix3D-main/examples/example_input.png +0 -0
  30. pix2pix3D-main/pix2pix3D-main/examples/example_input_edge2car.png +0 -0
  31. pix2pix3D-main/pix2pix3D-main/examples/example_input_edge2cat.png +0 -0
  32. pix2pix3D-main/pix2pix3D-main/legacy.py +325 -0
  33. pix2pix3D-main/pix2pix3D-main/metrics/__init__.py +11 -0
  34. pix2pix3D-main/pix2pix3D-main/metrics/equivariance.py +269 -0
  35. pix2pix3D-main/pix2pix3D-main/metrics/frechet_inception_distance.py +43 -0
  36. pix2pix3D-main/pix2pix3D-main/metrics/inception_score.py +40 -0
  37. pix2pix3D-main/pix2pix3D-main/metrics/kernel_inception_distance.py +48 -0
  38. pix2pix3D-main/pix2pix3D-main/metrics/metric_main.py +155 -0
  39. pix2pix3D-main/pix2pix3D-main/metrics/metric_utils.py +281 -0
  40. pix2pix3D-main/pix2pix3D-main/metrics/perceptual_path_length.py +127 -0
  41. pix2pix3D-main/pix2pix3D-main/metrics/precision_recall.py +64 -0
  42. pix2pix3D-main/pix2pix3D-main/torch_utils/__init__.py +11 -0
  43. pix2pix3D-main/pix2pix3D-main/torch_utils/custom_ops.py +159 -0
  44. pix2pix3D-main/pix2pix3D-main/torch_utils/misc.py +280 -0
  45. pix2pix3D-main/pix2pix3D-main/torch_utils/ops/__init__.py +11 -0
  46. pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.cpp +103 -0
  47. pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.cu +177 -0
  48. pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.h +42 -0
  49. pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.py +211 -0
  50. pix2pix3D-main/pix2pix3D-main/torch_utils/ops/conv2d_gradfix.py +199 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pix2pix3D-main/pix2pix3D-main/assets/demo.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ pix2pix3D-main/pix2pix3D-main/assets/rendered_mesh_colored.gif filter=lfs diff=lfs merge=lfs -text
38
+ pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1_label.gif filter=lfs diff=lfs merge=lfs -text
39
+ pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1.gif filter=lfs diff=lfs merge=lfs -text
40
+ pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_color.png filter=lfs diff=lfs merge=lfs -text
41
+ pix2pix3D-main/pix2pix3D-main/assets/teaser_gif.gif filter=lfs diff=lfs merge=lfs -text
42
+ pix2pix3D-main/pix2pix3D-main/assets/teaser_jpg.jpg filter=lfs diff=lfs merge=lfs -text
43
+ pix2pix3D-main/pix2pix3D-main/assets/teaser_png.png filter=lfs diff=lfs merge=lfs -text
pix2pix3D-main/pix2pix3D-main/.gitignore ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # pytype static type analyzer
135
+ .pytype/
136
+
137
+ # Cython debug symbols
138
+ cython_debug/
139
+
140
+ logs
141
+ data
142
+ wandb
143
+ scripts
144
+ web
145
+ *.pyc
146
+ .vscode
pix2pix3D-main/pix2pix3D-main/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Kangle Deng
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
pix2pix3D-main/pix2pix3D-main/README.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # 3D-aware Conditional Image Synthesis (pix2pix3D)
4
+
5
+ [**Project**](https://www.cs.cmu.edu/~pix2pix3D/) | [**Paper**](https://arxiv.org/abs/2302.08509)
6
+
7
+ This is the official PyTorch implementation of "3D-aware Conditional Image Synthesis". Pix2pix3D synthesizes 3D objects (neural fields) given a 2D label map, such as a segmentation or edge map. We also provide an interactive 3D editing demo.
8
+
9
+ https://user-images.githubusercontent.com/28395429/222578030-8bb2c727-397e-44b6-9ab1-9b0b09dd3b5b.mp4
10
+
11
+
12
+ [3D-aware Conditional Image Synthesis](https://arxiv.org/abs/2302.08509)
13
+
14
+ CVPR 2023
15
+
16
+ [Kangle Deng](https://dunbar12138.github.io/),
17
+ [Gengshan Yang](https://gengshan-y.github.io/),
18
+ [Deva Ramanan](https://www.cs.cmu.edu/~deva/),
19
+ [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/)
20
+
21
+ Carnegie Mellon University
22
+
23
+ ---
24
+
25
+ We propose pix2pix3D, a 3D-aware conditional generative model for controllable photorealistic image synthesis. Given a 2D label map, such as a segmentation or edge map, our model learns to synthesize a corresponding image from different viewpoints. To enable explicit 3D user control, we extend conditional generative models with neural radiance fields. Given widely-available monocular images and label map pairs, our model learns to assign a label to every 3D point in addition to color and density, which enables it to render the image and pixel-aligned label map simultaneously. Finally, we build an interactive system that allows users to edit the label map from any viewpoint and generate outputs accordingly.
26
+
27
+ <p align="center">
28
+ <img src="assets/teaser_jpg.jpg" width="720" />
29
+ </p>
30
+
31
+ ## Getting Started
32
+
33
+ ### Dependencies
34
+
35
+ We provide a conda env file that contains all the other dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment:
36
+ ```
37
+ conda env create -f environment.yml
38
+ conda activate pix2pix3d
39
+ ```
40
+
41
+ ### Data
42
+
43
+ We provide our preprocessed datasets, including segmentation maps and edge maps. You can download the [CelebAMask](https://drive.google.com/drive/folders/1mC6i4YmdpazJSmXrW8WFSfsImAJ8a_CF?usp=sharing) dataset, [AFHQ-Cat-Seg](https://drive.google.com/drive/folders/1yjTTE57P9-hEe-IVcE-GXdh04WGo5lD9?usp=sharing) dataset, and [Shapenet-Car-Edge](https://drive.google.com/drive/folders/1XTPuu784DIvk0ie094qyLrcF-v-jMe3_?usp=sharing) dataset and put those zip files under ```data/```.
44
+
45
+
46
+ ### Pre-trained Models
47
+
48
+ You can download our pre-trained models using the following script:
49
+ ```
50
+ bash checkpoints/download_models.sh
51
+ ```
52
+
53
+ ---
54
+ ### Inference
55
+
56
+ We provide several scripts to generate the results once you download the pre-trained models.
57
+
58
+ <p align="center">
59
+ <img src="assets/teaser_gif.gif" width="720" />
60
+ </p>
61
+
62
+ #### Generate Samples
63
+ You can generate results based on the samples in the dataset.
64
+ ```
65
+ python applications/generate_samples.py --network <network_pkl> --outdir <output_dir> --random_seed <random_seeds list, e.g. 0 1> --cfg <configs, e.g., seg2cat, seg2face, edge2car> --input_id <sample_id in dataset>
66
+ ```
67
+ For example:
68
+
69
+ | Input Label Map | Generated Image | Generated Label Map |
70
+ | ------------- | ------------- | -------- |
71
+ | <img src="assets/seg2cat_1666_input.png" width="256" /> | <img src="assets/seg2cat_1666_1_color.png" width="256" /> | <img src="assets/seg2cat_1666_1_label.png" width="256" /> |
72
+
73
+ You can get the results above by running:
74
+ ```
75
+ python applications/generate_samples.py --network checkpoints/pix2pix3d_seg2cat.pkl --outdir examples --random_seed 1 --cfg seg2cat --input_id 1666
76
+ ```
77
+
78
+ #### Render Videos
79
+ You can render a video result based on a specified input label map.
80
+ ```
81
+ python applications/generate_video.py --network <network_pkl> --outdir <output_dir> --random_seed <random_seeds list, e.g. 0 1> --cfg <configs, e.g., seg2cat, seg2face, edge2car> --input <input label map>
82
+ ```
83
+
84
+ For example:
85
+ | Input Label Map | Generated Image | Generated Label Map |
86
+ | ------------- | ------------- | -------- |
87
+ | <img src="assets/seg2cat_1666_input.png" width="256" /> | <img src="assets/seg2cat_1.gif" width="256" /> | <img src="assets/seg2cat_1_label.gif" width="256" /> |
88
+
89
+ You can get the results above using the following command:
90
+ ```
91
+ python applications/generate_video.py --network checkpoints/pix2pix3d_seg2cat.pkl --outdir examples --random_seed 1 --cfg seg2cat --input examples/example_input.png
92
+ ```
93
+
94
+ #### Extract Semantic Mesh
95
+ You can also extract the mesh and color it using 3D semantic labels. Some extra packages (`pyrender`, `trimesh`, and `mcubes`) are required for mesh extraction. You can install them by `pip`. The extracted mesh will be saved as `semantic_mesh.ply`.
96
+
97
+ For example:
98
+ | Input Label Map | Semantic Mesh |
99
+ | ------------- | ------------- |
100
+ | <img src="assets/seg2cat_1666_input.png" width="256" /> | <img src="assets/rendered_mesh_colored.gif" width="256" /> |
101
+
102
+ You can get the results above with the following command:
103
+ ```
104
+ python applications/extract_mesh.py --network checkpoints/pix2pix3d_seg2cat.pkl --outdir examples --cfg seg2cat --input examples/example_input.png
105
+ ```
106
+
107
+ <!-- #### Interpolation -->
108
+
109
+
110
+ ### Training
111
+
112
+ We provide an example training script at `train_scripts/afhq_seg.sh`:
113
+ ```
114
+ python train.py --outdir=<log_dir> \
115
+ --cfg=afhq --data=data/afhq_v2_train_cat_512.zip \
116
+ --mask_data=data/afhqcat_seg_6c.zip \
117
+ --data_type=seg --semantic_channels=6 \
118
+ --render_mask=True --dis_mask=True \
119
+ --neural_rendering_resolution_initial=128 \
120
+ --resume=<EG3D-checkpoints>/afhqcats512-128.pkl \
121
+ --gpus=2 --batch=4 --mbstd-group=2 \
122
+ --gamma=5 --gen_pose_cond=True \
123
+ --random_c_prob=0.5 \
124
+ --lambda_d_semantic=0.1 \
125
+ --lambda_lpips=1 \
126
+ --lambda_cross_view=1e-4 \
127
+ --only_raw_recons=True \
128
+ --wandb_log=False
129
+ ```
130
+ Training parameters:
131
+ - `outdir`: The directory to save checkpoints and logs.
132
+ - `cfg`: Choose from [afhq, celeba, shapenet].
133
+ - `data`: RGB data file.
134
+ - `mask_data`: label map data file.
135
+ - `data_type`: Choose from [seg, edge]. Specify the `semantic_channels` if using `seg`.
136
+ - `render_mask`: Whether to render label maps along with RGB.
137
+ - `dis_mask`: Whether to use a GAN loss on rendered label maps.
138
+ - `neural_rendering_resolution_initial`: The resolution of NeRF outputs.
139
+ - `resume`: We partially initialize our network with EG3D pretrained checkpoints (download [here](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/eg3d)).
140
+ - `gpus`, `batch`, `mbstd-group`: Parameters for batch size and multi-gpu training.
141
+ - `gen_pose_cond`: Whether to condition the generation on camera poses.
142
+ - `random_c_prob`: Probablity of sampling random poses for training.
143
+ - `lambda_d_semantic`: The weight of GAN loss on label maps.
144
+ - `lambda_lpips`: The weight of RGB LPIPS loss.
145
+ - `lambda_cross_view`: The weight of cross-view consistency loss.
146
+ - `wandb_log`: Whether to use wandb log.
147
+
148
+ ### Prepare your own dataset
149
+
150
+ We follow the dataset format of EG3D [here](https://github.com/NVlabs/eg3d#preparing-datasets). You can obtain the segmentation masks of your own dataset by [DINO clustering](https://github.com/ShirAmir/dino-vit-features/blob/main/part_cosegmentation.py), and obtain the edge map by [pidinet](https://github.com/hellozhuo/pidinet) and [informative drawing](https://github.com/carolineec/informative-drawings).
151
+
152
+ ---
153
+
154
+ ## Citation
155
+
156
+ If you find this repository useful for your research, please cite the following work.
157
+ ```
158
+ @inproceedings{kangle2023pix2pix3d,
159
+ title={3D-aware Conditional Image Synthesis},
160
+ author={Deng, Kangle and Yang, Gengshan and Ramanan, Deva and Zhu, Jun-Yan},
161
+ booktitle = {CVPR},
162
+ year = {2023}
163
+ }
164
+ ```
165
+
166
+ ---
167
+
168
+ ## Acknowledgments
169
+ We thank Sheng-Yu Wang, Nupur Kumari, Gaurav Parmer, Ruihan Gao, Muyang Li, George Cazenavette, Andrew Song, Zhipeng Bao, Tamaki Kojima, Krishna Wadhwani, Takuya Narihira, and Tatsuo Fujiwara for their discussion and help. We are grateful for the support from Sony Corporation, Singapore DSTA, and the CMU Argo AI Center for Autonomous Vehicle Research.
170
+ This codebase borrows heavily from [EG3D](https://github.com/NVlabs/eg3d) and [StyleNeRF](https://github.com/facebookresearch/StyleNeRF).
pix2pix3D-main/pix2pix3D-main/applications/demo/qt_demo_seg2cat.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+
4
+ import os
5
+ import cv2
6
+ import time
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ import torch
11
+ from torchvision.utils import save_image
12
+
13
+ from ui_qt.ui_clean import Ui_Form_Seg2cat as Ui_Form
14
+ from ui_qt.mouse_event import GraphicsScene
15
+
16
+ from PyQt5.QtCore import *
17
+ from PyQt5.QtGui import *
18
+ from PyQt5.QtWidgets import *
19
+ from PyQt5.QtPrintSupport import QPrintDialog, QPrinter
20
+
21
+ import dnnlib
22
+ import legacy
23
+ from torch_utils import misc
24
+
25
+ from camera_utils import LookAtPoseSampler, FOV_to_intrinsics
26
+ from torch_utils import misc
27
+ from training.triplane_cond import TriPlaneGenerator
28
+ from training.training_loop import setup_snapshot_image_grid, get_image_grid, save_image_grid
29
+ from train import init_conditional_dataset_kwargs
30
+
31
+ from matplotlib import pyplot as plt
32
+
33
+ from pathlib import Path
34
+
35
+ from rich.progress import track
36
+ import json
37
+
38
+ import imageio
39
+
40
+ from torch import nn
41
+ import torch.nn.functional as F
42
+
43
+ import argparse
44
+
45
+ from scipy.spatial.transform import Rotation as R
46
+
47
+ from training.utils import color_mask as color_mask_np
48
+
49
+ color_list = [QColor(255, 255, 255), QColor(204, 0, 0), QColor(76, 153, 0), QColor(204, 204, 0), QColor(51, 51, 255), QColor(204, 0, 204), QColor(0, 255, 255), QColor(255, 204, 204), QColor(102, 51, 0), QColor(255, 0, 0), QColor(102, 204, 0), QColor(255, 255, 0), QColor(0, 0, 153), QColor(0, 0, 204), QColor(255, 51, 153), QColor(0, 204, 204), QColor(0, 51, 0), QColor(255, 153, 51), QColor(0, 204, 0)]
50
+
51
+ def color_mask(m):
52
+ my_color_list = [[255, 255, 255], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
53
+ if len(m.shape) == 2:
54
+ im_base = np.zeros((m.shape[0], m.shape[1], 3))
55
+ for idx, color in enumerate(my_color_list):
56
+ im_base[m == idx] = color
57
+ return im_base
58
+ elif len(m.shape) == 3:
59
+ im_base = np.zeros((m.shape[0], m.shape[1], m.shape[2], 3))
60
+ for idx, color in enumerate(my_color_list):
61
+ im_base[m == idx] = color
62
+ return im_base
63
+
64
+ def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, device='cuda'):
65
+ gen = model.synthesis
66
+ range_u, range_v = gen.C.range_u, gen.C.range_v
67
+ u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
68
+ v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
69
+ cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device, fov=fov)
70
+ return cam
71
+
72
+ def get_yaw_pitch(cam2world):
73
+ forward = cam2world[0:3, 2]
74
+ yaw = np.arctan2(forward[0], forward[2]) - np.pi / 2
75
+ phi = np.arccos(forward[1])
76
+ v = (1 - np.cos(phi)) / 2
77
+ pitch = (1+forward[1]) / 2 * np.pi
78
+ return yaw, pitch
79
+
80
+ def create_cam2world_fromeuler(euler, radius):
81
+ r = R.from_euler('zyx', euler, degrees=False)
82
+ cam2world = r.as_matrix()
83
+ cam2world = np.concatenate([cam2world, np.array([[0, 0, 0]])], axis=0)
84
+ cam2world = np.concatenate([cam2world, np.array([0, 0, 0, 1])[...,None]], axis=1)
85
+ cam2world[:3, 3] = -cam2world[:3, 2] * radius
86
+ return cam2world
87
+
88
+ class Ex(QWidget, Ui_Form):
89
+ def __init__(self):
90
+ super(Ex, self).__init__()
91
+ self.size = 35
92
+ self.yaw = 100
93
+ self.pitch = 50
94
+ self.roll = 0
95
+ self.truncation = 0.75
96
+
97
+ self.setupUi(self)
98
+ self.show()
99
+ self.output_img = None
100
+
101
+ self.mat_img = None
102
+
103
+ self.mode = 0
104
+ self.mask = None
105
+ self.mask_m = np.ones((512, 512, 1), dtype=np.uint8) * 255
106
+ self.img = None
107
+
108
+ self.mouse_clicked = False
109
+ self.scene = GraphicsScene(self.mode, self.size)
110
+ self.graphicsView.setScene(self.scene)
111
+ self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
112
+ self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
113
+ self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
114
+
115
+ self.result_scene = QGraphicsScene()
116
+ self.graphicsView_2.setScene(self.result_scene)
117
+ self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft)
118
+ self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
119
+ self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
120
+
121
+ self.dlg = QColorDialog(self.graphicsView)
122
+ self.color = None
123
+ self.device = torch.device('cuda')
124
+
125
+ # Parse arguments
126
+ parser = argparse.ArgumentParser(description='Real-time 3D editing demo')
127
+ parser.add_argument('--network', help='Path to the network pickle file', required=True)
128
+ parser.add_argument('--data_dir', default='data/', help='Directory to the data', required=False)
129
+ args = parser.parse_args()
130
+
131
+ network_pkl = args.network
132
+ self.G = self.get_model(network_pkl)
133
+
134
+ # Initialize dataset.
135
+ data_path = Path(args.data_dir) / 'afhq_v2_train_cat_512.zip'
136
+ mask_data = Path(args.data_dir) / 'afhqcat_seg_6c.zip'
137
+ data_type= 'seg'
138
+ dataset_kwargs, dataset_name = init_conditional_dataset_kwargs(data_path, mask_data, data_type)
139
+ self.training_data = dnnlib.util.construct_class_by_name(**dataset_kwargs)
140
+
141
+ self.input_batch = None
142
+ # self.ws = None
143
+ self.ws_texture = None
144
+
145
+ self.buffer_mask = None
146
+
147
+ focal_length = 4.2647 # shapenet has higher FOV
148
+ self.intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=self.device)
149
+
150
+ os.makedirs('examples/ui', exist_ok=True)
151
+
152
+
153
+ def open(self):
154
+ fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
155
+ QDir.currentPath())
156
+ if fileName:
157
+ image = QPixmap(fileName)
158
+ mat_img = Image.open(fileName)
159
+ self.img = mat_img.copy()
160
+ if image.isNull():
161
+ QMessageBox.information(self, "Image Viewer",
162
+ "Cannot load %s." % fileName)
163
+ return
164
+ image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
165
+
166
+ if len(self.ref_scene.items())>0:
167
+ self.ref_scene.removeItem(self.ref_scene.items()[-1])
168
+ self.ref_scene.addPixmap(image)
169
+ if len(self.result_scene.items())>0:
170
+ self.result_scene.removeItem(self.result_scene.items()[-1])
171
+ self.result_scene.addPixmap(image)
172
+
173
+ def open_mask(self):
174
+ fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
175
+ QDir.currentPath())
176
+ if fileName:
177
+ mat_img = cv2.imread(fileName)
178
+ self.mask = mat_img.copy()
179
+ self.mask_m = mat_img
180
+ mat_img = mat_img.copy()
181
+ image = QImage(mat_img, 512, 512, QImage.Format_RGB888)
182
+
183
+ if image.isNull():
184
+ QMessageBox.information(self, "Image Viewer",
185
+ "Cannot load %s." % fileName)
186
+ return
187
+
188
+ for i in range(512):
189
+ for j in range(512):
190
+ r, g, b, a = image.pixelColor(i, j).getRgb()
191
+ image.setPixel(i, j, color_list[r].rgb())
192
+
193
+ pixmap = QPixmap()
194
+ pixmap.convertFromImage(image)
195
+ self.image = pixmap.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
196
+ self.scene.reset()
197
+ if len(self.scene.items())>0:
198
+ self.scene.reset_items()
199
+ self.scene.addPixmap(self.image)
200
+
201
+
202
+ def bg_mode(self):
203
+ self.scene.mode = 0
204
+
205
+ def skin_mode(self):
206
+ self.scene.mode = 1
207
+
208
+ def nose_mode(self):
209
+ self.scene.mode = 2
210
+
211
+ def eye_g_mode(self):
212
+ self.scene.mode = 3
213
+
214
+ def l_eye_mode(self):
215
+ self.scene.mode = 4
216
+
217
+ def r_eye_mode(self):
218
+ self.scene.mode = 5
219
+
220
+ def l_brow_mode(self):
221
+ self.scene.mode = 6
222
+
223
+ def r_brow_mode(self):
224
+ self.scene.mode = 7
225
+
226
+ def l_ear_mode(self):
227
+ self.scene.mode = 8
228
+
229
+ def r_ear_mode(self):
230
+ self.scene.mode = 9
231
+
232
+ def mouth_mode(self):
233
+ self.scene.mode = 10
234
+
235
+ def u_lip_mode(self):
236
+ self.scene.mode = 11
237
+
238
+ def l_lip_mode(self):
239
+ self.scene.mode = 12
240
+
241
+ def hair_mode(self):
242
+ self.scene.mode = 13
243
+
244
+ def hat_mode(self):
245
+ self.scene.mode = 14
246
+
247
+ def ear_r_mode(self):
248
+ self.scene.mode = 15
249
+
250
+ def neck_l_mode(self):
251
+ self.scene.mode = 16
252
+
253
+ def neck_mode(self):
254
+ self.scene.mode = 17
255
+
256
+ def cloth_mode(self):
257
+ self.scene.mode = 18
258
+
259
+ def increase(self):
260
+ if self.scene.size < 50:
261
+ self.scene.size += 1
262
+
263
+ def decrease(self):
264
+ if self.scene.size > 1:
265
+ self.scene.size -= 1
266
+
267
+ def changeBrushSize(self, s):
268
+ self.scene.size = s
269
+
270
+ def changeYaw(self, s):
271
+ self.yaw = s
272
+ if self.ws is not None:
273
+ self.generate()
274
+
275
+ def changePitch(self, s):
276
+ self.pitch = s
277
+ # print('changing pitch', self.pitch)
278
+ if self.ws is not None:
279
+ self.generate()
280
+
281
+ def changeRoll(self, s):
282
+ self.roll = s
283
+ if self.ws is not None:
284
+ self.generate()
285
+
286
+ def changeTruncation(self, s):
287
+ self.truncation = s / 100
288
+ self.reconstruct()
289
+ self.generate()
290
+
291
+ def inputID(self):
292
+ input_id = int(self.text_inputID.toPlainText())
293
+ self.input_batch = self.training_data[input_id]
294
+
295
+ self.mask = self.input_batch['mask'].transpose(1,2,0).astype(np.uint8)
296
+ # self.mask_m = self.input_batch['mask'].transpose(1,2,0).astype(np.uint8)
297
+ mat_img = cv2.resize(self.mask[:,:,0], (512, 512), interpolation=cv2.INTER_NEAREST)
298
+ self.mask = mat_img[:,:,np.newaxis]
299
+ self.mask_m = self.mask.copy()
300
+ image = QImage(mat_img, 512, 512, QImage.Format_RGB888)
301
+
302
+
303
+ for i in range(512):
304
+ for j in range(512):
305
+ # r, g, b, a = image.pixelColor(i, j).getRgb()
306
+ r = mat_img[j, i]
307
+ # print(r)
308
+ image.setPixel(i, j, color_list[r].rgb())
309
+
310
+ pixmap = QPixmap()
311
+ pixmap.convertFromImage(image)
312
+ self.image = pixmap.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
313
+ self.scene.reset()
314
+ if len(self.scene.items())>0:
315
+ self.scene.reset_items()
316
+ self.scene.addPixmap(self.image)
317
+
318
+ self.ws = None
319
+
320
+ roll, yaw, pitch = R.from_matrix(self.input_batch['pose'][:16].reshape(4,4)[:3,:3]).as_euler('zyx', degrees=False)
321
+ # print(yaw, pitch)
322
+ # print(self.input_batch['pose'][:16].reshape(4, 4))
323
+ pitch_range = np.pi
324
+ yaw_range = np.pi / 2
325
+ roll_range = np.pi / 4
326
+
327
+ pitch = pitch - np.pi
328
+ pitch = pitch + 2 * np.pi if pitch < -np.pi else pitch
329
+
330
+ self.yaw = ((yaw) / yaw_range * 100)
331
+ self.pitch = ((pitch) / pitch_range * 100)
332
+ self.roll = ((roll) / (roll_range) * 100)
333
+ print(self.roll, self.yaw, self.pitch)
334
+
335
+ self.intrinsics = torch.tensor(self.input_batch['pose'][16:].reshape(3,3)).float().to(self.device)
336
+
337
+
338
+ self.slider_yawselect.setValue(self.yaw)
339
+ self.slider_pitchselect.setValue(self.pitch)
340
+ self.slider_rollselect.setValue(self.roll)
341
+
342
+
343
+ def get_mask(self): # get from output
344
+ mat_img = self.buffer_mask[0].astype(np.uint8)
345
+
346
+ # print(self.mask.shape)
347
+ # mat_img = cv2.resize(mat_img, (512, 512), interpolation=cv2.INTER_NEAREST)
348
+ self.mask = mat_img[:,:,np.newaxis]
349
+ self.mask_m = self.mask.copy()
350
+ # print(mat_img.shape)
351
+ image = QImage(mat_img, 512, 512, QImage.Format_RGB888)
352
+
353
+
354
+ for i in range(512):
355
+ for j in range(512):
356
+ # r, g, b, a = image.pixelColor(i, j).getRgb()
357
+ r = mat_img[j, i]
358
+ # print(r)
359
+ image.setPixel(i, j, color_list[r].rgb())
360
+
361
+ pixmap = QPixmap()
362
+ pixmap.convertFromImage(image)
363
+ self.image = pixmap.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
364
+ self.scene.reset()
365
+ if len(self.scene.items())>0:
366
+ self.scene.reset_items()
367
+ self.scene.addPixmap(self.image)
368
+
369
+
370
+
371
+ def generate(self):
372
+ ws = self.ws
373
+
374
+ pitch_range = np.pi
375
+ yaw_range = np.pi / 2
376
+ roll_range = np.pi / 4
377
+ pitch = self.pitch / 100 * pitch_range
378
+ yaw = self.yaw / 100 * yaw_range
379
+ roll = self.roll / 100 * roll_range
380
+
381
+ cam2world_pose = torch.tensor(create_cam2world_fromeuler([roll, yaw, pitch+np.pi], radius=2.7)).float().to(self.device)
382
+
383
+ # print(cam2world_pose)
384
+ pose = torch.cat([cam2world_pose.reshape(-1, 16), self.intrinsics.reshape(-1, 9)], 1)
385
+
386
+ out = self.G.synthesis(ws, pose.to(self.device), noise_mode='const', neural_rendering_resolution=128)
387
+ img = ((out['image'].permute(0,2,3,1).squeeze(0).cpu().numpy().clip(-1, 1) * 0.5 + 0.5) * 255).astype(np.uint8).copy()
388
+ if out['image'].shape[-1] != 512:
389
+ # print(f"Resizing {out['image'].shape[-1]} to {512}")
390
+ img = cv2.resize(img, (512, 512))
391
+
392
+ qim = QImage(img.data, img.shape[1], img.shape[0], img.strides[0], QImage.Format_RGB888)
393
+ if len(self.result_scene.items())>0:
394
+ self.result_scene.removeItem(self.result_scene.items()[-1])
395
+ self.result_scene.addPixmap(QPixmap.fromImage(qim))
396
+
397
+ self.buffer_mask = torch.argmax(out['semantic'].detach(), dim=1).cpu().numpy() # 1 x 512 x 512
398
+ self.output_img = img.copy()
399
+ self.get_mask()
400
+
401
+
402
+ def reconstruct(self):
403
+ if self.input_batch is None:
404
+ return
405
+ ws = self.ws
406
+
407
+ out = self.G.synthesis(ws, torch.tensor(self.input_batch['pose']).unsqueeze(0).to(self.device))
408
+ if out['img'].shape[-1] != 512:
409
+ print(f"Resizing {out['img'].shape[-1]} to {512}")
410
+ img = resize_image(out['img'].detach(), 512)
411
+ else:
412
+ img = out['img'].detach()
413
+ img = ((img.permute(0,2,3,1).squeeze(0).cpu().numpy().clip(-1, 1) * 0.5 + 0.5) * 255).astype(np.uint8).copy()
414
+ qim = QImage(img.data, img.shape[1], img.shape[0], img.strides[0], QImage.Format_RGB888)
415
+ if len(self.ref_scene.items())>0:
416
+ self.ref_scene.removeItem(self.ref_scene.items()[-1])
417
+ self.ref_scene.addPixmap(QPixmap.fromImage(qim))
418
+
419
+ if out['semantic'].shape[-1] != 512:
420
+ seg = resize_image(out['semantic'].detach(), 512)
421
+ else:
422
+ seg = out['semantic'].detach()
423
+ seg = color_mask(torch.argmax(seg, dim=1).cpu())[0].astype(np.uint8).copy()
424
+ qim = QImage(seg.data, seg.shape[1], seg.shape[0], seg.strides[0], QImage.Format_RGB888)
425
+ if len(self.ref_seg_scene.items())>0:
426
+ self.ref_seg_scene.removeItem(self.ref_seg_scene.items()[-1])
427
+ self.ref_seg_scene.addPixmap(QPixmap.fromImage(qim))
428
+
429
+ def get_ws(self):
430
+ z = torch.from_numpy(np.random.RandomState(int(self.text_seed.toPlainText())).randn(1, self.G.z_dim).astype('float32')).to(self.device)
431
+
432
+ for i in range(6):
433
+ self.mask_m = self.make_mask(self.mask_m, self.scene.mask_points[i], self.scene.size_points[i], i)
434
+
435
+ cv2.imwrite('examples/ui/mask_input.png', self.mask_m)
436
+
437
+ forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(self.G.rendering_kwargs['avg_camera_pivot'], device=self.device),
438
+ radius=self.G.rendering_kwargs['avg_camera_radius'], device=self.device)
439
+ focal_length = 4.2647 # shapenet has higher FOV
440
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=self.device)
441
+ forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
442
+
443
+ self.ws = self.G.mapping(z, forward_pose.to(self.device),
444
+ {'mask': torch.tensor(self.mask_m[None,...,0]).unsqueeze(0).to(self.device), 'pose': torch.tensor(self.input_batch['pose']).unsqueeze(0).to(self.device)})
445
+
446
+ if self.ws_texture is None:
447
+ self.ws_texture = self.ws[:,8:,:]
448
+ else:
449
+ self.ws[:,8:,:] = self.ws_texture
450
+ # print(self.ws[:,8:,:])
451
+
452
+
453
+ def generateAndReconstruct(self):
454
+ self.get_ws()
455
+ self.generate()
456
+ # self.reconstruct()
457
+
458
+
459
+ def make_mask(self, mask, pts, sizes, color):
460
+ if len(pts)>0:
461
+ for idx, pt in enumerate(pts):
462
+ cv2.line(mask,pt['prev'],pt['curr'],(color,color,color),sizes[idx])
463
+ return mask
464
+
465
+ def save_img(self):
466
+ for i in range(6):
467
+ self.mask_m = self.make_mask(self.mask_m, self.scene.mask_points[i], self.scene.size_points[i], i)
468
+ mask_np = np.array(self.mask_m)[..., 0]
469
+ print(mask_np.shape)
470
+ cv2.imwrite('examples/ui/mask.png', mask_np)
471
+ cv2.imwrite('examples/ui/mask_color.png', color_mask_np(mask_np).astype(np.uint8)[...,::-1])
472
+ cv2.imwrite('examples/ui/output.png',self.output_img[...,::-1])
473
+
474
+
475
+ def undo(self):
476
+ self.scene.undo()
477
+
478
+ def clear(self):
479
+ self.mask_m = self.mask.copy()
480
+
481
+ self.scene.reset_items()
482
+ self.scene.reset()
483
+ if type(self.image):
484
+ self.scene.addPixmap(self.image)
485
+
486
+ if len(self.result_scene.items())>0:
487
+ self.result_scene.removeItem(self.result_scene.items()[-1])
488
+
489
+ self.ws = None
490
+
491
+ def get_model(self, network_pkl):
492
+ device = torch.device('cuda')
493
+ with dnnlib.util.open_url(network_pkl) as f:
494
+ G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
495
+
496
+ return G
497
+
498
+ def clear_ws(self):
499
+ self.ws = None
500
+
501
+ if __name__ == '__main__':
502
+ app = QApplication(sys.argv)
503
+ ex = Ex()
504
+ sys.exit(app.exec_())
pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/__init__.py ADDED
File without changes
pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/mouse_event.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from PyQt5.QtCore import *
4
+ from PyQt5.QtGui import *
5
+ from PyQt5.QtWidgets import *
6
+ import numpy as np
7
+
8
+ color_list = [QColor(255, 255, 255), QColor(204, 0, 0), QColor(76, 153, 0), QColor(204, 204, 0), QColor(51, 51, 255), QColor(204, 0, 204), QColor(0, 255, 255), QColor(255, 204, 204), QColor(102, 51, 0), QColor(255, 0, 0), QColor(102, 204, 0), QColor(255, 255, 0), QColor(0, 0, 153), QColor(0, 0, 204), QColor(255, 51, 153), QColor(0, 204, 204), QColor(0, 51, 0), QColor(255, 153, 51), QColor(0, 204, 0)]
9
+
10
+ class GraphicsScene(QGraphicsScene):
11
+ def __init__(self, mode, size, parent=None):
12
+ QGraphicsScene.__init__(self, parent)
13
+ self.mode = mode
14
+ self.size = size
15
+ self.mouse_clicked = False
16
+ self.prev_pt = None
17
+
18
+ # self.masked_image = None
19
+
20
+ # save the points
21
+ self.mask_points = []
22
+ for i in range(len(color_list)):
23
+ self.mask_points.append([])
24
+
25
+ # save the size of points
26
+ self.size_points = []
27
+ for i in range(len(color_list)):
28
+ self.size_points.append([])
29
+
30
+ # save the history of edit
31
+ self.history = []
32
+
33
+ def reset(self):
34
+ # save the points
35
+ self.mask_points = []
36
+ for i in range(len(color_list)):
37
+ self.mask_points.append([])
38
+ # save the size of points
39
+ self.size_points = []
40
+ for i in range(len(color_list)):
41
+ self.size_points.append([])
42
+ # save the history of edit
43
+ self.history = []
44
+
45
+ self.mode = 0
46
+ self.prev_pt = None
47
+
48
+ def mousePressEvent(self, event):
49
+ self.mouse_clicked = True
50
+
51
+ def mouseReleaseEvent(self, event):
52
+ self.prev_pt = None
53
+ self.mouse_clicked = False
54
+
55
+ def mouseMoveEvent(self, event): # drawing
56
+ if self.mouse_clicked:
57
+ if self.prev_pt:
58
+ self.drawMask(self.prev_pt, event.scenePos(), color_list[self.mode], self.size)
59
+ pts = {}
60
+ pts['prev'] = (int(self.prev_pt.x()),int(self.prev_pt.y()))
61
+ pts['curr'] = (int(event.scenePos().x()),int(event.scenePos().y()))
62
+
63
+ self.size_points[self.mode].append(self.size)
64
+ self.mask_points[self.mode].append(pts)
65
+ self.history.append(self.mode)
66
+ self.prev_pt = event.scenePos()
67
+ else:
68
+ self.prev_pt = event.scenePos()
69
+
70
+ def drawMask(self, prev_pt, curr_pt, color, size):
71
+ lineItem = QGraphicsLineItem(QLineF(prev_pt, curr_pt))
72
+ lineItem.setPen(QPen(color, size, Qt.SolidLine)) # rect
73
+ self.addItem(lineItem)
74
+
75
+ def erase_prev_pt(self):
76
+ self.prev_pt = None
77
+
78
+ def reset_items(self):
79
+ for i in range(len(self.items())):
80
+ item = self.items()[0]
81
+ self.removeItem(item)
82
+
83
+ def undo(self):
84
+ if len(self.items())>1:
85
+ if len(self.items())>=9:
86
+ for i in range(8):
87
+ item = self.items()[0]
88
+ self.removeItem(item)
89
+ if self.history[-1] == self.mode:
90
+ self.mask_points[self.mode].pop()
91
+ self.size_points[self.mode].pop()
92
+ self.history.pop()
93
+ else:
94
+ for i in range(len(self.items())-1):
95
+ item = self.items()[0]
96
+ self.removeItem(item)
97
+ if self.history[-1] == self.mode:
98
+ self.mask_points[self.mode].pop()
99
+ self.size_points[self.mode].pop()
100
+ self.history.pop()
pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/ui.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyQt5 import QtCore, QtGui, QtWidgets
2
+ from PyQt5.QtCore import Qt
3
+
4
+ class Ui_Form(object):
5
+ def setupUi(self, Form):
6
+ Form.setObjectName("Form")
7
+ Form.resize(1800, 660)
8
+ self.pushButton = QtWidgets.QPushButton(Form)
9
+ # self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
10
+ self.pushButton.setGeometry(QtCore.QRect(535, 360, 81, 27))
11
+ self.pushButton.setObjectName("pushButton")
12
+ self.pushButton_2 = QtWidgets.QPushButton(Form)
13
+ self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
14
+ self.pushButton_2.setObjectName("pushButton_2")
15
+ self.pushButton_3 = QtWidgets.QPushButton(Form)
16
+ self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
17
+ self.pushButton_3.setObjectName("pushButton_3")
18
+ self.pushButton_4 = QtWidgets.QPushButton(Form)
19
+ self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
20
+ self.pushButton_4.setObjectName("pushButton_4")
21
+ self.pushButton_5 = QtWidgets.QPushButton(Form)
22
+ self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
23
+ self.pushButton_5.setObjectName("pushButton_5")
24
+ self.pushButton_6 = QtWidgets.QPushButton(Form)
25
+ self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
26
+ self.pushButton_6.setObjectName("pushButton_6")
27
+ self.pushButton_7 = QtWidgets.QPushButton(Form)
28
+ self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
29
+ self.pushButton_7.setObjectName("pushButton_7")
30
+ self.pushButton_8 = QtWidgets.QPushButton(Form)
31
+ self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
32
+ self.pushButton_8.setObjectName("pushButton_8")
33
+ self.pushButton_9 = QtWidgets.QPushButton(Form)
34
+ self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
35
+ self.pushButton_9.setObjectName("pushButton_9")
36
+ self.pushButton_10 = QtWidgets.QPushButton(Form)
37
+ self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
38
+ self.pushButton_10.setObjectName("pushButton_10")
39
+ self.pushButton_11 = QtWidgets.QPushButton(Form)
40
+ self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
41
+ self.pushButton_11.setObjectName("pushButton_11")
42
+ self.pushButton_12 = QtWidgets.QPushButton(Form)
43
+ self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
44
+ self.pushButton_12.setObjectName("pushButton_12")
45
+ self.pushButton_13 = QtWidgets.QPushButton(Form)
46
+ self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
47
+ self.pushButton_13.setObjectName("pushButton_13")
48
+ self.pushButton_14 = QtWidgets.QPushButton(Form)
49
+ self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
50
+ self.pushButton_14.setObjectName("pushButton_14")
51
+ self.pushButton_15 = QtWidgets.QPushButton(Form)
52
+ self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
53
+ self.pushButton_15.setObjectName("pushButton_15")
54
+ self.pushButton_16 = QtWidgets.QPushButton(Form)
55
+ self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
56
+ self.pushButton_16.setObjectName("pushButton_16")
57
+ self.pushButton_17 = QtWidgets.QPushButton(Form)
58
+ self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
59
+ self.pushButton_17.setObjectName("pushButton_17")
60
+ self.pushButton_18 = QtWidgets.QPushButton(Form)
61
+ self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
62
+ self.pushButton_18.setObjectName("pushButton_18")
63
+ self.pushButton_19 = QtWidgets.QPushButton(Form)
64
+ self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
65
+ self.pushButton_19.setObjectName("pushButton_19")
66
+ self.pushButton_20 = QtWidgets.QPushButton(Form)
67
+ self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
68
+ self.pushButton_20.setObjectName("pushButton_20")
69
+ self.pushButton_21 = QtWidgets.QPushButton(Form)
70
+ self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
71
+ self.pushButton_21.setObjectName("pushButton_21")
72
+ self.pushButton_22 = QtWidgets.QPushButton(Form)
73
+ self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
74
+ self.pushButton_22.setObjectName("pushButton_22")
75
+ self.pushButton_23 = QtWidgets.QPushButton(Form)
76
+ self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
77
+ self.pushButton_23.setObjectName("pushButton_23")
78
+ self.pushButton_24 = QtWidgets.QPushButton(Form)
79
+ self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
80
+ self.pushButton_24.setObjectName("pushButton_24")
81
+ self.pushButton_25 = QtWidgets.QPushButton(Form)
82
+ self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
83
+ self.pushButton_25.setObjectName("pushButton_25")
84
+ # self.pushButton_26 = QtWidgets.QPushButton(Form)
85
+ # self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
86
+ # self.pushButton_26.setObjectName("pushButton_26")
87
+ # self.pushButton_27 = QtWidgets.QPushButton(Form)
88
+ # self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
89
+ # self.pushButton_27.setObjectName("pushButton_27")
90
+
91
+
92
+ self.slider_sizeselect = QtWidgets.QSlider(Form)
93
+ self.slider_sizeselect.setRange(10,70)
94
+ self.slider_sizeselect.setOrientation(Qt.Horizontal)
95
+ self.slider_sizeselect.setValue(Form.size)
96
+ self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
97
+
98
+ self.label_sizeselect = QtWidgets.QLabel(Form)
99
+ self.label_sizeselect.setText("Brush Size")
100
+ self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
101
+
102
+ self.slider_yawselect = QtWidgets.QSlider(Form)
103
+ self.slider_yawselect.setRange(-100,100)
104
+ self.slider_yawselect.setOrientation(Qt.Horizontal)
105
+ self.slider_yawselect.setValue(Form.yaw)
106
+ self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
107
+
108
+ self.label_yawselect = QtWidgets.QLabel(Form)
109
+ self.label_yawselect.setText("Yaw")
110
+ self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
111
+
112
+ self.slider_pitchselect = QtWidgets.QSlider(Form)
113
+ self.slider_pitchselect.setRange(-100,100)
114
+ self.slider_pitchselect.setOrientation(Qt.Horizontal)
115
+ self.slider_pitchselect.setValue(Form.pitch)
116
+ self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
117
+
118
+ self.label_pitchselect = QtWidgets.QLabel(Form)
119
+ self.label_pitchselect.setText("Pitch")
120
+ self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
121
+
122
+ self.text_inputID = QtWidgets.QTextEdit(Form)
123
+ self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
124
+ self.text_inputID.setObjectName("text_inputID")
125
+
126
+ self.pushButton_inputID = QtWidgets.QPushButton(Form)
127
+ self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
128
+ self.pushButton_inputID.setObjectName("pushButton_inputID")
129
+
130
+ self.text_seed = QtWidgets.QTextEdit(Form)
131
+ self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
132
+ self.text_seed.setObjectName("text_seed")
133
+ self.text_seed.setPlainText("0")
134
+
135
+ self.label_seed = QtWidgets.QLabel(Form)
136
+ self.label_seed.setText("Seed")
137
+ self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
138
+
139
+
140
+
141
+
142
+ self.graphicsView = QtWidgets.QGraphicsView(Form)
143
+ self.graphicsView.setGeometry(QtCore.QRect(20, 120, 512, 512))
144
+ self.graphicsView.setObjectName("graphicsView")
145
+ self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
146
+ self.graphicsView_2.setGeometry(QtCore.QRect(620, 120, 512, 512))
147
+ self.graphicsView_2.setObjectName("graphicsView_2")
148
+ self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
149
+ self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
150
+ self.graphicsView_3.setObjectName("graphicsView_3")
151
+
152
+
153
+ self.retranslateUi(Form)
154
+ self.pushButton.clicked.connect(Form.generateAndReconstruct)
155
+ self.pushButton_2.clicked.connect(Form.open)
156
+ self.pushButton_3.clicked.connect(Form.open_mask)
157
+ self.pushButton_4.clicked.connect(Form.clear)
158
+ self.pushButton_5.clicked.connect(Form.undo)
159
+ self.pushButton_6.clicked.connect(Form.save_img)
160
+ self.pushButton_7.clicked.connect(Form.bg_mode)
161
+ self.pushButton_8.clicked.connect(Form.skin_mode)
162
+ self.pushButton_9.clicked.connect(Form.nose_mode)
163
+ self.pushButton_10.clicked.connect(Form.eye_g_mode)
164
+ self.pushButton_11.clicked.connect(Form.l_eye_mode)
165
+ self.pushButton_12.clicked.connect(Form.r_eye_mode)
166
+ self.pushButton_13.clicked.connect(Form.l_brow_mode)
167
+ self.pushButton_14.clicked.connect(Form.r_brow_mode)
168
+ self.pushButton_15.clicked.connect(Form.l_ear_mode)
169
+ self.pushButton_16.clicked.connect(Form.r_ear_mode)
170
+ self.pushButton_17.clicked.connect(Form.mouth_mode)
171
+ self.pushButton_18.clicked.connect(Form.u_lip_mode)
172
+ self.pushButton_19.clicked.connect(Form.l_lip_mode)
173
+ self.pushButton_20.clicked.connect(Form.hair_mode)
174
+ self.pushButton_21.clicked.connect(Form.hat_mode)
175
+ self.pushButton_22.clicked.connect(Form.ear_r_mode)
176
+ self.pushButton_23.clicked.connect(Form.neck_l_mode)
177
+ self.pushButton_24.clicked.connect(Form.neck_mode)
178
+ self.pushButton_25.clicked.connect(Form.cloth_mode)
179
+ # self.pushButton_26.clicked.connect(Form.increase)
180
+ # self.pushButton_27.clicked.connect(Form.decrease)
181
+
182
+
183
+ self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
184
+ self.slider_yawselect.valueChanged.connect(Form.changeYaw)
185
+ self.slider_pitchselect.valueChanged.connect(Form.changePitch)
186
+
187
+ self.pushButton_inputID.clicked.connect(Form.inputID)
188
+
189
+ QtCore.QMetaObject.connectSlotsByName(Form)
190
+
191
+ def retranslateUi(self, Form):
192
+ _translate = QtCore.QCoreApplication.translate
193
+ Form.setWindowTitle(_translate("Form", "3D-GauGAN"))
194
+ self.pushButton.setText(_translate("Form", "Generate"))
195
+ self.pushButton_2.setText(_translate("Form", "Open Image"))
196
+ self.pushButton_3.setText(_translate("Form", "Open Mask"))
197
+ self.pushButton_4.setText(_translate("Form", "Clear"))
198
+ self.pushButton_5.setText(_translate("Form", "Undo"))
199
+ self.pushButton_6.setText(_translate("Form", "Save Image"))
200
+ self.pushButton_7.setText(_translate("Form", "BackGround"))
201
+ self.pushButton_8.setText(_translate("Form", "Skin"))
202
+ self.pushButton_9.setText(_translate("Form", "Nose"))
203
+ self.pushButton_10.setText(_translate("Form", "Eyeglass"))
204
+ self.pushButton_11.setText(_translate("Form", "Left Eye"))
205
+ self.pushButton_12.setText(_translate("Form", "Right Eye"))
206
+ self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
207
+ self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
208
+ self.pushButton_15.setText(_translate("Form", "Left ear"))
209
+ self.pushButton_16.setText(_translate("Form", "Right ear"))
210
+ self.pushButton_17.setText(_translate("Form", "Mouth"))
211
+ self.pushButton_18.setText(_translate("Form", "Upper Lip"))
212
+ self.pushButton_19.setText(_translate("Form", "Lower Lip"))
213
+ self.pushButton_20.setText(_translate("Form", "Hair"))
214
+ self.pushButton_21.setText(_translate("Form", "Hat"))
215
+ self.pushButton_22.setText(_translate("Form", "Earring"))
216
+ self.pushButton_23.setText(_translate("Form", "Necklace"))
217
+ self.pushButton_24.setText(_translate("Form", "Neck"))
218
+ self.pushButton_25.setText(_translate("Form", "Cloth"))
219
+ # self.pushButton_26.setText(_translate("Form", "+"))
220
+ # self.pushButton_27.setText(_translate("Form", "-"))
221
+ self.pushButton_inputID.setText(_translate("Form", "Input ID"))
222
+
223
+
224
+ class Ui_Form_Seg(object):
225
+ def setupUi(self, Form):
226
+ Form.setObjectName("Form")
227
+ Form.resize(1800, 1260)
228
+ self.pushButton = QtWidgets.QPushButton(Form)
229
+ # self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
230
+ self.pushButton.setGeometry(QtCore.QRect(535, 360, 81, 27))
231
+ self.pushButton.setObjectName("pushButton")
232
+ self.pushButton_2 = QtWidgets.QPushButton(Form)
233
+ self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
234
+ self.pushButton_2.setObjectName("pushButton_2")
235
+ self.pushButton_3 = QtWidgets.QPushButton(Form)
236
+ self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
237
+ self.pushButton_3.setObjectName("pushButton_3")
238
+ self.pushButton_4 = QtWidgets.QPushButton(Form)
239
+ self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
240
+ self.pushButton_4.setObjectName("pushButton_4")
241
+ self.pushButton_5 = QtWidgets.QPushButton(Form)
242
+ self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
243
+ self.pushButton_5.setObjectName("pushButton_5")
244
+ self.pushButton_6 = QtWidgets.QPushButton(Form)
245
+ self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
246
+ self.pushButton_6.setObjectName("pushButton_6")
247
+ self.pushButton_7 = QtWidgets.QPushButton(Form)
248
+ self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
249
+ self.pushButton_7.setObjectName("pushButton_7")
250
+ self.pushButton_8 = QtWidgets.QPushButton(Form)
251
+ self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
252
+ self.pushButton_8.setObjectName("pushButton_8")
253
+ self.pushButton_9 = QtWidgets.QPushButton(Form)
254
+ self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
255
+ self.pushButton_9.setObjectName("pushButton_9")
256
+ self.pushButton_10 = QtWidgets.QPushButton(Form)
257
+ self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
258
+ self.pushButton_10.setObjectName("pushButton_10")
259
+ self.pushButton_11 = QtWidgets.QPushButton(Form)
260
+ self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
261
+ self.pushButton_11.setObjectName("pushButton_11")
262
+ self.pushButton_12 = QtWidgets.QPushButton(Form)
263
+ self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
264
+ self.pushButton_12.setObjectName("pushButton_12")
265
+ self.pushButton_13 = QtWidgets.QPushButton(Form)
266
+ self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
267
+ self.pushButton_13.setObjectName("pushButton_13")
268
+ self.pushButton_14 = QtWidgets.QPushButton(Form)
269
+ self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
270
+ self.pushButton_14.setObjectName("pushButton_14")
271
+ self.pushButton_15 = QtWidgets.QPushButton(Form)
272
+ self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
273
+ self.pushButton_15.setObjectName("pushButton_15")
274
+ self.pushButton_16 = QtWidgets.QPushButton(Form)
275
+ self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
276
+ self.pushButton_16.setObjectName("pushButton_16")
277
+ self.pushButton_17 = QtWidgets.QPushButton(Form)
278
+ self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
279
+ self.pushButton_17.setObjectName("pushButton_17")
280
+ self.pushButton_18 = QtWidgets.QPushButton(Form)
281
+ self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
282
+ self.pushButton_18.setObjectName("pushButton_18")
283
+ self.pushButton_19 = QtWidgets.QPushButton(Form)
284
+ self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
285
+ self.pushButton_19.setObjectName("pushButton_19")
286
+ self.pushButton_20 = QtWidgets.QPushButton(Form)
287
+ self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
288
+ self.pushButton_20.setObjectName("pushButton_20")
289
+ self.pushButton_21 = QtWidgets.QPushButton(Form)
290
+ self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
291
+ self.pushButton_21.setObjectName("pushButton_21")
292
+ self.pushButton_22 = QtWidgets.QPushButton(Form)
293
+ self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
294
+ self.pushButton_22.setObjectName("pushButton_22")
295
+ self.pushButton_23 = QtWidgets.QPushButton(Form)
296
+ self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
297
+ self.pushButton_23.setObjectName("pushButton_23")
298
+ self.pushButton_24 = QtWidgets.QPushButton(Form)
299
+ self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
300
+ self.pushButton_24.setObjectName("pushButton_24")
301
+ self.pushButton_25 = QtWidgets.QPushButton(Form)
302
+ self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
303
+ self.pushButton_25.setObjectName("pushButton_25")
304
+ # self.pushButton_26 = QtWidgets.QPushButton(Form)
305
+ # self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
306
+ # self.pushButton_26.setObjectName("pushButton_26")
307
+ # self.pushButton_27 = QtWidgets.QPushButton(Form)
308
+ # self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
309
+ # self.pushButton_27.setObjectName("pushButton_27")
310
+
311
+ self.slider_sizeselect = QtWidgets.QSlider(Form)
312
+ self.slider_sizeselect.setRange(10,70)
313
+ self.slider_sizeselect.setOrientation(Qt.Horizontal)
314
+ self.slider_sizeselect.setValue(Form.size)
315
+ self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
316
+
317
+ self.label_sizeselect = QtWidgets.QLabel(Form)
318
+ self.label_sizeselect.setText("Brush Size")
319
+ self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
320
+
321
+ self.slider_yawselect = QtWidgets.QSlider(Form)
322
+ self.slider_yawselect.setRange(-100,100)
323
+ self.slider_yawselect.setOrientation(Qt.Horizontal)
324
+ self.slider_yawselect.setValue(Form.yaw)
325
+ self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
326
+
327
+ self.label_yawselect = QtWidgets.QLabel(Form)
328
+ self.label_yawselect.setText("Yaw")
329
+ self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
330
+
331
+ self.slider_pitchselect = QtWidgets.QSlider(Form)
332
+ self.slider_pitchselect.setRange(-100,100)
333
+ self.slider_pitchselect.setOrientation(Qt.Horizontal)
334
+ self.slider_pitchselect.setValue(Form.pitch)
335
+ self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
336
+
337
+ self.label_pitchselect = QtWidgets.QLabel(Form)
338
+ self.label_pitchselect.setText("Pitch")
339
+ self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
340
+
341
+ self.slider_truncation = QtWidgets.QSlider(Form)
342
+ self.slider_truncation.setRange(0,100)
343
+ self.slider_truncation.setOrientation(Qt.Horizontal)
344
+ self.slider_truncation.setValue(Form.truncation)
345
+ self.slider_truncation.setGeometry(QtCore.QRect(1530, 100, 97, 27))
346
+
347
+ self.label_truncation = QtWidgets.QLabel(Form)
348
+ self.label_truncation.setText("Truncation")
349
+ self.label_truncation.setGeometry(QtCore.QRect(1630, 100, 97, 27))
350
+
351
+ self.text_inputID = QtWidgets.QTextEdit(Form)
352
+ self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
353
+ self.text_inputID.setObjectName("text_inputID")
354
+
355
+ self.pushButton_inputID = QtWidgets.QPushButton(Form)
356
+ self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
357
+ self.pushButton_inputID.setObjectName("pushButton_inputID")
358
+
359
+ self.text_seed = QtWidgets.QTextEdit(Form)
360
+ self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
361
+ self.text_seed.setObjectName("text_seed")
362
+ self.text_seed.setPlainText("0")
363
+
364
+ self.label_seed = QtWidgets.QLabel(Form)
365
+ self.label_seed.setText("Seed")
366
+ self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
367
+
368
+ self.pushButton_inverse = QtWidgets.QPushButton(Form)
369
+ self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
370
+ self.pushButton_inverse.setObjectName("pushButton_inverse")
371
+
372
+ self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
373
+ self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
374
+ self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
375
+
376
+
377
+
378
+
379
+ self.graphicsView = QtWidgets.QGraphicsView(Form)
380
+ self.graphicsView.setGeometry(QtCore.QRect(20, 120, 512, 512))
381
+ self.graphicsView.setObjectName("graphicsView")
382
+ self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
383
+ self.graphicsView_2.setGeometry(QtCore.QRect(620, 120, 512, 512))
384
+ self.graphicsView_2.setObjectName("graphicsView_2")
385
+ self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
386
+ self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
387
+ self.graphicsView_3.setObjectName("graphicsView_3")
388
+
389
+ self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
390
+ self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
391
+ self.graphicsView_5.setObjectName("graphicsView_5")
392
+ self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
393
+ self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
394
+ self.graphicsView_6.setObjectName("graphicsView_6")
395
+
396
+
397
+ self.retranslateUi(Form)
398
+ self.pushButton.clicked.connect(Form.generateAndReconstruct)
399
+ self.pushButton_2.clicked.connect(Form.open)
400
+ self.pushButton_3.clicked.connect(Form.open_mask)
401
+ self.pushButton_4.clicked.connect(Form.clear)
402
+ self.pushButton_5.clicked.connect(Form.undo)
403
+ self.pushButton_6.clicked.connect(Form.save_img)
404
+ self.pushButton_7.clicked.connect(Form.bg_mode)
405
+ self.pushButton_8.clicked.connect(Form.skin_mode)
406
+ self.pushButton_9.clicked.connect(Form.nose_mode)
407
+ self.pushButton_10.clicked.connect(Form.eye_g_mode)
408
+ self.pushButton_11.clicked.connect(Form.l_eye_mode)
409
+ self.pushButton_12.clicked.connect(Form.r_eye_mode)
410
+ self.pushButton_13.clicked.connect(Form.l_brow_mode)
411
+ self.pushButton_14.clicked.connect(Form.r_brow_mode)
412
+ self.pushButton_15.clicked.connect(Form.l_ear_mode)
413
+ self.pushButton_16.clicked.connect(Form.r_ear_mode)
414
+ self.pushButton_17.clicked.connect(Form.mouth_mode)
415
+ self.pushButton_18.clicked.connect(Form.u_lip_mode)
416
+ self.pushButton_19.clicked.connect(Form.l_lip_mode)
417
+ self.pushButton_20.clicked.connect(Form.hair_mode)
418
+ self.pushButton_21.clicked.connect(Form.hat_mode)
419
+ self.pushButton_22.clicked.connect(Form.ear_r_mode)
420
+ self.pushButton_23.clicked.connect(Form.neck_l_mode)
421
+ self.pushButton_24.clicked.connect(Form.neck_mode)
422
+ self.pushButton_25.clicked.connect(Form.cloth_mode)
423
+ # self.pushButton_26.clicked.connect(Form.increase)
424
+ # self.pushButton_27.clicked.connect(Form.decrease)
425
+
426
+ self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
427
+ self.slider_yawselect.valueChanged.connect(Form.changeYaw)
428
+ self.slider_pitchselect.valueChanged.connect(Form.changePitch)
429
+ self.slider_truncation.valueChanged.connect(Form.changeTruncation)
430
+
431
+ self.pushButton_inputID.clicked.connect(Form.inputID)
432
+
433
+ self.pushButton_inverse.clicked.connect(Form.inverse)
434
+ self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
435
+
436
+ QtCore.QMetaObject.connectSlotsByName(Form)
437
+
438
+ def retranslateUi(self, Form):
439
+ _translate = QtCore.QCoreApplication.translate
440
+ Form.setWindowTitle(_translate("Form", "3D-GauGAN"))
441
+ self.pushButton.setText(_translate("Form", "Generate"))
442
+ self.pushButton_2.setText(_translate("Form", "Open Image"))
443
+ self.pushButton_3.setText(_translate("Form", "Open Mask"))
444
+ self.pushButton_4.setText(_translate("Form", "Clear"))
445
+ self.pushButton_5.setText(_translate("Form", "Undo"))
446
+ self.pushButton_6.setText(_translate("Form", "Save Image"))
447
+ self.pushButton_7.setText(_translate("Form", "BackGround"))
448
+ self.pushButton_8.setText(_translate("Form", "Skin"))
449
+ self.pushButton_9.setText(_translate("Form", "Nose"))
450
+ self.pushButton_10.setText(_translate("Form", "Eyeglass"))
451
+ self.pushButton_11.setText(_translate("Form", "Left Eye"))
452
+ self.pushButton_12.setText(_translate("Form", "Right Eye"))
453
+ self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
454
+ self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
455
+ self.pushButton_15.setText(_translate("Form", "Left ear"))
456
+ self.pushButton_16.setText(_translate("Form", "Right ear"))
457
+ self.pushButton_17.setText(_translate("Form", "Mouth"))
458
+ self.pushButton_18.setText(_translate("Form", "Upper Lip"))
459
+ self.pushButton_19.setText(_translate("Form", "Lower Lip"))
460
+ self.pushButton_20.setText(_translate("Form", "Hair"))
461
+ self.pushButton_21.setText(_translate("Form", "Hat"))
462
+ self.pushButton_22.setText(_translate("Form", "Earring"))
463
+ self.pushButton_23.setText(_translate("Form", "Necklace"))
464
+ self.pushButton_24.setText(_translate("Form", "Neck"))
465
+ self.pushButton_25.setText(_translate("Form", "Cloth"))
466
+ # self.pushButton_26.setText(_translate("Form", "+"))
467
+ # self.pushButton_27.setText(_translate("Form", "-"))
468
+ self.pushButton_inputID.setText(_translate("Form", "Input ID"))
469
+ self.pushButton_inverse.setText(_translate("Form", "Inverse"))
470
+ self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
471
+
472
+
473
+ class Ui_Form_Video(object):
474
+ def setupUi(self, Form):
475
+ Form.setObjectName("Form")
476
+ Form.resize(1800, 1260)
477
+ self.pushButton = QtWidgets.QPushButton(Form)
478
+ # self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
479
+ self.pushButton.setGeometry(QtCore.QRect(535, 360, 81, 27))
480
+ self.pushButton.setObjectName("pushButton")
481
+ self.pushButton_2 = QtWidgets.QPushButton(Form)
482
+ self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
483
+ self.pushButton_2.setObjectName("pushButton_2")
484
+ self.pushButton_3 = QtWidgets.QPushButton(Form)
485
+ self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
486
+ self.pushButton_3.setObjectName("pushButton_3")
487
+ self.pushButton_4 = QtWidgets.QPushButton(Form)
488
+ self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
489
+ self.pushButton_4.setObjectName("pushButton_4")
490
+ self.pushButton_5 = QtWidgets.QPushButton(Form)
491
+ self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
492
+ self.pushButton_5.setObjectName("pushButton_5")
493
+ self.pushButton_6 = QtWidgets.QPushButton(Form)
494
+ self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
495
+ self.pushButton_6.setObjectName("pushButton_6")
496
+ self.pushButton_7 = QtWidgets.QPushButton(Form)
497
+ self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
498
+ self.pushButton_7.setObjectName("pushButton_7")
499
+ self.pushButton_8 = QtWidgets.QPushButton(Form)
500
+ self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
501
+ self.pushButton_8.setObjectName("pushButton_8")
502
+ self.pushButton_9 = QtWidgets.QPushButton(Form)
503
+ self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
504
+ self.pushButton_9.setObjectName("pushButton_9")
505
+ self.pushButton_10 = QtWidgets.QPushButton(Form)
506
+ self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
507
+ self.pushButton_10.setObjectName("pushButton_10")
508
+ self.pushButton_11 = QtWidgets.QPushButton(Form)
509
+ self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
510
+ self.pushButton_11.setObjectName("pushButton_11")
511
+ self.pushButton_12 = QtWidgets.QPushButton(Form)
512
+ self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
513
+ self.pushButton_12.setObjectName("pushButton_12")
514
+ self.pushButton_13 = QtWidgets.QPushButton(Form)
515
+ self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
516
+ self.pushButton_13.setObjectName("pushButton_13")
517
+ self.pushButton_14 = QtWidgets.QPushButton(Form)
518
+ self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
519
+ self.pushButton_14.setObjectName("pushButton_14")
520
+ self.pushButton_15 = QtWidgets.QPushButton(Form)
521
+ self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
522
+ self.pushButton_15.setObjectName("pushButton_15")
523
+ self.pushButton_16 = QtWidgets.QPushButton(Form)
524
+ self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
525
+ self.pushButton_16.setObjectName("pushButton_16")
526
+ self.pushButton_17 = QtWidgets.QPushButton(Form)
527
+ self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
528
+ self.pushButton_17.setObjectName("pushButton_17")
529
+ self.pushButton_18 = QtWidgets.QPushButton(Form)
530
+ self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
531
+ self.pushButton_18.setObjectName("pushButton_18")
532
+ self.pushButton_19 = QtWidgets.QPushButton(Form)
533
+ self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
534
+ self.pushButton_19.setObjectName("pushButton_19")
535
+ self.pushButton_20 = QtWidgets.QPushButton(Form)
536
+ self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
537
+ self.pushButton_20.setObjectName("pushButton_20")
538
+ self.pushButton_21 = QtWidgets.QPushButton(Form)
539
+ self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
540
+ self.pushButton_21.setObjectName("pushButton_21")
541
+ self.pushButton_22 = QtWidgets.QPushButton(Form)
542
+ self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
543
+ self.pushButton_22.setObjectName("pushButton_22")
544
+ self.pushButton_23 = QtWidgets.QPushButton(Form)
545
+ self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
546
+ self.pushButton_23.setObjectName("pushButton_23")
547
+ self.pushButton_24 = QtWidgets.QPushButton(Form)
548
+ self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
549
+ self.pushButton_24.setObjectName("pushButton_24")
550
+ self.pushButton_25 = QtWidgets.QPushButton(Form)
551
+ self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
552
+ self.pushButton_25.setObjectName("pushButton_25")
553
+ # self.pushButton_26 = QtWidgets.QPushButton(Form)
554
+ # self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
555
+ # self.pushButton_26.setObjectName("pushButton_26")
556
+ # self.pushButton_27 = QtWidgets.QPushButton(Form)
557
+ # self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
558
+ # self.pushButton_27.setObjectName("pushButton_27")
559
+
560
+ self.slider_sizeselect = QtWidgets.QSlider(Form)
561
+ self.slider_sizeselect.setRange(10,70)
562
+ self.slider_sizeselect.setOrientation(Qt.Horizontal)
563
+ self.slider_sizeselect.setValue(Form.size)
564
+ self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
565
+
566
+ self.label_sizeselect = QtWidgets.QLabel(Form)
567
+ self.label_sizeselect.setText("Brush Size")
568
+ self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
569
+
570
+ self.slider_yawselect = QtWidgets.QSlider(Form)
571
+ self.slider_yawselect.setRange(-100,100)
572
+ self.slider_yawselect.setOrientation(Qt.Horizontal)
573
+ self.slider_yawselect.setValue(Form.yaw)
574
+ self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
575
+
576
+ self.label_yawselect = QtWidgets.QLabel(Form)
577
+ self.label_yawselect.setText("Yaw")
578
+ self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
579
+
580
+ self.slider_pitchselect = QtWidgets.QSlider(Form)
581
+ self.slider_pitchselect.setRange(-100,100)
582
+ self.slider_pitchselect.setOrientation(Qt.Horizontal)
583
+ self.slider_pitchselect.setValue(Form.pitch)
584
+ self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
585
+
586
+ self.label_pitchselect = QtWidgets.QLabel(Form)
587
+ self.label_pitchselect.setText("Pitch")
588
+ self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
589
+
590
+ self.slider_truncation = QtWidgets.QSlider(Form)
591
+ self.slider_truncation.setRange(0,100)
592
+ self.slider_truncation.setOrientation(Qt.Horizontal)
593
+ self.slider_truncation.setValue(Form.truncation)
594
+ self.slider_truncation.setGeometry(QtCore.QRect(1530, 100, 97, 27))
595
+
596
+ self.label_truncation = QtWidgets.QLabel(Form)
597
+ self.label_truncation.setText("Truncation")
598
+ self.label_truncation.setGeometry(QtCore.QRect(1630, 100, 97, 27))
599
+
600
+ self.text_inputID = QtWidgets.QTextEdit(Form)
601
+ self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
602
+ self.text_inputID.setObjectName("text_inputID")
603
+
604
+ self.pushButton_inputID = QtWidgets.QPushButton(Form)
605
+ self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
606
+ self.pushButton_inputID.setObjectName("pushButton_inputID")
607
+
608
+ self.text_seed = QtWidgets.QTextEdit(Form)
609
+ self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
610
+ self.text_seed.setObjectName("text_seed")
611
+ self.text_seed.setPlainText("0")
612
+
613
+ self.label_seed = QtWidgets.QLabel(Form)
614
+ self.label_seed.setText("Seed")
615
+ self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
616
+
617
+ # self.pushButton_inverse = QtWidgets.QPushButton(Form)
618
+ # self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
619
+ # self.pushButton_inverse.setObjectName("pushButton_inverse")
620
+
621
+ # self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
622
+ # self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
623
+ # self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
624
+ self.pushButton_get = QtWidgets.QPushButton(Form)
625
+ self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
626
+ self.pushButton_get.setObjectName("pushButton_get")
627
+
628
+
629
+
630
+
631
+ self.graphicsView = QtWidgets.QGraphicsView(Form)
632
+ self.graphicsView.setGeometry(QtCore.QRect(20, 120, 512, 512))
633
+ self.graphicsView.setObjectName("graphicsView")
634
+ self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
635
+ self.graphicsView_2.setGeometry(QtCore.QRect(620, 120, 512, 512))
636
+ self.graphicsView_2.setObjectName("graphicsView_2")
637
+ self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
638
+ self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
639
+ self.graphicsView_3.setObjectName("graphicsView_3")
640
+
641
+ self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
642
+ self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
643
+ self.graphicsView_5.setObjectName("graphicsView_5")
644
+ self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
645
+ self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
646
+ self.graphicsView_6.setObjectName("graphicsView_6")
647
+
648
+
649
+ self.retranslateUi(Form)
650
+ self.pushButton.clicked.connect(Form.generateAndReconstruct)
651
+ self.pushButton_2.clicked.connect(Form.open)
652
+ self.pushButton_3.clicked.connect(Form.open_mask)
653
+ self.pushButton_4.clicked.connect(Form.clear)
654
+ self.pushButton_5.clicked.connect(Form.undo)
655
+ self.pushButton_6.clicked.connect(Form.save_img)
656
+ self.pushButton_7.clicked.connect(Form.bg_mode)
657
+ self.pushButton_8.clicked.connect(Form.skin_mode)
658
+ self.pushButton_9.clicked.connect(Form.nose_mode)
659
+ self.pushButton_10.clicked.connect(Form.eye_g_mode)
660
+ self.pushButton_11.clicked.connect(Form.l_eye_mode)
661
+ self.pushButton_12.clicked.connect(Form.r_eye_mode)
662
+ self.pushButton_13.clicked.connect(Form.l_brow_mode)
663
+ self.pushButton_14.clicked.connect(Form.r_brow_mode)
664
+ self.pushButton_15.clicked.connect(Form.l_ear_mode)
665
+ self.pushButton_16.clicked.connect(Form.r_ear_mode)
666
+ self.pushButton_17.clicked.connect(Form.mouth_mode)
667
+ self.pushButton_18.clicked.connect(Form.u_lip_mode)
668
+ self.pushButton_19.clicked.connect(Form.l_lip_mode)
669
+ self.pushButton_20.clicked.connect(Form.hair_mode)
670
+ self.pushButton_21.clicked.connect(Form.hat_mode)
671
+ self.pushButton_22.clicked.connect(Form.ear_r_mode)
672
+ self.pushButton_23.clicked.connect(Form.neck_l_mode)
673
+ self.pushButton_24.clicked.connect(Form.neck_mode)
674
+ self.pushButton_25.clicked.connect(Form.cloth_mode)
675
+ # self.pushButton_26.clicked.connect(Form.increase)
676
+ # self.pushButton_27.clicked.connect(Form.decrease)
677
+
678
+ self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
679
+ self.slider_yawselect.valueChanged.connect(Form.changeYaw)
680
+ self.slider_pitchselect.valueChanged.connect(Form.changePitch)
681
+ self.slider_truncation.valueChanged.connect(Form.changeTruncation)
682
+
683
+ self.pushButton_inputID.clicked.connect(Form.inputID)
684
+
685
+ # self.pushButton_inverse.clicked.connect(Form.inverse)
686
+ # self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
687
+ self.pushButton_get.clicked.connect(Form.get_mask)
688
+
689
+ QtCore.QMetaObject.connectSlotsByName(Form)
690
+
691
+ def retranslateUi(self, Form):
692
+ _translate = QtCore.QCoreApplication.translate
693
+ Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis"))
694
+ self.pushButton.setText(_translate("Form", "Generate"))
695
+ self.pushButton_2.setText(_translate("Form", "Open Image"))
696
+ self.pushButton_3.setText(_translate("Form", "Open Mask"))
697
+ self.pushButton_4.setText(_translate("Form", "Clear"))
698
+ self.pushButton_5.setText(_translate("Form", "Undo"))
699
+ self.pushButton_6.setText(_translate("Form", "Save Image"))
700
+ self.pushButton_7.setText(_translate("Form", "BackGround"))
701
+ self.pushButton_8.setText(_translate("Form", "Skin"))
702
+ self.pushButton_9.setText(_translate("Form", "Nose"))
703
+ self.pushButton_10.setText(_translate("Form", "Eyeglass"))
704
+ self.pushButton_11.setText(_translate("Form", "Left Eye"))
705
+ self.pushButton_12.setText(_translate("Form", "Right Eye"))
706
+ self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
707
+ self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
708
+ self.pushButton_15.setText(_translate("Form", "Left ear"))
709
+ self.pushButton_16.setText(_translate("Form", "Right ear"))
710
+ self.pushButton_17.setText(_translate("Form", "Mouth"))
711
+ self.pushButton_18.setText(_translate("Form", "Upper Lip"))
712
+ self.pushButton_19.setText(_translate("Form", "Lower Lip"))
713
+ self.pushButton_20.setText(_translate("Form", "Hair"))
714
+ self.pushButton_21.setText(_translate("Form", "Hat"))
715
+ self.pushButton_22.setText(_translate("Form", "Earring"))
716
+ self.pushButton_23.setText(_translate("Form", "Necklace"))
717
+ self.pushButton_24.setText(_translate("Form", "Neck"))
718
+ self.pushButton_25.setText(_translate("Form", "Cloth"))
719
+ # self.pushButton_26.setText(_translate("Form", "+"))
720
+ # self.pushButton_27.setText(_translate("Form", "-"))
721
+ self.pushButton_inputID.setText(_translate("Form", "Input ID"))
722
+ # self.pushButton_inverse.setText(_translate("Form", "Inverse"))
723
+ # self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
724
+ self.pushButton_get.setText(_translate("Form", "Get"))
725
+
726
+
727
+ class Ui_Form_Edge2car(object):
728
+ def setupUi(self, Form):
729
+ Form.setObjectName("Form")
730
+ Form.resize(1800, 1260)
731
+ self.pushButton = QtWidgets.QPushButton(Form)
732
+ # self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
733
+ self.pushButton.setGeometry(QtCore.QRect(535, 360, 81, 27))
734
+ self.pushButton.setObjectName("pushButton")
735
+ self.pushButton_2 = QtWidgets.QPushButton(Form)
736
+ self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
737
+ self.pushButton_2.setObjectName("pushButton_2")
738
+ self.pushButton_3 = QtWidgets.QPushButton(Form)
739
+ self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
740
+ self.pushButton_3.setObjectName("pushButton_3")
741
+ self.pushButton_4 = QtWidgets.QPushButton(Form)
742
+ self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
743
+ self.pushButton_4.setObjectName("pushButton_4")
744
+ self.pushButton_5 = QtWidgets.QPushButton(Form)
745
+ self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
746
+ self.pushButton_5.setObjectName("pushButton_5")
747
+ self.pushButton_6 = QtWidgets.QPushButton(Form)
748
+ self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
749
+ self.pushButton_6.setObjectName("pushButton_6")
750
+ self.pushButton_7 = QtWidgets.QPushButton(Form)
751
+ self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
752
+ self.pushButton_7.setObjectName("pushButton_7")
753
+ self.pushButton_8 = QtWidgets.QPushButton(Form)
754
+ self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
755
+ self.pushButton_8.setObjectName("pushButton_8")
756
+ self.pushButton_9 = QtWidgets.QPushButton(Form)
757
+ self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
758
+ self.pushButton_9.setObjectName("pushButton_9")
759
+ self.pushButton_10 = QtWidgets.QPushButton(Form)
760
+ self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
761
+ self.pushButton_10.setObjectName("pushButton_10")
762
+ self.pushButton_11 = QtWidgets.QPushButton(Form)
763
+ self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
764
+ self.pushButton_11.setObjectName("pushButton_11")
765
+ self.pushButton_12 = QtWidgets.QPushButton(Form)
766
+ self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
767
+ self.pushButton_12.setObjectName("pushButton_12")
768
+ self.pushButton_13 = QtWidgets.QPushButton(Form)
769
+ self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
770
+ self.pushButton_13.setObjectName("pushButton_13")
771
+ self.pushButton_14 = QtWidgets.QPushButton(Form)
772
+ self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
773
+ self.pushButton_14.setObjectName("pushButton_14")
774
+ self.pushButton_15 = QtWidgets.QPushButton(Form)
775
+ self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
776
+ self.pushButton_15.setObjectName("pushButton_15")
777
+ self.pushButton_16 = QtWidgets.QPushButton(Form)
778
+ self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
779
+ self.pushButton_16.setObjectName("pushButton_16")
780
+ self.pushButton_17 = QtWidgets.QPushButton(Form)
781
+ self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
782
+ self.pushButton_17.setObjectName("pushButton_17")
783
+ self.pushButton_18 = QtWidgets.QPushButton(Form)
784
+ self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
785
+ self.pushButton_18.setObjectName("pushButton_18")
786
+ self.pushButton_19 = QtWidgets.QPushButton(Form)
787
+ self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
788
+ self.pushButton_19.setObjectName("pushButton_19")
789
+ self.pushButton_20 = QtWidgets.QPushButton(Form)
790
+ self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
791
+ self.pushButton_20.setObjectName("pushButton_20")
792
+ self.pushButton_21 = QtWidgets.QPushButton(Form)
793
+ self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
794
+ self.pushButton_21.setObjectName("pushButton_21")
795
+ self.pushButton_22 = QtWidgets.QPushButton(Form)
796
+ self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
797
+ self.pushButton_22.setObjectName("pushButton_22")
798
+ self.pushButton_23 = QtWidgets.QPushButton(Form)
799
+ self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
800
+ self.pushButton_23.setObjectName("pushButton_23")
801
+ self.pushButton_24 = QtWidgets.QPushButton(Form)
802
+ self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
803
+ self.pushButton_24.setObjectName("pushButton_24")
804
+ self.pushButton_25 = QtWidgets.QPushButton(Form)
805
+ self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
806
+ self.pushButton_25.setObjectName("pushButton_25")
807
+ # self.pushButton_26 = QtWidgets.QPushButton(Form)
808
+ # self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
809
+ # self.pushButton_26.setObjectName("pushButton_26")
810
+ # self.pushButton_27 = QtWidgets.QPushButton(Form)
811
+ # self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
812
+ # self.pushButton_27.setObjectName("pushButton_27")
813
+
814
+ self.slider_sizeselect = QtWidgets.QSlider(Form)
815
+ self.slider_sizeselect.setRange(10,70)
816
+ self.slider_sizeselect.setOrientation(Qt.Horizontal)
817
+ self.slider_sizeselect.setValue(Form.size)
818
+ self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
819
+
820
+ self.label_sizeselect = QtWidgets.QLabel(Form)
821
+ self.label_sizeselect.setText("Brush Size")
822
+ self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
823
+
824
+ self.slider_yawselect = QtWidgets.QSlider(Form)
825
+ self.slider_yawselect.setRange(-100,100)
826
+ self.slider_yawselect.setOrientation(Qt.Horizontal)
827
+ self.slider_yawselect.setValue(Form.yaw)
828
+ self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
829
+
830
+ self.label_yawselect = QtWidgets.QLabel(Form)
831
+ self.label_yawselect.setText("Yaw")
832
+ self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
833
+
834
+ self.slider_pitchselect = QtWidgets.QSlider(Form)
835
+ self.slider_pitchselect.setRange(-100,100)
836
+ self.slider_pitchselect.setOrientation(Qt.Horizontal)
837
+ self.slider_pitchselect.setValue(Form.pitch)
838
+ self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
839
+
840
+ self.label_pitchselect = QtWidgets.QLabel(Form)
841
+ self.label_pitchselect.setText("Pitch")
842
+ self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
843
+
844
+ self.slider_truncation = QtWidgets.QSlider(Form)
845
+ self.slider_truncation.setRange(0,100)
846
+ self.slider_truncation.setOrientation(Qt.Horizontal)
847
+ self.slider_truncation.setValue(Form.truncation)
848
+ self.slider_truncation.setGeometry(QtCore.QRect(1530, 100, 97, 27))
849
+
850
+ self.label_truncation = QtWidgets.QLabel(Form)
851
+ self.label_truncation.setText("Truncation")
852
+ self.label_truncation.setGeometry(QtCore.QRect(1630, 100, 97, 27))
853
+
854
+ self.text_inputID = QtWidgets.QTextEdit(Form)
855
+ self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
856
+ self.text_inputID.setObjectName("text_inputID")
857
+
858
+ self.pushButton_inputID = QtWidgets.QPushButton(Form)
859
+ self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
860
+ self.pushButton_inputID.setObjectName("pushButton_inputID")
861
+
862
+ self.text_seed = QtWidgets.QTextEdit(Form)
863
+ self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
864
+ self.text_seed.setObjectName("text_seed")
865
+ self.text_seed.setPlainText("0")
866
+
867
+ self.label_seed = QtWidgets.QLabel(Form)
868
+ self.label_seed.setText("Seed")
869
+ self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
870
+
871
+ # self.pushButton_inverse = QtWidgets.QPushButton(Form)
872
+ # self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
873
+ # self.pushButton_inverse.setObjectName("pushButton_inverse")
874
+
875
+ # self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
876
+ # self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
877
+ # self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
878
+ self.pushButton_get = QtWidgets.QPushButton(Form)
879
+ self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
880
+ self.pushButton_get.setObjectName("pushButton_get")
881
+
882
+
883
+
884
+
885
+ self.graphicsView = QtWidgets.QGraphicsView(Form)
886
+ self.graphicsView.setGeometry(QtCore.QRect(20, 120, 512, 512))
887
+ self.graphicsView.setObjectName("graphicsView")
888
+ self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
889
+ self.graphicsView_2.setGeometry(QtCore.QRect(620, 120, 512, 512))
890
+ self.graphicsView_2.setObjectName("graphicsView_2")
891
+ self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
892
+ self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
893
+ self.graphicsView_3.setObjectName("graphicsView_3")
894
+
895
+ self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
896
+ self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
897
+ self.graphicsView_5.setObjectName("graphicsView_5")
898
+ self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
899
+ self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
900
+ self.graphicsView_6.setObjectName("graphicsView_6")
901
+
902
+
903
+ self.retranslateUi(Form)
904
+ self.pushButton.clicked.connect(Form.generateAndReconstruct)
905
+ self.pushButton_2.clicked.connect(Form.open)
906
+ self.pushButton_3.clicked.connect(Form.open_mask)
907
+ self.pushButton_4.clicked.connect(Form.clear)
908
+ self.pushButton_5.clicked.connect(Form.undo)
909
+ self.pushButton_6.clicked.connect(Form.save_img)
910
+ self.pushButton_7.clicked.connect(Form.bg_mode)
911
+ self.pushButton_8.clicked.connect(Form.skin_mode)
912
+ self.pushButton_9.clicked.connect(Form.nose_mode)
913
+ self.pushButton_10.clicked.connect(Form.eye_g_mode)
914
+ self.pushButton_11.clicked.connect(Form.l_eye_mode)
915
+ self.pushButton_12.clicked.connect(Form.r_eye_mode)
916
+ self.pushButton_13.clicked.connect(Form.l_brow_mode)
917
+ self.pushButton_14.clicked.connect(Form.r_brow_mode)
918
+ self.pushButton_15.clicked.connect(Form.l_ear_mode)
919
+ self.pushButton_16.clicked.connect(Form.r_ear_mode)
920
+ self.pushButton_17.clicked.connect(Form.mouth_mode)
921
+ self.pushButton_18.clicked.connect(Form.u_lip_mode)
922
+ self.pushButton_19.clicked.connect(Form.l_lip_mode)
923
+ self.pushButton_20.clicked.connect(Form.hair_mode)
924
+ self.pushButton_21.clicked.connect(Form.hat_mode)
925
+ self.pushButton_22.clicked.connect(Form.ear_r_mode)
926
+ self.pushButton_23.clicked.connect(Form.neck_l_mode)
927
+ self.pushButton_24.clicked.connect(Form.neck_mode)
928
+ self.pushButton_25.clicked.connect(Form.cloth_mode)
929
+ # self.pushButton_26.clicked.connect(Form.increase)
930
+ # self.pushButton_27.clicked.connect(Form.decrease)
931
+
932
+ self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
933
+ self.slider_yawselect.valueChanged.connect(Form.changeYaw)
934
+ self.slider_pitchselect.valueChanged.connect(Form.changePitch)
935
+ self.slider_truncation.valueChanged.connect(Form.changeTruncation)
936
+
937
+ self.pushButton_inputID.clicked.connect(Form.inputID)
938
+
939
+ # self.pushButton_inverse.clicked.connect(Form.inverse)
940
+ # self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
941
+ self.pushButton_get.clicked.connect(Form.get_mask)
942
+
943
+ QtCore.QMetaObject.connectSlotsByName(Form)
944
+
945
+ def retranslateUi(self, Form):
946
+ _translate = QtCore.QCoreApplication.translate
947
+ Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis (Edge2car)"))
948
+ self.pushButton.setText(_translate("Form", "Generate"))
949
+ self.pushButton_2.setText(_translate("Form", "Open Image"))
950
+ self.pushButton_3.setText(_translate("Form", "Open Mask"))
951
+ self.pushButton_4.setText(_translate("Form", "Clear"))
952
+ self.pushButton_5.setText(_translate("Form", "Undo"))
953
+ self.pushButton_6.setText(_translate("Form", "Save Image"))
954
+ self.pushButton_7.setText(_translate("Form", "BackGround"))
955
+ self.pushButton_8.setText(_translate("Form", "Skin"))
956
+ self.pushButton_9.setText(_translate("Form", "Nose"))
957
+ self.pushButton_10.setText(_translate("Form", "Eyeglass"))
958
+ self.pushButton_11.setText(_translate("Form", "Left Eye"))
959
+ self.pushButton_12.setText(_translate("Form", "Right Eye"))
960
+ self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
961
+ self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
962
+ self.pushButton_15.setText(_translate("Form", "Left ear"))
963
+ self.pushButton_16.setText(_translate("Form", "Right ear"))
964
+ self.pushButton_17.setText(_translate("Form", "Mouth"))
965
+ self.pushButton_18.setText(_translate("Form", "Upper Lip"))
966
+ self.pushButton_19.setText(_translate("Form", "Lower Lip"))
967
+ self.pushButton_20.setText(_translate("Form", "Hair"))
968
+ self.pushButton_21.setText(_translate("Form", "Hat"))
969
+ self.pushButton_22.setText(_translate("Form", "Earring"))
970
+ self.pushButton_23.setText(_translate("Form", "Necklace"))
971
+ self.pushButton_24.setText(_translate("Form", "Neck"))
972
+ self.pushButton_25.setText(_translate("Form", "Cloth"))
973
+ # self.pushButton_26.setText(_translate("Form", "+"))
974
+ # self.pushButton_27.setText(_translate("Form", "-"))
975
+ self.pushButton_inputID.setText(_translate("Form", "Input ID"))
976
+ # self.pushButton_inverse.setText(_translate("Form", "Inverse"))
977
+ # self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
978
+ self.pushButton_get.setText(_translate("Form", "Get"))
979
+
980
+ if __name__ == "__main__":
981
+ import sys
982
+ app = QtWidgets.QApplication(sys.argv)
983
+ Form = QtWidgets.QWidget()
984
+ ui = Ui_Form()
985
+ ui.setupUi(Form)
986
+ Form.show()
987
+ sys.exit(app.exec_())
988
+
pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/ui_clean.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyQt5 import QtCore, QtGui, QtWidgets
2
+ from PyQt5.QtCore import Qt
3
+
4
+ class Ui_Form_Clean(object):
5
+ def setupUi(self, Form):
6
+ Form.setObjectName("Form")
7
+ Form.resize(1800, 660)
8
+ self.pushButton = QtWidgets.QPushButton(Form)
9
+ # self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
10
+ self.pushButton.setGeometry(QtCore.QRect(685, 360, 81, 27))
11
+ self.pushButton.setObjectName("pushButton")
12
+ self.pushButton_2 = QtWidgets.QPushButton(Form)
13
+ self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
14
+ self.pushButton_2.setObjectName("pushButton_2")
15
+ self.pushButton_3 = QtWidgets.QPushButton(Form)
16
+ self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
17
+ self.pushButton_3.setObjectName("pushButton_3")
18
+ self.pushButton_4 = QtWidgets.QPushButton(Form)
19
+ self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
20
+ self.pushButton_4.setObjectName("pushButton_4")
21
+ self.pushButton_5 = QtWidgets.QPushButton(Form)
22
+ self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
23
+ self.pushButton_5.setObjectName("pushButton_5")
24
+ self.pushButton_6 = QtWidgets.QPushButton(Form)
25
+ self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
26
+ self.pushButton_6.setObjectName("pushButton_6")
27
+ self.pushButton_7 = QtWidgets.QPushButton(Form)
28
+ self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
29
+ self.pushButton_7.setObjectName("pushButton_7")
30
+ self.pushButton_8 = QtWidgets.QPushButton(Form)
31
+ self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
32
+ self.pushButton_8.setObjectName("pushButton_8")
33
+ self.pushButton_9 = QtWidgets.QPushButton(Form)
34
+ self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
35
+ self.pushButton_9.setObjectName("pushButton_9")
36
+ self.pushButton_10 = QtWidgets.QPushButton(Form)
37
+ self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
38
+ self.pushButton_10.setObjectName("pushButton_10")
39
+ self.pushButton_11 = QtWidgets.QPushButton(Form)
40
+ self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
41
+ self.pushButton_11.setObjectName("pushButton_11")
42
+ self.pushButton_12 = QtWidgets.QPushButton(Form)
43
+ self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
44
+ self.pushButton_12.setObjectName("pushButton_12")
45
+ self.pushButton_13 = QtWidgets.QPushButton(Form)
46
+ self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
47
+ self.pushButton_13.setObjectName("pushButton_13")
48
+ self.pushButton_14 = QtWidgets.QPushButton(Form)
49
+ self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
50
+ self.pushButton_14.setObjectName("pushButton_14")
51
+ self.pushButton_15 = QtWidgets.QPushButton(Form)
52
+ self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
53
+ self.pushButton_15.setObjectName("pushButton_15")
54
+ self.pushButton_16 = QtWidgets.QPushButton(Form)
55
+ self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
56
+ self.pushButton_16.setObjectName("pushButton_16")
57
+ self.pushButton_17 = QtWidgets.QPushButton(Form)
58
+ self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
59
+ self.pushButton_17.setObjectName("pushButton_17")
60
+ self.pushButton_18 = QtWidgets.QPushButton(Form)
61
+ self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
62
+ self.pushButton_18.setObjectName("pushButton_18")
63
+ self.pushButton_19 = QtWidgets.QPushButton(Form)
64
+ self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
65
+ self.pushButton_19.setObjectName("pushButton_19")
66
+ self.pushButton_20 = QtWidgets.QPushButton(Form)
67
+ self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
68
+ self.pushButton_20.setObjectName("pushButton_20")
69
+ self.pushButton_21 = QtWidgets.QPushButton(Form)
70
+ self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
71
+ self.pushButton_21.setObjectName("pushButton_21")
72
+ self.pushButton_22 = QtWidgets.QPushButton(Form)
73
+ self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
74
+ self.pushButton_22.setObjectName("pushButton_22")
75
+ self.pushButton_23 = QtWidgets.QPushButton(Form)
76
+ self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
77
+ self.pushButton_23.setObjectName("pushButton_23")
78
+ self.pushButton_24 = QtWidgets.QPushButton(Form)
79
+ self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
80
+ self.pushButton_24.setObjectName("pushButton_24")
81
+ self.pushButton_25 = QtWidgets.QPushButton(Form)
82
+ self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
83
+ self.pushButton_25.setObjectName("pushButton_25")
84
+ # self.pushButton_26 = QtWidgets.QPushButton(Form)
85
+ # self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
86
+ # self.pushButton_26.setObjectName("pushButton_26")
87
+ # self.pushButton_27 = QtWidgets.QPushButton(Form)
88
+ # self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
89
+ # self.pushButton_27.setObjectName("pushButton_27")
90
+
91
+ self.slider_sizeselect = QtWidgets.QSlider(Form)
92
+ self.slider_sizeselect.setRange(10,70)
93
+ self.slider_sizeselect.setOrientation(Qt.Horizontal)
94
+ self.slider_sizeselect.setValue(Form.size)
95
+ self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
96
+
97
+ self.label_sizeselect = QtWidgets.QLabel(Form)
98
+ self.label_sizeselect.setText("Brush Size")
99
+ self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
100
+
101
+ self.slider_yawselect = QtWidgets.QSlider(Form)
102
+ self.slider_yawselect.setRange(-100,100)
103
+ self.slider_yawselect.setOrientation(Qt.Horizontal)
104
+ self.slider_yawselect.setValue(Form.yaw)
105
+ self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
106
+
107
+ self.label_yawselect = QtWidgets.QLabel(Form)
108
+ self.label_yawselect.setText("Yaw")
109
+ self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
110
+
111
+ self.slider_pitchselect = QtWidgets.QSlider(Form)
112
+ self.slider_pitchselect.setRange(-100,100)
113
+ self.slider_pitchselect.setOrientation(Qt.Horizontal)
114
+ self.slider_pitchselect.setValue(Form.pitch)
115
+ self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
116
+
117
+ self.label_pitchselect = QtWidgets.QLabel(Form)
118
+ self.label_pitchselect.setText("Pitch")
119
+ self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
120
+
121
+ self.slider_truncation = QtWidgets.QSlider(Form)
122
+ self.slider_truncation.setRange(0,100)
123
+ self.slider_truncation.setOrientation(Qt.Horizontal)
124
+ self.slider_truncation.setValue(Form.truncation)
125
+ self.slider_truncation.setGeometry(QtCore.QRect(1530, 100, 97, 27))
126
+
127
+ self.label_truncation = QtWidgets.QLabel(Form)
128
+ self.label_truncation.setText("Truncation")
129
+ self.label_truncation.setGeometry(QtCore.QRect(1630, 100, 97, 27))
130
+
131
+ self.text_inputID = QtWidgets.QTextEdit(Form)
132
+ self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
133
+ self.text_inputID.setObjectName("text_inputID")
134
+
135
+ self.pushButton_inputID = QtWidgets.QPushButton(Form)
136
+ self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
137
+ self.pushButton_inputID.setObjectName("pushButton_inputID")
138
+
139
+ self.text_seed = QtWidgets.QTextEdit(Form)
140
+ self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
141
+ self.text_seed.setObjectName("text_seed")
142
+ self.text_seed.setPlainText("0")
143
+
144
+ self.label_seed = QtWidgets.QLabel(Form)
145
+ self.label_seed.setText("Seed")
146
+ self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
147
+
148
+ # self.pushButton_inverse = QtWidgets.QPushButton(Form)
149
+ # self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
150
+ # self.pushButton_inverse.setObjectName("pushButton_inverse")
151
+
152
+ # self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
153
+ # self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
154
+ # self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
155
+ self.pushButton_get = QtWidgets.QPushButton(Form)
156
+ self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
157
+ self.pushButton_get.setObjectName("pushButton_get")
158
+
159
+
160
+
161
+
162
+ self.graphicsView = QtWidgets.QGraphicsView(Form)
163
+ self.graphicsView.setGeometry(QtCore.QRect(120, 120, 512, 512))
164
+ self.graphicsView.setObjectName("graphicsView")
165
+ self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
166
+ self.graphicsView_2.setGeometry(QtCore.QRect(820, 120, 512, 512))
167
+ self.graphicsView_2.setObjectName("graphicsView_2")
168
+ # self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
169
+ # self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
170
+ # self.graphicsView_3.setObjectName("graphicsView_3")
171
+
172
+ # self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
173
+ # self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
174
+ # self.graphicsView_5.setObjectName("graphicsView_5")
175
+ # self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
176
+ # self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
177
+ # self.graphicsView_6.setObjectName("graphicsView_6")
178
+
179
+
180
+ self.retranslateUi(Form)
181
+ self.pushButton.clicked.connect(Form.generateAndReconstruct)
182
+ self.pushButton_2.clicked.connect(Form.open)
183
+ self.pushButton_3.clicked.connect(Form.open_mask)
184
+ self.pushButton_4.clicked.connect(Form.clear)
185
+ self.pushButton_5.clicked.connect(Form.undo)
186
+ self.pushButton_6.clicked.connect(Form.save_img)
187
+ self.pushButton_7.clicked.connect(Form.bg_mode)
188
+ self.pushButton_8.clicked.connect(Form.skin_mode)
189
+ self.pushButton_9.clicked.connect(Form.nose_mode)
190
+ self.pushButton_10.clicked.connect(Form.eye_g_mode)
191
+ self.pushButton_11.clicked.connect(Form.l_eye_mode)
192
+ self.pushButton_12.clicked.connect(Form.r_eye_mode)
193
+ self.pushButton_13.clicked.connect(Form.l_brow_mode)
194
+ self.pushButton_14.clicked.connect(Form.r_brow_mode)
195
+ self.pushButton_15.clicked.connect(Form.l_ear_mode)
196
+ self.pushButton_16.clicked.connect(Form.r_ear_mode)
197
+ self.pushButton_17.clicked.connect(Form.mouth_mode)
198
+ self.pushButton_18.clicked.connect(Form.u_lip_mode)
199
+ self.pushButton_19.clicked.connect(Form.l_lip_mode)
200
+ self.pushButton_20.clicked.connect(Form.hair_mode)
201
+ self.pushButton_21.clicked.connect(Form.hat_mode)
202
+ self.pushButton_22.clicked.connect(Form.ear_r_mode)
203
+ self.pushButton_23.clicked.connect(Form.neck_l_mode)
204
+ self.pushButton_24.clicked.connect(Form.neck_mode)
205
+ self.pushButton_25.clicked.connect(Form.cloth_mode)
206
+ # self.pushButton_26.clicked.connect(Form.increase)
207
+ # self.pushButton_27.clicked.connect(Form.decrease)
208
+
209
+ self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
210
+ self.slider_yawselect.valueChanged.connect(Form.changeYaw)
211
+ self.slider_pitchselect.valueChanged.connect(Form.changePitch)
212
+ self.slider_truncation.valueChanged.connect(Form.changeTruncation)
213
+
214
+ self.pushButton_inputID.clicked.connect(Form.inputID)
215
+
216
+ # self.pushButton_inverse.clicked.connect(Form.inverse)
217
+ # self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
218
+ self.pushButton_get.clicked.connect(Form.get_mask)
219
+
220
+ QtCore.QMetaObject.connectSlotsByName(Form)
221
+
222
+ def retranslateUi(self, Form):
223
+ _translate = QtCore.QCoreApplication.translate
224
+ Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis"))
225
+ self.pushButton.setText(_translate("Form", "Generate"))
226
+ self.pushButton_2.setText(_translate("Form", "Open Image"))
227
+ self.pushButton_3.setText(_translate("Form", "Open Mask"))
228
+ self.pushButton_4.setText(_translate("Form", "Clear"))
229
+ self.pushButton_5.setText(_translate("Form", "Undo"))
230
+ self.pushButton_6.setText(_translate("Form", "Save Image"))
231
+ self.pushButton_7.setText(_translate("Form", "BackGround"))
232
+ self.pushButton_8.setText(_translate("Form", "Skin"))
233
+ self.pushButton_9.setText(_translate("Form", "Nose"))
234
+ self.pushButton_10.setText(_translate("Form", "Eyeglass"))
235
+ self.pushButton_11.setText(_translate("Form", "Left Eye"))
236
+ self.pushButton_12.setText(_translate("Form", "Right Eye"))
237
+ self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
238
+ self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
239
+ self.pushButton_15.setText(_translate("Form", "Left ear"))
240
+ self.pushButton_16.setText(_translate("Form", "Right ear"))
241
+ self.pushButton_17.setText(_translate("Form", "Mouth"))
242
+ self.pushButton_18.setText(_translate("Form", "Upper Lip"))
243
+ self.pushButton_19.setText(_translate("Form", "Lower Lip"))
244
+ self.pushButton_20.setText(_translate("Form", "Hair"))
245
+ self.pushButton_21.setText(_translate("Form", "Hat"))
246
+ self.pushButton_22.setText(_translate("Form", "Earring"))
247
+ self.pushButton_23.setText(_translate("Form", "Necklace"))
248
+ self.pushButton_24.setText(_translate("Form", "Neck"))
249
+ self.pushButton_25.setText(_translate("Form", "Cloth"))
250
+ # self.pushButton_26.setText(_translate("Form", "+"))
251
+ # self.pushButton_27.setText(_translate("Form", "-"))
252
+ self.pushButton_inputID.setText(_translate("Form", "Input ID"))
253
+ # self.pushButton_inverse.setText(_translate("Form", "Inverse"))
254
+ # self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
255
+ self.pushButton_get.setText(_translate("Form", "Get"))
256
+
257
+
258
+ class Ui_Form_Edge2car(object):
259
+ def setupUi(self, Form):
260
+ Form.setObjectName("Form")
261
+ Form.resize(1800, 660)
262
+ self.pushButton = QtWidgets.QPushButton(Form)
263
+ # self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
264
+ self.pushButton.setGeometry(QtCore.QRect(685, 360, 81, 27))
265
+ self.pushButton.setObjectName("pushButton")
266
+ self.pushButton_2 = QtWidgets.QPushButton(Form)
267
+ self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
268
+ self.pushButton_2.setObjectName("pushButton_2")
269
+ self.pushButton_3 = QtWidgets.QPushButton(Form)
270
+ self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
271
+ self.pushButton_3.setObjectName("pushButton_3")
272
+ self.pushButton_4 = QtWidgets.QPushButton(Form)
273
+ self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
274
+ self.pushButton_4.setObjectName("pushButton_4")
275
+ self.pushButton_5 = QtWidgets.QPushButton(Form)
276
+ self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
277
+ self.pushButton_5.setObjectName("pushButton_5")
278
+ self.pushButton_6 = QtWidgets.QPushButton(Form)
279
+ self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
280
+ self.pushButton_6.setObjectName("pushButton_6")
281
+ self.pushButton_7 = QtWidgets.QPushButton(Form)
282
+ self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
283
+ self.pushButton_7.setObjectName("pushButton_7")
284
+ self.pushButton_8 = QtWidgets.QPushButton(Form)
285
+ self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
286
+ self.pushButton_8.setObjectName("pushButton_8")
287
+ # self.pushButton_9 = QtWidgets.QPushButton(Form)
288
+ # self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
289
+ # self.pushButton_9.setObjectName("pushButton_9")
290
+ # self.pushButton_10 = QtWidgets.QPushButton(Form)
291
+ # self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
292
+ # self.pushButton_10.setObjectName("pushButton_10")
293
+ # self.pushButton_11 = QtWidgets.QPushButton(Form)
294
+ # self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
295
+ # self.pushButton_11.setObjectName("pushButton_11")
296
+ # self.pushButton_12 = QtWidgets.QPushButton(Form)
297
+ # self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
298
+ # self.pushButton_12.setObjectName("pushButton_12")
299
+ # self.pushButton_13 = QtWidgets.QPushButton(Form)
300
+ # self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
301
+ # self.pushButton_13.setObjectName("pushButton_13")
302
+ # self.pushButton_14 = QtWidgets.QPushButton(Form)
303
+ # self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
304
+ # self.pushButton_14.setObjectName("pushButton_14")
305
+ # self.pushButton_15 = QtWidgets.QPushButton(Form)
306
+ # self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
307
+ # self.pushButton_15.setObjectName("pushButton_15")
308
+ # self.pushButton_16 = QtWidgets.QPushButton(Form)
309
+ # self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
310
+ # self.pushButton_16.setObjectName("pushButton_16")
311
+ # self.pushButton_17 = QtWidgets.QPushButton(Form)
312
+ # self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
313
+ # self.pushButton_17.setObjectName("pushButton_17")
314
+ # self.pushButton_18 = QtWidgets.QPushButton(Form)
315
+ # self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
316
+ # self.pushButton_18.setObjectName("pushButton_18")
317
+ # self.pushButton_19 = QtWidgets.QPushButton(Form)
318
+ # self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
319
+ # self.pushButton_19.setObjectName("pushButton_19")
320
+ # self.pushButton_20 = QtWidgets.QPushButton(Form)
321
+ # self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
322
+ # self.pushButton_20.setObjectName("pushButton_20")
323
+ # self.pushButton_21 = QtWidgets.QPushButton(Form)
324
+ # self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
325
+ # self.pushButton_21.setObjectName("pushButton_21")
326
+ # self.pushButton_22 = QtWidgets.QPushButton(Form)
327
+ # self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
328
+ # self.pushButton_22.setObjectName("pushButton_22")
329
+ # self.pushButton_23 = QtWidgets.QPushButton(Form)
330
+ # self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
331
+ # self.pushButton_23.setObjectName("pushButton_23")
332
+ # self.pushButton_24 = QtWidgets.QPushButton(Form)
333
+ # self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
334
+ # self.pushButton_24.setObjectName("pushButton_24")
335
+ # self.pushButton_25 = QtWidgets.QPushButton(Form)
336
+ # self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
337
+ # self.pushButton_25.setObjectName("pushButton_25")
338
+ # self.pushButton_26 = QtWidgets.QPushButton(Form)
339
+ # self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
340
+ # self.pushButton_26.setObjectName("pushButton_26")
341
+ # self.pushButton_27 = QtWidgets.QPushButton(Form)
342
+ # self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
343
+ # self.pushButton_27.setObjectName("pushButton_27")
344
+
345
+ self.slider_sizeselect = QtWidgets.QSlider(Form)
346
+ self.slider_sizeselect.setRange(10,70)
347
+ self.slider_sizeselect.setOrientation(Qt.Horizontal)
348
+ self.slider_sizeselect.setValue(Form.size)
349
+ self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 100, 97, 27))
350
+
351
+ self.label_sizeselect = QtWidgets.QLabel(Form)
352
+ self.label_sizeselect.setText("Brush Size")
353
+ self.label_sizeselect.setGeometry(QtCore.QRect(1630, 100, 97, 27))
354
+
355
+ self.slider_yawselect = QtWidgets.QSlider(Form)
356
+ self.slider_yawselect.setRange(-100,100)
357
+ self.slider_yawselect.setOrientation(Qt.Horizontal)
358
+ self.slider_yawselect.setValue(Form.yaw)
359
+ self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
360
+
361
+ self.label_yawselect = QtWidgets.QLabel(Form)
362
+ self.label_yawselect.setText("Yaw")
363
+ self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
364
+
365
+ self.slider_pitchselect = QtWidgets.QSlider(Form)
366
+ self.slider_pitchselect.setRange(-100,100)
367
+ self.slider_pitchselect.setOrientation(Qt.Horizontal)
368
+ self.slider_pitchselect.setValue(Form.pitch)
369
+ self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
370
+
371
+ self.label_pitchselect = QtWidgets.QLabel(Form)
372
+ self.label_pitchselect.setText("Pitch")
373
+ self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
374
+
375
+ self.slider_rollselect = QtWidgets.QSlider(Form)
376
+ self.slider_rollselect.setRange(0,100)
377
+ self.slider_rollselect.setOrientation(Qt.Horizontal)
378
+ self.slider_rollselect.setValue(Form.roll)
379
+ self.slider_rollselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
380
+
381
+ self.label_rollselect = QtWidgets.QLabel(Form)
382
+ self.label_rollselect.setText("Roll")
383
+ self.label_rollselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
384
+
385
+
386
+ self.slider_truncation = QtWidgets.QSlider(Form)
387
+ self.slider_truncation.setRange(0,100)
388
+ self.slider_truncation.setOrientation(Qt.Horizontal)
389
+ self.slider_truncation.setValue(Form.truncation)
390
+ self.slider_truncation.setGeometry(QtCore.QRect(1530, 130, 97, 27))
391
+
392
+ self.label_truncation = QtWidgets.QLabel(Form)
393
+ self.label_truncation.setText("Truncation")
394
+ self.label_truncation.setGeometry(QtCore.QRect(1630, 130, 97, 27))
395
+
396
+ self.text_inputID = QtWidgets.QTextEdit(Form)
397
+ self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
398
+ self.text_inputID.setObjectName("text_inputID")
399
+
400
+ self.pushButton_inputID = QtWidgets.QPushButton(Form)
401
+ self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
402
+ self.pushButton_inputID.setObjectName("pushButton_inputID")
403
+
404
+ self.text_seed = QtWidgets.QTextEdit(Form)
405
+ self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
406
+ self.text_seed.setObjectName("text_seed")
407
+ self.text_seed.setPlainText("0")
408
+
409
+ self.label_seed = QtWidgets.QLabel(Form)
410
+ self.label_seed.setText("Seed")
411
+ self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
412
+
413
+ # self.pushButton_inverse = QtWidgets.QPushButton(Form)
414
+ # self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
415
+ # self.pushButton_inverse.setObjectName("pushButton_inverse")
416
+
417
+ # self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
418
+ # self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
419
+ # self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
420
+ self.pushButton_get = QtWidgets.QPushButton(Form)
421
+ self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
422
+ self.pushButton_get.setObjectName("pushButton_get")
423
+
424
+
425
+
426
+
427
+ self.graphicsView = QtWidgets.QGraphicsView(Form)
428
+ self.graphicsView.setGeometry(QtCore.QRect(120, 120, 512, 512))
429
+ self.graphicsView.setObjectName("graphicsView")
430
+ self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
431
+ self.graphicsView_2.setGeometry(QtCore.QRect(820, 120, 512, 512))
432
+ self.graphicsView_2.setObjectName("graphicsView_2")
433
+ # self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
434
+ # self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
435
+ # self.graphicsView_3.setObjectName("graphicsView_3")
436
+
437
+ # self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
438
+ # self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
439
+ # self.graphicsView_5.setObjectName("graphicsView_5")
440
+ # self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
441
+ # self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
442
+ # self.graphicsView_6.setObjectName("graphicsView_6")
443
+
444
+
445
+ self.retranslateUi(Form)
446
+ self.pushButton.clicked.connect(Form.generateAndReconstruct)
447
+ self.pushButton_2.clicked.connect(Form.open)
448
+ self.pushButton_3.clicked.connect(Form.open_mask)
449
+ self.pushButton_4.clicked.connect(Form.clear)
450
+ self.pushButton_5.clicked.connect(Form.undo)
451
+ self.pushButton_6.clicked.connect(Form.save_img)
452
+ self.pushButton_7.clicked.connect(Form.bg_mode)
453
+ self.pushButton_8.clicked.connect(Form.skin_mode)
454
+ # self.pushButton_9.clicked.connect(Form.nose_mode)
455
+ # self.pushButton_10.clicked.connect(Form.eye_g_mode)
456
+ # self.pushButton_11.clicked.connect(Form.l_eye_mode)
457
+ # self.pushButton_12.clicked.connect(Form.r_eye_mode)
458
+ # self.pushButton_13.clicked.connect(Form.l_brow_mode)
459
+ # self.pushButton_14.clicked.connect(Form.r_brow_mode)
460
+ # self.pushButton_15.clicked.connect(Form.l_ear_mode)
461
+ # self.pushButton_16.clicked.connect(Form.r_ear_mode)
462
+ # self.pushButton_17.clicked.connect(Form.mouth_mode)
463
+ # self.pushButton_18.clicked.connect(Form.u_lip_mode)
464
+ # self.pushButton_19.clicked.connect(Form.l_lip_mode)
465
+ # self.pushButton_20.clicked.connect(Form.hair_mode)
466
+ # self.pushButton_21.clicked.connect(Form.hat_mode)
467
+ # self.pushButton_22.clicked.connect(Form.ear_r_mode)
468
+ # self.pushButton_23.clicked.connect(Form.neck_l_mode)
469
+ # self.pushButton_24.clicked.connect(Form.neck_mode)
470
+ # self.pushButton_25.clicked.connect(Form.cloth_mode)
471
+ # self.pushButton_26.clicked.connect(Form.increase)
472
+ # self.pushButton_27.clicked.connect(Form.decrease)
473
+
474
+ self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
475
+ self.slider_yawselect.valueChanged.connect(Form.changeYaw)
476
+ self.slider_pitchselect.valueChanged.connect(Form.changePitch)
477
+ self.slider_rollselect.valueChanged.connect(Form.changeRoll)
478
+ self.slider_truncation.valueChanged.connect(Form.changeTruncation)
479
+
480
+ self.pushButton_inputID.clicked.connect(Form.inputID)
481
+
482
+ # self.pushButton_inverse.clicked.connect(Form.inverse)
483
+ # self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
484
+ self.pushButton_get.clicked.connect(Form.get_mask)
485
+
486
+ QtCore.QMetaObject.connectSlotsByName(Form)
487
+
488
+ def retranslateUi(self, Form):
489
+ _translate = QtCore.QCoreApplication.translate
490
+ Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis (Edge2car)"))
491
+ self.pushButton.setText(_translate("Form", "Generate"))
492
+ self.pushButton_2.setText(_translate("Form", "Open Image"))
493
+ self.pushButton_3.setText(_translate("Form", "Open Mask"))
494
+ self.pushButton_4.setText(_translate("Form", "Clear"))
495
+ self.pushButton_5.setText(_translate("Form", "Undo"))
496
+ self.pushButton_6.setText(_translate("Form", "Save Image"))
497
+ self.pushButton_7.setText(_translate("Form", "BackGround"))
498
+ self.pushButton_8.setText(_translate("Form", "Edge"))
499
+ # self.pushButton_9.setText(_translate("Form", "Nose"))
500
+ # self.pushButton_10.setText(_translate("Form", "Eyeglass"))
501
+ # self.pushButton_11.setText(_translate("Form", "Left Eye"))
502
+ # self.pushButton_12.setText(_translate("Form", "Right Eye"))
503
+ # self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
504
+ # self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
505
+ # self.pushButton_15.setText(_translate("Form", "Left ear"))
506
+ # self.pushButton_16.setText(_translate("Form", "Right ear"))
507
+ # self.pushButton_17.setText(_translate("Form", "Mouth"))
508
+ # self.pushButton_18.setText(_translate("Form", "Upper Lip"))
509
+ # self.pushButton_19.setText(_translate("Form", "Lower Lip"))
510
+ # self.pushButton_20.setText(_translate("Form", "Hair"))
511
+ # self.pushButton_21.setText(_translate("Form", "Hat"))
512
+ # self.pushButton_22.setText(_translate("Form", "Earring"))
513
+ # self.pushButton_23.setText(_translate("Form", "Necklace"))
514
+ # self.pushButton_24.setText(_translate("Form", "Neck"))
515
+ # self.pushButton_25.setText(_translate("Form", "Cloth"))
516
+ # self.pushButton_26.setText(_translate("Form", "+"))
517
+ # self.pushButton_27.setText(_translate("Form", "-"))
518
+ self.pushButton_inputID.setText(_translate("Form", "Input ID"))
519
+ # self.pushButton_inverse.setText(_translate("Form", "Inverse"))
520
+ # self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
521
+ self.pushButton_get.setText(_translate("Form", "Get"))
522
+
523
+
524
+ class Ui_Form_Seg2cat(object):
525
+ def setupUi(self, Form):
526
+ Form.setObjectName("Form")
527
+ Form.resize(1800, 660)
528
+ self.pushButton = QtWidgets.QPushButton(Form)
529
+ # self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
530
+ self.pushButton.setGeometry(QtCore.QRect(685, 360, 81, 27))
531
+ self.pushButton.setObjectName("pushButton")
532
+ self.pushButton_2 = QtWidgets.QPushButton(Form)
533
+ self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
534
+ self.pushButton_2.setObjectName("pushButton_2")
535
+ self.pushButton_3 = QtWidgets.QPushButton(Form)
536
+ self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
537
+ self.pushButton_3.setObjectName("pushButton_3")
538
+ self.pushButton_4 = QtWidgets.QPushButton(Form)
539
+ self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
540
+ self.pushButton_4.setObjectName("pushButton_4")
541
+ self.pushButton_5 = QtWidgets.QPushButton(Form)
542
+ self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
543
+ self.pushButton_5.setObjectName("pushButton_5")
544
+ self.pushButton_6 = QtWidgets.QPushButton(Form)
545
+ self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
546
+ self.pushButton_6.setObjectName("pushButton_6")
547
+ self.pushButton_7 = QtWidgets.QPushButton(Form)
548
+ self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
549
+ self.pushButton_7.setObjectName("pushButton_7")
550
+ self.pushButton_8 = QtWidgets.QPushButton(Form)
551
+ self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
552
+ self.pushButton_8.setObjectName("pushButton_8")
553
+ self.pushButton_9 = QtWidgets.QPushButton(Form)
554
+ self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
555
+ self.pushButton_9.setObjectName("pushButton_9")
556
+ self.pushButton_10 = QtWidgets.QPushButton(Form)
557
+ self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
558
+ self.pushButton_10.setObjectName("pushButton_10")
559
+ self.pushButton_11 = QtWidgets.QPushButton(Form)
560
+ self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
561
+ self.pushButton_11.setObjectName("pushButton_11")
562
+ self.pushButton_12 = QtWidgets.QPushButton(Form)
563
+ self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
564
+ self.pushButton_12.setObjectName("pushButton_12")
565
+ # self.pushButton_13 = QtWidgets.QPushButton(Form)
566
+ # self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
567
+ # self.pushButton_13.setObjectName("pushButton_13")
568
+ # self.pushButton_14 = QtWidgets.QPushButton(Form)
569
+ # self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
570
+ # self.pushButton_14.setObjectName("pushButton_14")
571
+ # self.pushButton_15 = QtWidgets.QPushButton(Form)
572
+ # self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
573
+ # self.pushButton_15.setObjectName("pushButton_15")
574
+ # self.pushButton_16 = QtWidgets.QPushButton(Form)
575
+ # self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
576
+ # self.pushButton_16.setObjectName("pushButton_16")
577
+ # self.pushButton_17 = QtWidgets.QPushButton(Form)
578
+ # self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
579
+ # self.pushButton_17.setObjectName("pushButton_17")
580
+ # self.pushButton_18 = QtWidgets.QPushButton(Form)
581
+ # self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
582
+ # self.pushButton_18.setObjectName("pushButton_18")
583
+ # self.pushButton_19 = QtWidgets.QPushButton(Form)
584
+ # self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
585
+ # self.pushButton_19.setObjectName("pushButton_19")
586
+ # self.pushButton_20 = QtWidgets.QPushButton(Form)
587
+ # self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
588
+ # self.pushButton_20.setObjectName("pushButton_20")
589
+ # self.pushButton_21 = QtWidgets.QPushButton(Form)
590
+ # self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
591
+ # self.pushButton_21.setObjectName("pushButton_21")
592
+ # self.pushButton_22 = QtWidgets.QPushButton(Form)
593
+ # self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
594
+ # self.pushButton_22.setObjectName("pushButton_22")
595
+ # self.pushButton_23 = QtWidgets.QPushButton(Form)
596
+ # self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
597
+ # self.pushButton_23.setObjectName("pushButton_23")
598
+ # self.pushButton_24 = QtWidgets.QPushButton(Form)
599
+ # self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
600
+ # self.pushButton_24.setObjectName("pushButton_24")
601
+ # self.pushButton_25 = QtWidgets.QPushButton(Form)
602
+ # self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
603
+ # self.pushButton_25.setObjectName("pushButton_25")
604
+ # self.pushButton_26 = QtWidgets.QPushButton(Form)
605
+ # self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
606
+ # self.pushButton_26.setObjectName("pushButton_26")
607
+ # self.pushButton_27 = QtWidgets.QPushButton(Form)
608
+ # self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
609
+ # self.pushButton_27.setObjectName("pushButton_27")
610
+
611
+ self.slider_sizeselect = QtWidgets.QSlider(Form)
612
+ self.slider_sizeselect.setRange(10,70)
613
+ self.slider_sizeselect.setOrientation(Qt.Horizontal)
614
+ self.slider_sizeselect.setValue(Form.size)
615
+ self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 100, 97, 27))
616
+
617
+ self.label_sizeselect = QtWidgets.QLabel(Form)
618
+ self.label_sizeselect.setText("Brush Size")
619
+ self.label_sizeselect.setGeometry(QtCore.QRect(1630, 100, 97, 27))
620
+
621
+ self.slider_yawselect = QtWidgets.QSlider(Form)
622
+ self.slider_yawselect.setRange(-100,100)
623
+ self.slider_yawselect.setOrientation(Qt.Horizontal)
624
+ self.slider_yawselect.setValue(Form.yaw)
625
+ self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
626
+
627
+ self.label_yawselect = QtWidgets.QLabel(Form)
628
+ self.label_yawselect.setText("Yaw")
629
+ self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
630
+
631
+ self.slider_pitchselect = QtWidgets.QSlider(Form)
632
+ self.slider_pitchselect.setRange(-100,100)
633
+ self.slider_pitchselect.setOrientation(Qt.Horizontal)
634
+ self.slider_pitchselect.setValue(Form.pitch)
635
+ self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
636
+
637
+ self.label_pitchselect = QtWidgets.QLabel(Form)
638
+ self.label_pitchselect.setText("Pitch")
639
+ self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
640
+
641
+ self.slider_rollselect = QtWidgets.QSlider(Form)
642
+ self.slider_rollselect.setRange(-100,100)
643
+ self.slider_rollselect.setOrientation(Qt.Horizontal)
644
+ self.slider_rollselect.setValue(Form.roll)
645
+ self.slider_rollselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
646
+
647
+ self.label_rollselect = QtWidgets.QLabel(Form)
648
+ self.label_rollselect.setText("Roll")
649
+ self.label_rollselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
650
+
651
+
652
+ self.slider_truncation = QtWidgets.QSlider(Form)
653
+ self.slider_truncation.setRange(0,100)
654
+ self.slider_truncation.setOrientation(Qt.Horizontal)
655
+ self.slider_truncation.setValue(Form.truncation)
656
+ self.slider_truncation.setGeometry(QtCore.QRect(1530, 130, 97, 27))
657
+
658
+ self.label_truncation = QtWidgets.QLabel(Form)
659
+ self.label_truncation.setText("Truncation")
660
+ self.label_truncation.setGeometry(QtCore.QRect(1630, 130, 97, 27))
661
+
662
+ self.text_inputID = QtWidgets.QTextEdit(Form)
663
+ self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
664
+ self.text_inputID.setObjectName("text_inputID")
665
+
666
+ self.pushButton_inputID = QtWidgets.QPushButton(Form)
667
+ self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
668
+ self.pushButton_inputID.setObjectName("pushButton_inputID")
669
+
670
+ self.text_seed = QtWidgets.QTextEdit(Form)
671
+ self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
672
+ self.text_seed.setObjectName("text_seed")
673
+ self.text_seed.setPlainText("0")
674
+
675
+ self.label_seed = QtWidgets.QLabel(Form)
676
+ self.label_seed.setText("Seed")
677
+ self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
678
+
679
+ # self.pushButton_inverse = QtWidgets.QPushButton(Form)
680
+ # self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
681
+ # self.pushButton_inverse.setObjectName("pushButton_inverse")
682
+
683
+ # self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
684
+ # self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
685
+ # self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
686
+ self.pushButton_get = QtWidgets.QPushButton(Form)
687
+ self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
688
+ self.pushButton_get.setObjectName("pushButton_get")
689
+
690
+
691
+
692
+
693
+ self.graphicsView = QtWidgets.QGraphicsView(Form)
694
+ self.graphicsView.setGeometry(QtCore.QRect(120, 120, 512, 512))
695
+ self.graphicsView.setObjectName("graphicsView")
696
+ self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
697
+ self.graphicsView_2.setGeometry(QtCore.QRect(820, 120, 512, 512))
698
+ self.graphicsView_2.setObjectName("graphicsView_2")
699
+ # self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
700
+ # self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
701
+ # self.graphicsView_3.setObjectName("graphicsView_3")
702
+
703
+ # self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
704
+ # self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
705
+ # self.graphicsView_5.setObjectName("graphicsView_5")
706
+ # self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
707
+ # self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
708
+ # self.graphicsView_6.setObjectName("graphicsView_6")
709
+
710
+
711
+ self.retranslateUi(Form)
712
+ self.pushButton.clicked.connect(Form.generateAndReconstruct)
713
+ self.pushButton_2.clicked.connect(Form.open)
714
+ self.pushButton_3.clicked.connect(Form.open_mask)
715
+ self.pushButton_4.clicked.connect(Form.clear)
716
+ self.pushButton_5.clicked.connect(Form.undo)
717
+ self.pushButton_6.clicked.connect(Form.save_img)
718
+ self.pushButton_7.clicked.connect(Form.bg_mode)
719
+ self.pushButton_8.clicked.connect(Form.skin_mode)
720
+ self.pushButton_9.clicked.connect(Form.nose_mode)
721
+ self.pushButton_10.clicked.connect(Form.eye_g_mode)
722
+ self.pushButton_11.clicked.connect(Form.l_eye_mode)
723
+ self.pushButton_12.clicked.connect(Form.r_eye_mode)
724
+ # self.pushButton_13.clicked.connect(Form.l_brow_mode)
725
+ # self.pushButton_14.clicked.connect(Form.r_brow_mode)
726
+ # self.pushButton_15.clicked.connect(Form.l_ear_mode)
727
+ # self.pushButton_16.clicked.connect(Form.r_ear_mode)
728
+ # self.pushButton_17.clicked.connect(Form.mouth_mode)
729
+ # self.pushButton_18.clicked.connect(Form.u_lip_mode)
730
+ # self.pushButton_19.clicked.connect(Form.l_lip_mode)
731
+ # self.pushButton_20.clicked.connect(Form.hair_mode)
732
+ # self.pushButton_21.clicked.connect(Form.hat_mode)
733
+ # self.pushButton_22.clicked.connect(Form.ear_r_mode)
734
+ # self.pushButton_23.clicked.connect(Form.neck_l_mode)
735
+ # self.pushButton_24.clicked.connect(Form.neck_mode)
736
+ # self.pushButton_25.clicked.connect(Form.cloth_mode)
737
+ # self.pushButton_26.clicked.connect(Form.increase)
738
+ # self.pushButton_27.clicked.connect(Form.decrease)
739
+
740
+ self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
741
+ self.slider_yawselect.valueChanged.connect(Form.changeYaw)
742
+ self.slider_pitchselect.valueChanged.connect(Form.changePitch)
743
+ self.slider_rollselect.valueChanged.connect(Form.changeRoll)
744
+ self.slider_truncation.valueChanged.connect(Form.changeTruncation)
745
+
746
+ self.pushButton_inputID.clicked.connect(Form.inputID)
747
+
748
+ # self.pushButton_inverse.clicked.connect(Form.inverse)
749
+ # self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
750
+ self.pushButton_get.clicked.connect(Form.get_mask)
751
+
752
+ QtCore.QMetaObject.connectSlotsByName(Form)
753
+
754
+ def retranslateUi(self, Form):
755
+ _translate = QtCore.QCoreApplication.translate
756
+ Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis (Seg2cat)"))
757
+ self.pushButton.setText(_translate("Form", "Generate"))
758
+ self.pushButton_2.setText(_translate("Form", "Open Image"))
759
+ self.pushButton_3.setText(_translate("Form", "Open Mask"))
760
+ self.pushButton_4.setText(_translate("Form", "Clear"))
761
+ self.pushButton_5.setText(_translate("Form", "Undo"))
762
+ self.pushButton_6.setText(_translate("Form", "Save Image"))
763
+ self.pushButton_7.setText(_translate("Form", "BackGround"))
764
+ self.pushButton_8.setText(_translate("Form", "Face"))
765
+ self.pushButton_9.setText(_translate("Form", "Ear"))
766
+ self.pushButton_10.setText(_translate("Form", "Mouth"))
767
+ self.pushButton_11.setText(_translate("Form", "Eyes"))
768
+ self.pushButton_12.setText(_translate("Form", "Whiskers"))
769
+ # self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
770
+ # self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
771
+ # self.pushButton_15.setText(_translate("Form", "Left ear"))
772
+ # self.pushButton_16.setText(_translate("Form", "Right ear"))
773
+ # self.pushButton_17.setText(_translate("Form", "Mouth"))
774
+ # self.pushButton_18.setText(_translate("Form", "Upper Lip"))
775
+ # self.pushButton_19.setText(_translate("Form", "Lower Lip"))
776
+ # self.pushButton_20.setText(_translate("Form", "Hair"))
777
+ # self.pushButton_21.setText(_translate("Form", "Hat"))
778
+ # self.pushButton_22.setText(_translate("Form", "Earring"))
779
+ # self.pushButton_23.setText(_translate("Form", "Necklace"))
780
+ # self.pushButton_24.setText(_translate("Form", "Neck"))
781
+ # self.pushButton_25.setText(_translate("Form", "Cloth"))
782
+ # self.pushButton_26.setText(_translate("Form", "+"))
783
+ # self.pushButton_27.setText(_translate("Form", "-"))
784
+ self.pushButton_inputID.setText(_translate("Form", "Input ID"))
785
+ # self.pushButton_inverse.setText(_translate("Form", "Inverse"))
786
+ # self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
787
+ self.pushButton_get.setText(_translate("Form", "Get"))
788
+
789
+ if __name__ == "__main__":
790
+ import sys
791
+ app = QtWidgets.QApplication(sys.argv)
792
+ Form = QtWidgets.QWidget()
793
+ ui = Ui_Form()
794
+ ui.setupUi(Form)
795
+ Form.show()
796
+ sys.exit(app.exec_())
797
+
pix2pix3D-main/pix2pix3D-main/applications/edge2cat.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
pix2pix3D-main/pix2pix3D-main/applications/extract_mesh.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+
4
+ import os
5
+ import re
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import click
9
+ import dnnlib
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+
16
+ import legacy
17
+ from camera_utils import LookAtPoseSampler
18
+
19
+ from matplotlib import pyplot as plt
20
+
21
+ from pathlib import Path
22
+
23
+ import json
24
+
25
+ from training.utils import color_mask, color_list
26
+
27
+ from tqdm import tqdm
28
+
29
+ import imageio
30
+
31
+ import argparse
32
+
33
+ import trimesh
34
+ import pyrender
35
+ import mcubes
36
+
37
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
38
+
39
+ def init_conditional_dataset_kwargs(data, mask_data, data_type, resolution=None):
40
+ try:
41
+ if data_type =='seg':
42
+ dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageSegFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False, resolution=resolution)
43
+ dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
44
+ dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
45
+ dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
46
+ dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
47
+ return dataset_kwargs, dataset_obj.name
48
+ elif data_type == 'edge':
49
+ dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageEdgeFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False)
50
+ dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
51
+ dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
52
+ dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
53
+ dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
54
+ return dataset_kwargs, dataset_obj.name
55
+ else:
56
+ raise click.ClickException(f'Unknown data_type: {data_type}')
57
+ except IOError as err:
58
+ raise click.ClickException(f'--data: {err}')
59
+
60
+ def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64):
61
+ # return numpy array of forwarded sigma value
62
+ # bound = (nerf.rendering_kwargs['ray_end'] - nerf.rendering_kwargs['ray_start']) * 0.5
63
+ bound = nerf.rendering_kwargs['box_warp'] * 0.5
64
+ X = torch.linspace(-bound, bound, resolution).split(block_resolution)
65
+
66
+ sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32)
67
+
68
+ for xi, xs in enumerate(X):
69
+ for yi, ys in enumerate(X):
70
+ for zi, zs in enumerate(X):
71
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
72
+ pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C
73
+ block_shape = [1, len(xs), len(ys), len(zs)]
74
+ out = nerf.sample_mixed(pts.reshape(1,-1,3), None, ws=styles, noise_mode='const')
75
+ feat_out, sigma_out = out['rgb'], out['sigma']
76
+ sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \
77
+ yi * block_resolution: yi * block_resolution + len(ys), \
78
+ zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy()
79
+ # print(feat_out.shape)
80
+
81
+ return sigma_np, bound
82
+
83
+
84
+ def extract_geometry(nerf, styles, resolution, threshold):
85
+
86
+ # print('threshold: {}'.format(threshold))
87
+ u, bound = get_sigma_field_np(nerf, styles, resolution)
88
+ vertices, faces = mcubes.marching_cubes(u, threshold)
89
+ # vertices, faces, normals, values = skimage.measure.marching_cubes(
90
+ # u, level=10
91
+ # )
92
+ b_min_np = np.array([-bound, -bound, -bound])
93
+ b_max_np = np.array([ bound, bound, bound])
94
+
95
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
96
+ return vertices.astype('float32'), faces
97
+
98
+
99
+
100
+ def main():
101
+ # Parse arguments
102
+ parser = argparse.ArgumentParser(description='Generate samples from a trained model')
103
+ parser.add_argument('--network', help='Path to the network pickle file', required=True)
104
+ parser.add_argument('--outdir', help='Directory to save the output', required=True)
105
+
106
+ parser.add_argument('--input_id', type=int, default=0, help='Input label map id', required=False)
107
+ parser.add_argument('--data_dir', default='data/', help='Directory to the data', required=False)
108
+ parser.add_argument('--input', help='input label map', required=False)
109
+ parser.add_argument('--cfg', help='Base Configuration: seg2face, seg2cat, edge2car', required=True)
110
+ args = parser.parse_args()
111
+ device = 'cuda'
112
+
113
+ # Load the network
114
+ with dnnlib.util.open_url(args.network) as f:
115
+ G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
116
+
117
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
118
+ neural_rendering_resolution = 128
119
+ pitch_range, yaw_range = 0.25, 0.35
120
+ data_type = 'seg'
121
+ # Initialize pose sampler.
122
+ forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
123
+ radius=G.rendering_kwargs['avg_camera_radius'], device=device)
124
+ focal_length = 4.2647 # shapenet has higher FOV
125
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
126
+ forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
127
+ elif args.cfg == 'edge2car':
128
+ neural_rendering_resolution = 64
129
+ pitch_range, yaw_range = np.pi / 2, np.pi
130
+ data_type= 'edge'
131
+
132
+ forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
133
+ radius=G.rendering_kwargs['avg_camera_radius'], device=device)
134
+ focal_length = 1.7074 # shapenet has higher FOV
135
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
136
+ forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
137
+ else:
138
+ print('Invalid cfg')
139
+ return
140
+
141
+ save_dir = Path(args.outdir)
142
+
143
+ # Load the input label map
144
+ if args.input is not None:
145
+ input_label = PIL.Image.open(args.input)
146
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
147
+ input_label = np.array(input_label).astype(np.uint8)
148
+ input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
149
+
150
+ # Save the visualized input label map
151
+ PIL.Image.fromarray(color_mask(input_label[0,0].cpu().numpy()).astype(np.uint8)).save(save_dir / f'{args.cfg}_input.png')
152
+ elif args.cfg == 'edge2car':
153
+ input_label = np.array(input_label).astype(np.float32)[..., 0]
154
+ input_label = -(torch.tensor(input_label).to(torch.float32) / 127.5 - 1).unsqueeze(0).unsqueeze(0).to(device)
155
+ input_pose = forward_pose.to(device)
156
+
157
+ elif args.input_id is not None:
158
+ if args.cfg == 'seg2cat':
159
+ data_path = Path(args.data_dir) / 'afhq_v2_train_cat_512.zip'
160
+ mask_data = Path(args.data_dir) / 'afhqcat_seg_6c.zip'
161
+ elif args.cfg == 'edge2car':
162
+ data_path = Path(args.data_dir) / 'cars_128.zip'
163
+ mask_data = Path(args.data_dir) / 'shapenet_car_contour.zip'
164
+ elif args.cfg == 'seg2face':
165
+ data_path = Path(args.data_dir) / 'celebamask_test.zip'
166
+ mask_data = Path(args.data_dir) / 'celebamask_test_label.zip'
167
+
168
+ dataset_kwargs, dataset_name = init_conditional_dataset_kwargs(str(data_path), str(mask_data), data_type)
169
+ dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
170
+ batch = dataset[args.input_id]
171
+
172
+ save_dir = Path(args.outdir)
173
+
174
+ # Save the input label map
175
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
176
+ PIL.Image.fromarray(color_mask(batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
177
+ elif args.cfg == 'edge2car':
178
+ PIL.Image.fromarray((255 - batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
179
+
180
+ input_pose = torch.tensor(batch['pose']).unsqueeze(0).to(device)
181
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
182
+ input_label = torch.tensor(batch['mask']).unsqueeze(0).to(device)
183
+ elif args.cfg == 'edge2car':
184
+ input_label = -(torch.tensor(batch['mask']).to(torch.float32) / 127.5 - 1).unsqueeze(0).to(device)
185
+
186
+ # Generate videos
187
+ z = torch.from_numpy(np.random.RandomState(int(0)).randn(1, G.z_dim).astype('float32')).to(device)
188
+
189
+ with torch.no_grad():
190
+ ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose})
191
+
192
+ mesh_trimesh = trimesh.Trimesh(*extract_geometry(G, ws, resolution=512, threshold=50.))
193
+
194
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
195
+
196
+ verts_np = np.array(mesh_trimesh.vertices)
197
+ colors = torch.zeros((verts_np.shape[0], 3), device=device)
198
+ semantic_colors = torch.zeros((verts_np.shape[0], 6), device=device)
199
+ samples_color = torch.tensor(verts_np, device=device).unsqueeze(0).float()
200
+
201
+ head = 0
202
+ max_batch = 10000000
203
+ with tqdm(total = verts_np.shape[0]) as pbar:
204
+ with torch.no_grad():
205
+ while head < verts_np.shape[0]:
206
+ torch.manual_seed(0)
207
+ out = G.sample_mixed(samples_color[:, head:head+max_batch], None, ws, truncation_psi=1, noise_mode='const')
208
+ # sigma = out['sigma']
209
+ colors[head:head+max_batch, :] = out['rgb'][0,:,:3]
210
+ seg = out['rgb'][0, :, 32:32+6]
211
+ semantic_colors[head:head+max_batch, :] = seg
212
+ # semantics[:, head:head+max_batch] = out['semantic']
213
+ head += max_batch
214
+ pbar.update(max_batch)
215
+
216
+ semantic_colors = torch.tensor(color_list)[torch.argmax(semantic_colors, dim=-1)]
217
+
218
+ mesh_trimesh.visual.vertex_colors = semantic_colors.cpu().numpy().astype(np.uint8)
219
+
220
+ # Save mesh.
221
+ mesh_trimesh.export(os.path.join(save_dir, f'semantic_mesh.ply'))
222
+ elif args.cfg == 'edge2car':
223
+ # Save mesh.
224
+ mesh_trimesh.export(os.path.join(save_dir, f'{args.cfg}_mesh.ply'))
225
+
226
+ mesh = pyrender.Mesh.from_trimesh(mesh_trimesh)
227
+ light = pyrender.SpotLight(color=np.ones(3), intensity=3.0,
228
+ innerConeAngle=np.pi/4)
229
+ r = pyrender.OffscreenRenderer(512, 512)
230
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
231
+ camera = pyrender.OrthographicCamera(xmag=0.3, ymag=0.3)
232
+
233
+ elif args.cfg == 'edge2car':
234
+ camera = pyrender.OrthographicCamera(xmag=0.6, ymag=0.6)
235
+
236
+
237
+ frames_mesh = []
238
+ num_frames = 120
239
+
240
+ for frame_idx in tqdm(range(num_frames)):
241
+ scene = pyrender.Scene()
242
+ scene.add(mesh)
243
+
244
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
245
+ camera_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
246
+ 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
247
+ torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=1, device=device)
248
+ elif args.cfg == 'edge2car':
249
+ camera_pose = LookAtPoseSampler.sample(-3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
250
+ 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
251
+ torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=1.2, device=device)
252
+ camera_pose = camera_pose.reshape(4, 4).cpu().numpy().copy()
253
+ camera_pose[:, 1] = -camera_pose[:, 1]
254
+ camera_pose[:, 2] = -camera_pose[:, 2]
255
+
256
+ scene.add(camera, pose=camera_pose)
257
+ scene.add(light, pose=camera_pose)
258
+ color, depth = r.render(scene)
259
+ frames_mesh.append(color)
260
+
261
+ imageio.mimsave(os.path.join(save_dir, f'rendered_mesh.gif'), frames_mesh, fps=60)
262
+ r.delete()
263
+
264
+
265
+
266
+ if __name__ == '__main__':
267
+ main()
pix2pix3D-main/pix2pix3D-main/applications/generate_samples.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+
4
+ import os
5
+ import re
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import click
9
+ import dnnlib
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+
16
+ import legacy
17
+
18
+ from matplotlib import pyplot as plt
19
+
20
+ from pathlib import Path
21
+
22
+ import json
23
+
24
+ from training.utils import color_mask, color_list
25
+
26
+ from tqdm import tqdm
27
+
28
+ import argparse
29
+
30
+ def init_conditional_dataset_kwargs(data, mask_data, data_type, resolution=None):
31
+ try:
32
+ if data_type =='seg':
33
+ dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageSegFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False, resolution=resolution)
34
+ dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
35
+ dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
36
+ dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
37
+ dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
38
+ return dataset_kwargs, dataset_obj.name
39
+ elif data_type == 'edge':
40
+ dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageEdgeFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False)
41
+ dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
42
+ dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
43
+ dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
44
+ dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
45
+ return dataset_kwargs, dataset_obj.name
46
+ else:
47
+ raise click.ClickException(f'Unknown data_type: {data_type}')
48
+ except IOError as err:
49
+ raise click.ClickException(f'--data: {err}')
50
+
51
+ def main():
52
+ # Parse arguments
53
+ parser = argparse.ArgumentParser(description='Generate samples from a trained model')
54
+ parser.add_argument('--network', help='Path to the network pickle file', required=True)
55
+ parser.add_argument('--outdir', help='Directory to save the output', required=True)
56
+ # Define an argument of a list of random seeds
57
+ parser.add_argument('--random_seed', help='Random seed', nargs="+", type=int)
58
+
59
+ parser.add_argument('--input_id', type=int, default=0, help='Input label map id', required=True)
60
+ parser.add_argument('--data_dir', default='data/', help='Directory to the data', required=False)
61
+ parser.add_argument('--cfg', help='Base Configuration: seg2face, seg2cat, edge2car', required=True)
62
+ args = parser.parse_args()
63
+ device = 'cuda'
64
+
65
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
66
+ neural_rendering_resolution = 128
67
+ data_type = 'seg'
68
+ elif args.cfg == 'edge2car':
69
+ neural_rendering_resolution = 64
70
+ data_type= 'edge'
71
+ else:
72
+ print('Invalid cfg')
73
+ return
74
+
75
+ # Load the network
76
+ with dnnlib.util.open_url(args.network) as f:
77
+ G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
78
+
79
+ # Load the input label map
80
+ # Initialize dataset.
81
+ if args.cfg == 'seg2cat':
82
+ data_path = Path(args.data_dir) / 'afhq_v2_train_cat_512.zip'
83
+ mask_data = Path(args.data_dir) / 'afhqcat_seg_6c.zip'
84
+ elif args.cfg == 'edge2car':
85
+ data_path = Path(args.data_dir) / 'cars_128.zip'
86
+ mask_data = Path(args.data_dir) / 'shapenet_car_contour.zip'
87
+ elif args.cfg == 'seg2face':
88
+ data_path = Path(args.data_dir) / 'celebamask_test.zip'
89
+ mask_data = Path(args.data_dir) / 'celebamask_test_label.zip'
90
+
91
+ dataset_kwargs, dataset_name = init_conditional_dataset_kwargs(str(data_path), str(mask_data), data_type)
92
+ dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
93
+ batch = dataset[args.input_id]
94
+
95
+ save_dir = Path(args.outdir)
96
+
97
+ # Save the input label map
98
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
99
+ PIL.Image.fromarray(color_mask(batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
100
+ elif args.cfg == 'edge2car':
101
+ PIL.Image.fromarray((255 - batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
102
+
103
+ # Generate samples
104
+ for seed in args.random_seed:
105
+ z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, G.z_dim).astype('float32')).to(device)
106
+ input_pose = torch.tensor(batch['pose']).unsqueeze(0).to(device)
107
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
108
+ input_label = torch.tensor(batch['mask']).unsqueeze(0).to(device)
109
+ elif args.cfg == 'edge2car':
110
+ input_label = -(torch.tensor(batch['mask']).to(torch.float32) / 127.5 - 1).unsqueeze(0).to(device)
111
+
112
+ with torch.no_grad():
113
+ ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose})
114
+ out = G.synthesis(ws, input_pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
115
+
116
+ image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8)
117
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
118
+ image_label = color_mask(torch.argmax(out['semantic'][0], dim=0).cpu().numpy()).astype(np.uint8)
119
+ elif args.cfg == 'edge2car':
120
+ image_label = ((out['semantic'][0].cpu().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8)[0]
121
+
122
+ PIL.Image.fromarray(image_color).save(save_dir / f'{args.cfg}_{args.input_id}_{seed}_color.png')
123
+ PIL.Image.fromarray(image_label).save(save_dir / f'{args.cfg}_{args.input_id}_{seed}_label.png')
124
+
125
+
126
+
127
+ if __name__ == '__main__':
128
+ main()
pix2pix3D-main/pix2pix3D-main/applications/generate_video.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+
4
+ import os
5
+ import re
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import click
9
+ import dnnlib
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+
16
+ import legacy
17
+ from camera_utils import LookAtPoseSampler
18
+
19
+ from matplotlib import pyplot as plt
20
+
21
+ from pathlib import Path
22
+
23
+ import json
24
+
25
+ from training.utils import color_mask, color_list
26
+
27
+ from tqdm import tqdm
28
+
29
+ import imageio
30
+
31
+ import argparse
32
+
33
+ def init_conditional_dataset_kwargs(data, mask_data, data_type, resolution=None):
34
+ try:
35
+ if data_type =='seg':
36
+ dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageSegFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False, resolution=resolution)
37
+ dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
38
+ dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
39
+ dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
40
+ dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
41
+ return dataset_kwargs, dataset_obj.name
42
+ elif data_type == 'edge':
43
+ dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageEdgeFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False)
44
+ dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
45
+ dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
46
+ dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
47
+ dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
48
+ return dataset_kwargs, dataset_obj.name
49
+ else:
50
+ raise click.ClickException(f'Unknown data_type: {data_type}')
51
+ except IOError as err:
52
+ raise click.ClickException(f'--data: {err}')
53
+
54
+ def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution = 128, device='cuda'):
55
+ frames, frames_label = [], []
56
+
57
+ for frame_idx in tqdm(range(num_frames)):
58
+ cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
59
+ 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
60
+ torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
61
+ pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
62
+ with torch.no_grad():
63
+ out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
64
+ # frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
65
+ image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8)
66
+ frames.append(image_color)
67
+ frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8))
68
+
69
+ return frames, frames_label
70
+
71
+ def render_video_edge(G, ws, intrinsics, num_frames = 120, pitch_range = np.pi / 2, yaw_range = np.pi, neural_rendering_resolution = 64, device='cuda'):
72
+ frames, frames_label = [], []
73
+
74
+ for frame_idx in tqdm(range(num_frames)):
75
+ cam2world_pose = LookAtPoseSampler.sample(-3.14/2 + yaw_range * np.cos(2 * 3.14 * frame_idx / num_frames),
76
+ 3.14/2 -0.05 + pitch_range * np.sin(2 * 3.14 * frame_idx / num_frames),
77
+ torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
78
+ pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
79
+ with torch.no_grad():
80
+ out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
81
+ frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
82
+ frames_label.append(((out['semantic'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8)[0])
83
+
84
+ return frames, frames_label
85
+
86
+ def render_video_edge2cat(G, ws, intrinsics, num_frames = 120, pitch_range = np.pi / 2, yaw_range = np.pi, neural_rendering_resolution = 64, device='cuda'):
87
+ frames, frames_label = [], []
88
+
89
+ for frame_idx in tqdm(range(num_frames)):
90
+ cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
91
+ 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
92
+ torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
93
+ pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
94
+ with torch.no_grad():
95
+ out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
96
+ frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
97
+ frames_label.append(((out['semantic'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8)[0])
98
+
99
+ return frames, frames_label
100
+
101
+ def main():
102
+ # Parse arguments
103
+ parser = argparse.ArgumentParser(description='Generate samples from a trained model')
104
+ parser.add_argument('--network', help='Path to the network pickle file', required=True)
105
+ parser.add_argument('--outdir', help='Directory to save the output', required=True)
106
+ # Define an argument of a list of random seeds
107
+ parser.add_argument('--random_seed', help='Random seed', nargs="+", type=int)
108
+
109
+ parser.add_argument('--input_id', type=int, default=0, help='Input label map id', required=False)
110
+ parser.add_argument('--data_dir', default='data/', help='Directory to the data', required=False)
111
+ parser.add_argument('--input', help='input label map', required=False)
112
+ parser.add_argument('--cfg', help='Base Configuration: seg2face, seg2cat, edge2car', required=True)
113
+ args = parser.parse_args()
114
+ device = 'cuda'
115
+
116
+ # Load the network
117
+ with dnnlib.util.open_url(args.network) as f:
118
+ G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
119
+
120
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face' or args.cfg == 'edge2cat':
121
+ neural_rendering_resolution = 128
122
+ pitch_range, yaw_range = 0.25, 0.35
123
+ data_type = 'seg'
124
+ # Initialize pose sampler.
125
+ forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
126
+ radius=G.rendering_kwargs['avg_camera_radius'], device=device)
127
+ focal_length = 4.2647
128
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
129
+ forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
130
+ elif args.cfg == 'edge2car':
131
+ neural_rendering_resolution = 64
132
+ pitch_range, yaw_range = np.pi / 2, np.pi
133
+ data_type= 'edge'
134
+
135
+ forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
136
+ radius=G.rendering_kwargs['avg_camera_radius'], device=device)
137
+ focal_length = 1.7074 # shapenet has higher FOV
138
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
139
+ forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
140
+ else:
141
+ print('Invalid cfg')
142
+ return
143
+
144
+ save_dir = Path(args.outdir)
145
+
146
+ # Load the input label map
147
+ if args.input is not None:
148
+ input_label = PIL.Image.open(args.input)
149
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
150
+ input_label = np.array(input_label).astype(np.uint8)
151
+ input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
152
+
153
+ # Save the visualized input label map
154
+ PIL.Image.fromarray(color_mask(input_label[0,0].cpu().numpy()).astype(np.uint8)).save(save_dir / f'{args.cfg}_input.png')
155
+ elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
156
+ input_label = np.array(input_label).astype(np.float32)
157
+ if input_label.ndim == 3:
158
+ input_label = input_label[:,:,0]
159
+ print(input_label.min(), input_label.max())
160
+ input_label = (torch.tensor(input_label).to(torch.float32) / 127.5 - 1).unsqueeze(0).unsqueeze(0).to(device)
161
+ plt.imshow(input_label.cpu().numpy()[0,0], cmap='gray')
162
+ plt.savefig(save_dir / f'{args.cfg}_input.png')
163
+
164
+ input_pose = forward_pose.to(device)
165
+
166
+ elif args.input_id is not None:
167
+ if args.cfg == 'seg2cat':
168
+ data_path = Path(args.data_dir) / 'afhq_v2_train_cat_512.zip'
169
+ mask_data = Path(args.data_dir) / 'afhqcat_seg_6c.zip'
170
+ elif args.cfg == 'edge2car':
171
+ data_path = Path(args.data_dir) / 'cars_128.zip'
172
+ mask_data = Path(args.data_dir) / 'shapenet_car_contour.zip'
173
+ elif args.cfg == 'seg2face':
174
+ # data_path = Path(args.data_dir) / 'celebamask_test.zip'
175
+ # mask_data = Path(args.data_dir) / 'celebamask_test_label.zip'
176
+ data_path = '/data2/datasets/CelebAMask_eg3d/test/celebamask_test.zip'
177
+ mask_data = '/data2/datasets/CelebAMask_eg3d/test/celebamask_test_label.zip'
178
+ elif args.cfg == 'edge2cat':
179
+ data_path = '/data2/datasets/AFHQ_eg3d/afhq_v2_train_cat_512.zip'
180
+ mask_data = '/data2/datasets/AFHQ_eg3d/afhqcat_contour_pidinet.zip'
181
+
182
+ dataset_kwargs, dataset_name = init_conditional_dataset_kwargs(str(data_path), str(mask_data), data_type)
183
+ dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
184
+ batch = dataset[args.input_id]
185
+
186
+ save_dir = Path(args.outdir)
187
+
188
+ # Save the input label map
189
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
190
+ PIL.Image.fromarray(color_mask(batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
191
+ elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
192
+ PIL.Image.fromarray((255 - batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
193
+
194
+ input_pose = torch.tensor(batch['pose']).unsqueeze(0).to(device)
195
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
196
+ input_label = torch.tensor(batch['mask']).unsqueeze(0).to(device)
197
+ elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
198
+ input_label = -(torch.tensor(batch['mask']).to(torch.float32) / 127.5 - 1).unsqueeze(0).to(device)
199
+
200
+ # Generate videos
201
+ for seed in args.random_seed:
202
+ z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, G.z_dim).astype('float32')).to(device)
203
+
204
+ with torch.no_grad():
205
+ ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose})
206
+
207
+ # Generate the video
208
+ if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
209
+ frames, frames_label = render_video(G, ws, intrinsics, num_frames = 120, pitch_range = pitch_range, yaw_range = yaw_range, neural_rendering_resolution=neural_rendering_resolution, device=device)
210
+ elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
211
+ frames, frames_label = render_video_edge2cat(G, ws, intrinsics, num_frames = 120, pitch_range = pitch_range, yaw_range = yaw_range, neural_rendering_resolution=neural_rendering_resolution, device=device)
212
+
213
+ # Save the video
214
+ imageio.mimsave(save_dir / f'{args.cfg}_{seed}.gif', frames, fps=60)
215
+ imageio.mimsave(save_dir / f'{args.cfg}_{seed}_label.gif', frames_label, fps=60)
216
+
217
+
218
+
219
+ if __name__ == '__main__':
220
+ main()
pix2pix3D-main/pix2pix3D-main/assets/demo.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:021005482c8f89c0d62b62709020c7e7e3de05152132d2c1dfa1b9256962fe59
3
+ size 6362125
pix2pix3D-main/pix2pix3D-main/assets/rendered_mesh_colored.gif ADDED

Git LFS Details

  • SHA256: 156ad91c72c8df2deed08706d7dfecdfbb018d1c1991c6a42d21edd682119146
  • Pointer size: 133 Bytes
  • Size of remote file: 18.9 MB
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1.gif ADDED

Git LFS Details

  • SHA256: f2e96e43f1e9efa5aba6b952088595d005c106578e59a9aeb0f91118c03089c4
  • Pointer size: 133 Bytes
  • Size of remote file: 19 MB
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_color.png ADDED

Git LFS Details

  • SHA256: f53bf27480260852a5b7f887ae6eb8eec9d3a8b1ae5368f5c2e31df4c6344b58
  • Pointer size: 131 Bytes
  • Size of remote file: 336 kB
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_label.png ADDED
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_input.png ADDED
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1_label.gif ADDED

Git LFS Details

  • SHA256: 59b169dc13e574ce866e2ec3b9a00f81e91e2b2a7e0a1a7efbcec9075c43c371
  • Pointer size: 131 Bytes
  • Size of remote file: 516 kB
pix2pix3D-main/pix2pix3D-main/assets/teaser_gif.gif ADDED

Git LFS Details

  • SHA256: 8b65919b403d5f07e224ab34d6be6b54b66c5ebba13979dece5ccfbba2de9fe2
  • Pointer size: 133 Bytes
  • Size of remote file: 18.5 MB
pix2pix3D-main/pix2pix3D-main/assets/teaser_jpg.jpg ADDED

Git LFS Details

  • SHA256: a83619219bc66b8e1bb7b5c85d6bd2afd25124a2b98909eb22b06003bddf7837
  • Pointer size: 131 Bytes
  • Size of remote file: 796 kB
pix2pix3D-main/pix2pix3D-main/assets/teaser_png.png ADDED

Git LFS Details

  • SHA256: 0bda9f01a87bcede18cde378f52b7b5da0b8c8021bfcc220b1a87e90d13a836c
  • Pointer size: 132 Bytes
  • Size of remote file: 6.02 MB
pix2pix3D-main/pix2pix3D-main/camera_utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """
12
+ Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
13
+ """
14
+
15
+ import math
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from training.volumetric_rendering import math_utils
21
+
22
+ class GaussianCameraPoseSampler:
23
+ """
24
+ Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
25
+ Camera is specified as looking at the origin.
26
+ If horizontal and vertical stddev (specified in radians) are zero, gives a
27
+ deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
28
+ The coordinate system is specified with y-up, z-forward, x-left.
29
+ Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
30
+ vertical mean is the polar angle (angle from the y axis) in radians.
31
+ A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
32
+
33
+ Example:
34
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
35
+ cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
36
+ """
37
+
38
+ @staticmethod
39
+ def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
40
+ h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
41
+ v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
42
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
43
+
44
+ theta = h
45
+ v = v / math.pi
46
+ phi = torch.arccos(1 - 2*v)
47
+
48
+ camera_origins = torch.zeros((batch_size, 3), device=device)
49
+
50
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
51
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
52
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
53
+
54
+ forward_vectors = math_utils.normalize_vecs(-camera_origins)
55
+ return create_cam2world_matrix(forward_vectors, camera_origins)
56
+
57
+
58
+ class LookAtPoseSampler:
59
+ """
60
+ Same as GaussianCameraPoseSampler, except the
61
+ camera is specified as looking at 'lookat_position', a 3-vector.
62
+
63
+ Example:
64
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
65
+ cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
66
+ """
67
+
68
+ @staticmethod
69
+ def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
70
+ h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
71
+ v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
72
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
73
+
74
+ theta = h
75
+ v = v / math.pi
76
+ phi = torch.arccos(1 - 2*v)
77
+
78
+ camera_origins = torch.zeros((batch_size, 3), device=device)
79
+
80
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
81
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
82
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
83
+
84
+ # forward_vectors = math_utils.normalize_vecs(-camera_origins)
85
+ forward_vectors = math_utils.normalize_vecs(lookat_position - camera_origins)
86
+ return create_cam2world_matrix(forward_vectors, camera_origins)
87
+
88
+ class UniformCameraPoseSampler:
89
+ """
90
+ Same as GaussianCameraPoseSampler, except the
91
+ pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev.
92
+
93
+ Example:
94
+ For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
95
+
96
+ cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
97
+ """
98
+
99
+ @staticmethod
100
+ def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
101
+ h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean
102
+ v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean
103
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
104
+
105
+ theta = h
106
+ v = v / math.pi
107
+ phi = torch.arccos(1 - 2*v)
108
+
109
+ camera_origins = torch.zeros((batch_size, 3), device=device)
110
+
111
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
112
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
113
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
114
+
115
+ forward_vectors = math_utils.normalize_vecs(-camera_origins)
116
+ return create_cam2world_matrix(forward_vectors, camera_origins)
117
+
118
+ def create_cam2world_matrix(forward_vector, origin):
119
+ """
120
+ Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
121
+ Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll.
122
+ """
123
+
124
+ forward_vector = math_utils.normalize_vecs(forward_vector)
125
+ up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector)
126
+
127
+ right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
128
+ up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1))
129
+
130
+ rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
131
+ rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1)
132
+
133
+ translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
134
+ translation_matrix[:, :3, 3] = origin
135
+ cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
136
+ assert(cam2world.shape[1:] == (4, 4))
137
+ return cam2world
138
+
139
+
140
+ def FOV_to_intrinsics(fov_degrees, device='cpu'):
141
+ """
142
+ Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
143
+ Note the intrinsics are returned as normalized by image size, rather than in pixel units.
144
+ Assumes principal point is at image center.
145
+ """
146
+
147
+ focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
148
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
149
+ return intrinsics
pix2pix3D-main/pix2pix3D-main/checkpoints/download_models.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ cd checkpoints
2
+ wget http://cs.cmu.edu/~pix2pix3D/release.tar
3
+ tar -xvf release.tar
4
+ rm release.tar
5
+ cd ..
pix2pix3D-main/pix2pix3D-main/dnnlib/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ from .util import EasyDict, make_cache_dir_path
pix2pix3D-main/pix2pix3D-main/dnnlib/util.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Miscellaneous utility classes and functions."""
12
+
13
+ import ctypes
14
+ import fnmatch
15
+ import importlib
16
+ import inspect
17
+ import numpy as np
18
+ import os
19
+ import shutil
20
+ import sys
21
+ import types
22
+ import io
23
+ import pickle
24
+ import re
25
+ import requests
26
+ import html
27
+ import hashlib
28
+ import glob
29
+ import tempfile
30
+ import urllib
31
+ import urllib.request
32
+ import uuid
33
+
34
+ from distutils.util import strtobool
35
+ from typing import Any, List, Tuple, Union
36
+
37
+
38
+ # Util classes
39
+ # ------------------------------------------------------------------------------------------
40
+
41
+
42
+ class EasyDict(dict):
43
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
44
+
45
+ def __getattr__(self, name: str) -> Any:
46
+ try:
47
+ return self[name]
48
+ except KeyError:
49
+ raise AttributeError(name)
50
+
51
+ def __setattr__(self, name: str, value: Any) -> None:
52
+ self[name] = value
53
+
54
+ def __delattr__(self, name: str) -> None:
55
+ del self[name]
56
+
57
+
58
+ class Logger(object):
59
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
60
+
61
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
62
+ self.file = None
63
+
64
+ if file_name is not None:
65
+ self.file = open(file_name, file_mode)
66
+
67
+ self.should_flush = should_flush
68
+ self.stdout = sys.stdout
69
+ self.stderr = sys.stderr
70
+
71
+ sys.stdout = self
72
+ sys.stderr = self
73
+
74
+ def __enter__(self) -> "Logger":
75
+ return self
76
+
77
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
78
+ self.close()
79
+
80
+ def write(self, text: Union[str, bytes]) -> None:
81
+ """Write text to stdout (and a file) and optionally flush."""
82
+ if isinstance(text, bytes):
83
+ text = text.decode()
84
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
85
+ return
86
+
87
+ if self.file is not None:
88
+ self.file.write(text)
89
+
90
+ self.stdout.write(text)
91
+
92
+ if self.should_flush:
93
+ self.flush()
94
+
95
+ def flush(self) -> None:
96
+ """Flush written text to both stdout and a file, if open."""
97
+ if self.file is not None:
98
+ self.file.flush()
99
+
100
+ self.stdout.flush()
101
+
102
+ def close(self) -> None:
103
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
104
+ self.flush()
105
+
106
+ # if using multiple loggers, prevent closing in wrong order
107
+ if sys.stdout is self:
108
+ sys.stdout = self.stdout
109
+ if sys.stderr is self:
110
+ sys.stderr = self.stderr
111
+
112
+ if self.file is not None:
113
+ self.file.close()
114
+ self.file = None
115
+
116
+
117
+ # Cache directories
118
+ # ------------------------------------------------------------------------------------------
119
+
120
+ _dnnlib_cache_dir = None
121
+
122
+ def set_cache_dir(path: str) -> None:
123
+ global _dnnlib_cache_dir
124
+ _dnnlib_cache_dir = path
125
+
126
+ def make_cache_dir_path(*paths: str) -> str:
127
+ if _dnnlib_cache_dir is not None:
128
+ return os.path.join(_dnnlib_cache_dir, *paths)
129
+ if 'DNNLIB_CACHE_DIR' in os.environ:
130
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
131
+ if 'HOME' in os.environ:
132
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
133
+ if 'USERPROFILE' in os.environ:
134
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
135
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
136
+
137
+ # Small util functions
138
+ # ------------------------------------------------------------------------------------------
139
+
140
+
141
+ def format_time(seconds: Union[int, float]) -> str:
142
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
143
+ s = int(np.rint(seconds))
144
+
145
+ if s < 60:
146
+ return "{0}s".format(s)
147
+ elif s < 60 * 60:
148
+ return "{0}m {1:02}s".format(s // 60, s % 60)
149
+ elif s < 24 * 60 * 60:
150
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
151
+ else:
152
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
153
+
154
+
155
+ def format_time_brief(seconds: Union[int, float]) -> str:
156
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
157
+ s = int(np.rint(seconds))
158
+
159
+ if s < 60:
160
+ return "{0}s".format(s)
161
+ elif s < 60 * 60:
162
+ return "{0}m {1:02}s".format(s // 60, s % 60)
163
+ elif s < 24 * 60 * 60:
164
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
165
+ else:
166
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
167
+
168
+
169
+ def ask_yes_no(question: str) -> bool:
170
+ """Ask the user the question until the user inputs a valid answer."""
171
+ while True:
172
+ try:
173
+ print("{0} [y/n]".format(question))
174
+ return strtobool(input().lower())
175
+ except ValueError:
176
+ pass
177
+
178
+
179
+ def tuple_product(t: Tuple) -> Any:
180
+ """Calculate the product of the tuple elements."""
181
+ result = 1
182
+
183
+ for v in t:
184
+ result *= v
185
+
186
+ return result
187
+
188
+
189
+ _str_to_ctype = {
190
+ "uint8": ctypes.c_ubyte,
191
+ "uint16": ctypes.c_uint16,
192
+ "uint32": ctypes.c_uint32,
193
+ "uint64": ctypes.c_uint64,
194
+ "int8": ctypes.c_byte,
195
+ "int16": ctypes.c_int16,
196
+ "int32": ctypes.c_int32,
197
+ "int64": ctypes.c_int64,
198
+ "float32": ctypes.c_float,
199
+ "float64": ctypes.c_double
200
+ }
201
+
202
+
203
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
204
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
205
+ type_str = None
206
+
207
+ if isinstance(type_obj, str):
208
+ type_str = type_obj
209
+ elif hasattr(type_obj, "__name__"):
210
+ type_str = type_obj.__name__
211
+ elif hasattr(type_obj, "name"):
212
+ type_str = type_obj.name
213
+ else:
214
+ raise RuntimeError("Cannot infer type name from input")
215
+
216
+ assert type_str in _str_to_ctype.keys()
217
+
218
+ my_dtype = np.dtype(type_str)
219
+ my_ctype = _str_to_ctype[type_str]
220
+
221
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
222
+
223
+ return my_dtype, my_ctype
224
+
225
+
226
+ def is_pickleable(obj: Any) -> bool:
227
+ try:
228
+ with io.BytesIO() as stream:
229
+ pickle.dump(obj, stream)
230
+ return True
231
+ except:
232
+ return False
233
+
234
+
235
+ # Functionality to import modules/objects by name, and call functions by name
236
+ # ------------------------------------------------------------------------------------------
237
+
238
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
239
+ """Searches for the underlying module behind the name to some python object.
240
+ Returns the module and the object name (original name with module part removed)."""
241
+
242
+ # allow convenience shorthands, substitute them by full names
243
+ obj_name = re.sub("^np.", "numpy.", obj_name)
244
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
245
+
246
+ # list alternatives for (module_name, local_obj_name)
247
+ parts = obj_name.split(".")
248
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
249
+
250
+ # try each alternative in turn
251
+ for module_name, local_obj_name in name_pairs:
252
+ try:
253
+ module = importlib.import_module(module_name) # may raise ImportError
254
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
255
+ return module, local_obj_name
256
+ except:
257
+ pass
258
+
259
+ # maybe some of the modules themselves contain errors?
260
+ for module_name, _local_obj_name in name_pairs:
261
+ try:
262
+ importlib.import_module(module_name) # may raise ImportError
263
+ except ImportError:
264
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
265
+ raise
266
+
267
+ # maybe the requested attribute is missing?
268
+ for module_name, local_obj_name in name_pairs:
269
+ try:
270
+ module = importlib.import_module(module_name) # may raise ImportError
271
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
272
+ except ImportError:
273
+ pass
274
+
275
+ # we are out of luck, but we have no idea why
276
+ raise ImportError(obj_name)
277
+
278
+
279
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
280
+ """Traverses the object name and returns the last (rightmost) python object."""
281
+ if obj_name == '':
282
+ return module
283
+ obj = module
284
+ for part in obj_name.split("."):
285
+ obj = getattr(obj, part)
286
+ return obj
287
+
288
+
289
+ def get_obj_by_name(name: str) -> Any:
290
+ """Finds the python object with the given name."""
291
+ module, obj_name = get_module_from_obj_name(name)
292
+ return get_obj_from_module(module, obj_name)
293
+
294
+
295
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
296
+ """Finds the python object with the given name and calls it as a function."""
297
+ assert func_name is not None
298
+ func_obj = get_obj_by_name(func_name)
299
+ assert callable(func_obj)
300
+ return func_obj(*args, **kwargs)
301
+
302
+
303
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
304
+ """Finds the python class with the given name and constructs it with the given arguments."""
305
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
306
+
307
+
308
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
309
+ """Get the directory path of the module containing the given object name."""
310
+ module, _ = get_module_from_obj_name(obj_name)
311
+ return os.path.dirname(inspect.getfile(module))
312
+
313
+
314
+ def is_top_level_function(obj: Any) -> bool:
315
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
316
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
317
+
318
+
319
+ def get_top_level_function_name(obj: Any) -> str:
320
+ """Return the fully-qualified name of a top-level function."""
321
+ assert is_top_level_function(obj)
322
+ module = obj.__module__
323
+ if module == '__main__':
324
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
325
+ return module + "." + obj.__name__
326
+
327
+
328
+ # File system helpers
329
+ # ------------------------------------------------------------------------------------------
330
+
331
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
332
+ """List all files recursively in a given directory while ignoring given file and directory names.
333
+ Returns list of tuples containing both absolute and relative paths."""
334
+ assert os.path.isdir(dir_path)
335
+ base_name = os.path.basename(os.path.normpath(dir_path))
336
+
337
+ if ignores is None:
338
+ ignores = []
339
+
340
+ result = []
341
+
342
+ for root, dirs, files in os.walk(dir_path, topdown=True):
343
+ for ignore_ in ignores:
344
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
345
+
346
+ # dirs need to be edited in-place
347
+ for d in dirs_to_remove:
348
+ dirs.remove(d)
349
+
350
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
351
+
352
+ absolute_paths = [os.path.join(root, f) for f in files]
353
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
354
+
355
+ if add_base_to_relative:
356
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
357
+
358
+ assert len(absolute_paths) == len(relative_paths)
359
+ result += zip(absolute_paths, relative_paths)
360
+
361
+ return result
362
+
363
+
364
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
365
+ """Takes in a list of tuples of (src, dst) paths and copies files.
366
+ Will create all necessary directories."""
367
+ for file in files:
368
+ target_dir_name = os.path.dirname(file[1])
369
+
370
+ # will create all intermediate-level directories
371
+ if not os.path.exists(target_dir_name):
372
+ os.makedirs(target_dir_name)
373
+
374
+ shutil.copyfile(file[0], file[1])
375
+
376
+
377
+ # URL helpers
378
+ # ------------------------------------------------------------------------------------------
379
+
380
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
381
+ """Determine whether the given object is a valid URL string."""
382
+ if not isinstance(obj, str) or not "://" in obj:
383
+ return False
384
+ if allow_file_urls and obj.startswith('file://'):
385
+ return True
386
+ try:
387
+ res = requests.compat.urlparse(obj)
388
+ if not res.scheme or not res.netloc or not "." in res.netloc:
389
+ return False
390
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
391
+ if not res.scheme or not res.netloc or not "." in res.netloc:
392
+ return False
393
+ except:
394
+ return False
395
+ return True
396
+
397
+
398
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
399
+ """Download the given URL and return a binary-mode file object to access the data."""
400
+ assert num_attempts >= 1
401
+ assert not (return_filename and (not cache))
402
+
403
+ # Doesn't look like an URL scheme so interpret it as a local filename.
404
+ if not re.match('^[a-z]+://', url):
405
+ return url if return_filename else open(url, "rb")
406
+
407
+ # Handle file URLs. This code handles unusual file:// patterns that
408
+ # arise on Windows:
409
+ #
410
+ # file:///c:/foo.txt
411
+ #
412
+ # which would translate to a local '/c:/foo.txt' filename that's
413
+ # invalid. Drop the forward slash for such pathnames.
414
+ #
415
+ # If you touch this code path, you should test it on both Linux and
416
+ # Windows.
417
+ #
418
+ # Some internet resources suggest using urllib.request.url2pathname() but
419
+ # but that converts forward slashes to backslashes and this causes
420
+ # its own set of problems.
421
+ if url.startswith('file://'):
422
+ filename = urllib.parse.urlparse(url).path
423
+ if re.match(r'^/[a-zA-Z]:', filename):
424
+ filename = filename[1:]
425
+ return filename if return_filename else open(filename, "rb")
426
+
427
+ assert is_url(url)
428
+
429
+ # Lookup from cache.
430
+ if cache_dir is None:
431
+ cache_dir = make_cache_dir_path('downloads')
432
+
433
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
434
+ if cache:
435
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
436
+ if len(cache_files) == 1:
437
+ filename = cache_files[0]
438
+ return filename if return_filename else open(filename, "rb")
439
+
440
+ # Download.
441
+ url_name = None
442
+ url_data = None
443
+ with requests.Session() as session:
444
+ if verbose:
445
+ print("Downloading %s ..." % url, end="", flush=True)
446
+ for attempts_left in reversed(range(num_attempts)):
447
+ try:
448
+ with session.get(url) as res:
449
+ res.raise_for_status()
450
+ if len(res.content) == 0:
451
+ raise IOError("No data received")
452
+
453
+ if len(res.content) < 8192:
454
+ content_str = res.content.decode("utf-8")
455
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
456
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
457
+ if len(links) == 1:
458
+ url = requests.compat.urljoin(url, links[0])
459
+ raise IOError("Google Drive virus checker nag")
460
+ if "Google Drive - Quota exceeded" in content_str:
461
+ raise IOError("Google Drive download quota exceeded -- please try again later")
462
+
463
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
464
+ url_name = match[1] if match else url
465
+ url_data = res.content
466
+ if verbose:
467
+ print(" done")
468
+ break
469
+ except KeyboardInterrupt:
470
+ raise
471
+ except:
472
+ if not attempts_left:
473
+ if verbose:
474
+ print(" failed")
475
+ raise
476
+ if verbose:
477
+ print(".", end="", flush=True)
478
+
479
+ # Save to cache.
480
+ if cache:
481
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
482
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
483
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
484
+ os.makedirs(cache_dir, exist_ok=True)
485
+ with open(temp_file, "wb") as f:
486
+ f.write(url_data)
487
+ os.replace(temp_file, cache_file) # atomic
488
+ if return_filename:
489
+ return cache_file
490
+
491
+ # Return data as file object.
492
+ assert not return_filename
493
+ return io.BytesIO(url_data)
pix2pix3D-main/pix2pix3D-main/environment.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ name: pix2pix3d
12
+ channels:
13
+ - pytorch
14
+ - nvidia
15
+ dependencies:
16
+ - python >= 3.8
17
+ - pip
18
+ - numpy>=1.20
19
+ - click>=8.0
20
+ - pillow=8.3.1
21
+ - scipy=1.7.1
22
+ - pytorch=1.11.0
23
+ - cudatoolkit=11.1
24
+ - requests=2.26.0
25
+ - tqdm=4.62.2
26
+ - ninja=1.10.2
27
+ - matplotlib=3.4.2
28
+ - imageio=2.9.0
29
+ - pip:
30
+ - imgui==1.3.0
31
+ - glfw==2.2.0
32
+ - pyopengl==3.1.5
33
+ - imageio-ffmpeg==0.4.3
34
+ - pyspng
35
+ - psutil
36
+ - mrcfile
37
+ - tensorboard
38
+ - einops
39
+ - opencv-python
pix2pix3D-main/pix2pix3D-main/examples/example_input.png ADDED
pix2pix3D-main/pix2pix3D-main/examples/example_input_edge2car.png ADDED
pix2pix3D-main/pix2pix3D-main/examples/example_input_edge2cat.png ADDED
pix2pix3D-main/pix2pix3D-main/legacy.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Converting legacy network pickle into the new format."""
12
+
13
+ import click
14
+ import pickle
15
+ import re
16
+ import copy
17
+ import numpy as np
18
+ import torch
19
+ import dnnlib
20
+ from torch_utils import misc
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ def load_network_pkl(f, force_fp16=False):
25
+ data = _LegacyUnpickler(f).load()
26
+
27
+ # Legacy TensorFlow pickle => convert.
28
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
29
+ tf_G, tf_D, tf_Gs = data
30
+ G = convert_tf_generator(tf_G)
31
+ D = convert_tf_discriminator(tf_D)
32
+ G_ema = convert_tf_generator(tf_Gs)
33
+ data = dict(G=G, D=D, G_ema=G_ema)
34
+
35
+ # Add missing fields.
36
+ if 'training_set_kwargs' not in data:
37
+ data['training_set_kwargs'] = None
38
+ if 'augment_pipe' not in data:
39
+ data['augment_pipe'] = None
40
+
41
+ # Validate contents.
42
+ assert isinstance(data['G'], torch.nn.Module)
43
+ assert isinstance(data['D'], torch.nn.Module)
44
+ assert isinstance(data['G_ema'], torch.nn.Module)
45
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
46
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
47
+
48
+ # Force FP16.
49
+ if force_fp16:
50
+ for key in ['G', 'D', 'G_ema']:
51
+ old = data[key]
52
+ kwargs = copy.deepcopy(old.init_kwargs)
53
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
54
+ fp16_kwargs.num_fp16_res = 4
55
+ fp16_kwargs.conv_clamp = 256
56
+ if kwargs != old.init_kwargs:
57
+ new = type(old)(**kwargs).eval().requires_grad_(False)
58
+ misc.copy_params_and_buffers(old, new, require_all=True)
59
+ data[key] = new
60
+ return data
61
+
62
+ #----------------------------------------------------------------------------
63
+
64
+ class _TFNetworkStub(dnnlib.EasyDict):
65
+ pass
66
+
67
+ class _LegacyUnpickler(pickle.Unpickler):
68
+ def find_class(self, module, name):
69
+ if module == 'dnnlib.tflib.network' and name == 'Network':
70
+ return _TFNetworkStub
71
+ return super().find_class(module, name)
72
+
73
+ #----------------------------------------------------------------------------
74
+
75
+ def _collect_tf_params(tf_net):
76
+ # pylint: disable=protected-access
77
+ tf_params = dict()
78
+ def recurse(prefix, tf_net):
79
+ for name, value in tf_net.variables:
80
+ tf_params[prefix + name] = value
81
+ for name, comp in tf_net.components.items():
82
+ recurse(prefix + name + '/', comp)
83
+ recurse('', tf_net)
84
+ return tf_params
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def _populate_module_params(module, *patterns):
89
+ for name, tensor in misc.named_params_and_buffers(module):
90
+ found = False
91
+ value = None
92
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93
+ match = re.fullmatch(pattern, name)
94
+ if match:
95
+ found = True
96
+ if value_fn is not None:
97
+ value = value_fn(*match.groups())
98
+ break
99
+ try:
100
+ assert found
101
+ if value is not None:
102
+ tensor.copy_(torch.from_numpy(np.array(value)))
103
+ except:
104
+ print(name, list(tensor.shape))
105
+ raise
106
+
107
+ #----------------------------------------------------------------------------
108
+
109
+ def convert_tf_generator(tf_G):
110
+ if tf_G.version < 4:
111
+ raise ValueError('TensorFlow pickle version too low')
112
+
113
+ # Collect kwargs.
114
+ tf_kwargs = tf_G.static_kwargs
115
+ known_kwargs = set()
116
+ def kwarg(tf_name, default=None, none=None):
117
+ known_kwargs.add(tf_name)
118
+ val = tf_kwargs.get(tf_name, default)
119
+ return val if val is not None else none
120
+
121
+ # Convert kwargs.
122
+ from training import networks_stylegan2
123
+ network_class = networks_stylegan2.Generator
124
+ kwargs = dnnlib.EasyDict(
125
+ z_dim = kwarg('latent_size', 512),
126
+ c_dim = kwarg('label_size', 0),
127
+ w_dim = kwarg('dlatent_size', 512),
128
+ img_resolution = kwarg('resolution', 1024),
129
+ img_channels = kwarg('num_channels', 3),
130
+ channel_base = kwarg('fmap_base', 16384) * 2,
131
+ channel_max = kwarg('fmap_max', 512),
132
+ num_fp16_res = kwarg('num_fp16_res', 0),
133
+ conv_clamp = kwarg('conv_clamp', None),
134
+ architecture = kwarg('architecture', 'skip'),
135
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
136
+ use_noise = kwarg('use_noise', True),
137
+ activation = kwarg('nonlinearity', 'lrelu'),
138
+ mapping_kwargs = dnnlib.EasyDict(
139
+ num_layers = kwarg('mapping_layers', 8),
140
+ embed_features = kwarg('label_fmaps', None),
141
+ layer_features = kwarg('mapping_fmaps', None),
142
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
143
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
144
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
145
+ ),
146
+ )
147
+
148
+ # Check for unknown kwargs.
149
+ kwarg('truncation_psi')
150
+ kwarg('truncation_cutoff')
151
+ kwarg('style_mixing_prob')
152
+ kwarg('structure')
153
+ kwarg('conditioning')
154
+ kwarg('fused_modconv')
155
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
156
+ if len(unknown_kwargs) > 0:
157
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
158
+
159
+ # Collect params.
160
+ tf_params = _collect_tf_params(tf_G)
161
+ for name, value in list(tf_params.items()):
162
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
163
+ if match:
164
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
165
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
166
+ kwargs.synthesis.kwargs.architecture = 'orig'
167
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
168
+
169
+ # Convert params.
170
+ G = network_class(**kwargs).eval().requires_grad_(False)
171
+ # pylint: disable=unnecessary-lambda
172
+ # pylint: disable=f-string-without-interpolation
173
+ _populate_module_params(G,
174
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
175
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
176
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
177
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
178
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
179
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
180
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
181
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
182
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
183
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
184
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
185
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
186
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
187
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
188
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
189
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
190
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
191
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
192
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
193
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
194
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
195
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
196
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
197
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
198
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
199
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
200
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
201
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
202
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
203
+ r'.*\.resample_filter', None,
204
+ r'.*\.act_filter', None,
205
+ )
206
+ return G
207
+
208
+ #----------------------------------------------------------------------------
209
+
210
+ def convert_tf_discriminator(tf_D):
211
+ if tf_D.version < 4:
212
+ raise ValueError('TensorFlow pickle version too low')
213
+
214
+ # Collect kwargs.
215
+ tf_kwargs = tf_D.static_kwargs
216
+ known_kwargs = set()
217
+ def kwarg(tf_name, default=None):
218
+ known_kwargs.add(tf_name)
219
+ return tf_kwargs.get(tf_name, default)
220
+
221
+ # Convert kwargs.
222
+ kwargs = dnnlib.EasyDict(
223
+ c_dim = kwarg('label_size', 0),
224
+ img_resolution = kwarg('resolution', 1024),
225
+ img_channels = kwarg('num_channels', 3),
226
+ architecture = kwarg('architecture', 'resnet'),
227
+ channel_base = kwarg('fmap_base', 16384) * 2,
228
+ channel_max = kwarg('fmap_max', 512),
229
+ num_fp16_res = kwarg('num_fp16_res', 0),
230
+ conv_clamp = kwarg('conv_clamp', None),
231
+ cmap_dim = kwarg('mapping_fmaps', None),
232
+ block_kwargs = dnnlib.EasyDict(
233
+ activation = kwarg('nonlinearity', 'lrelu'),
234
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
235
+ freeze_layers = kwarg('freeze_layers', 0),
236
+ ),
237
+ mapping_kwargs = dnnlib.EasyDict(
238
+ num_layers = kwarg('mapping_layers', 0),
239
+ embed_features = kwarg('mapping_fmaps', None),
240
+ layer_features = kwarg('mapping_fmaps', None),
241
+ activation = kwarg('nonlinearity', 'lrelu'),
242
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
243
+ ),
244
+ epilogue_kwargs = dnnlib.EasyDict(
245
+ mbstd_group_size = kwarg('mbstd_group_size', None),
246
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
247
+ activation = kwarg('nonlinearity', 'lrelu'),
248
+ ),
249
+ )
250
+
251
+ # Check for unknown kwargs.
252
+ kwarg('structure')
253
+ kwarg('conditioning')
254
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
255
+ if len(unknown_kwargs) > 0:
256
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
257
+
258
+ # Collect params.
259
+ tf_params = _collect_tf_params(tf_D)
260
+ for name, value in list(tf_params.items()):
261
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
262
+ if match:
263
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
264
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
265
+ kwargs.architecture = 'orig'
266
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
267
+
268
+ # Convert params.
269
+ from training import networks_stylegan2
270
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
271
+ # pylint: disable=unnecessary-lambda
272
+ # pylint: disable=f-string-without-interpolation
273
+ _populate_module_params(D,
274
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
275
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
276
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
277
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
278
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
279
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
280
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
281
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
282
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
283
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
284
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
285
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
286
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
287
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
288
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
289
+ r'.*\.resample_filter', None,
290
+ )
291
+ return D
292
+
293
+ #----------------------------------------------------------------------------
294
+
295
+ @click.command()
296
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
297
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
298
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
299
+ def convert_network_pickle(source, dest, force_fp16):
300
+ """Convert legacy network pickle into the native PyTorch format.
301
+
302
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
303
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
304
+
305
+ Example:
306
+
307
+ \b
308
+ python legacy.py \\
309
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
310
+ --dest=stylegan2-cat-config-f.pkl
311
+ """
312
+ print(f'Loading "{source}"...')
313
+ with dnnlib.util.open_url(source) as f:
314
+ data = load_network_pkl(f, force_fp16=force_fp16)
315
+ print(f'Saving "{dest}"...')
316
+ with open(dest, 'wb') as f:
317
+ pickle.dump(data, f)
318
+ print('Done.')
319
+
320
+ #----------------------------------------------------------------------------
321
+
322
+ if __name__ == "__main__":
323
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
324
+
325
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/metrics/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ # empty
pix2pix3D-main/pix2pix3D-main/metrics/equivariance.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
12
+ "Alias-Free Generative Adversarial Networks"."""
13
+
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import torch.fft
18
+ from torch_utils.ops import upfirdn2d
19
+ from . import metric_utils
20
+
21
+ #----------------------------------------------------------------------------
22
+ # Utilities.
23
+
24
+ def sinc(x):
25
+ y = (x * np.pi).abs()
26
+ z = torch.sin(y) / y.clamp(1e-30, float('inf'))
27
+ return torch.where(y < 1e-30, torch.ones_like(x), z)
28
+
29
+ def lanczos_window(x, a):
30
+ x = x.abs() / a
31
+ return torch.where(x < 1, sinc(x), torch.zeros_like(x))
32
+
33
+ def rotation_matrix(angle):
34
+ angle = torch.as_tensor(angle).to(torch.float32)
35
+ mat = torch.eye(3, device=angle.device)
36
+ mat[0, 0] = angle.cos()
37
+ mat[0, 1] = angle.sin()
38
+ mat[1, 0] = -angle.sin()
39
+ mat[1, 1] = angle.cos()
40
+ return mat
41
+
42
+ #----------------------------------------------------------------------------
43
+ # Apply integer translation to a batch of 2D images. Corresponds to the
44
+ # operator T_x in Appendix E.1.
45
+
46
+ def apply_integer_translation(x, tx, ty):
47
+ _N, _C, H, W = x.shape
48
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
49
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
50
+ ix = tx.round().to(torch.int64)
51
+ iy = ty.round().to(torch.int64)
52
+
53
+ z = torch.zeros_like(x)
54
+ m = torch.zeros_like(x)
55
+ if abs(ix) < W and abs(iy) < H:
56
+ y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
57
+ z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
58
+ m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
59
+ return z, m
60
+
61
+ #----------------------------------------------------------------------------
62
+ # Apply integer translation to a batch of 2D images. Corresponds to the
63
+ # operator T_x in Appendix E.2.
64
+
65
+ def apply_fractional_translation(x, tx, ty, a=3):
66
+ _N, _C, H, W = x.shape
67
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
68
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
69
+ ix = tx.floor().to(torch.int64)
70
+ iy = ty.floor().to(torch.int64)
71
+ fx = tx - ix
72
+ fy = ty - iy
73
+ b = a - 1
74
+
75
+ z = torch.zeros_like(x)
76
+ zx0 = max(ix - b, 0)
77
+ zy0 = max(iy - b, 0)
78
+ zx1 = min(ix + a, 0) + W
79
+ zy1 = min(iy + a, 0) + H
80
+ if zx0 < zx1 and zy0 < zy1:
81
+ taps = torch.arange(a * 2, device=x.device) - b
82
+ filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
83
+ filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
84
+ y = x
85
+ y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
86
+ y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
87
+ y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
88
+ z[:, :, zy0:zy1, zx0:zx1] = y
89
+
90
+ m = torch.zeros_like(x)
91
+ mx0 = max(ix + a, 0)
92
+ my0 = max(iy + a, 0)
93
+ mx1 = min(ix - b, 0) + W
94
+ my1 = min(iy - b, 0) + H
95
+ if mx0 < mx1 and my0 < my1:
96
+ m[:, :, my0:my1, mx0:mx1] = 1
97
+ return z, m
98
+
99
+ #----------------------------------------------------------------------------
100
+ # Construct an oriented low-pass filter that applies the appropriate
101
+ # bandlimit with respect to the input and output of the given affine 2D
102
+ # image transformation.
103
+
104
+ def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
105
+ assert a <= amax < aflt
106
+ mat = torch.as_tensor(mat).to(torch.float32)
107
+
108
+ # Construct 2D filter taps in input & output coordinate spaces.
109
+ taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
110
+ yi, xi = torch.meshgrid(taps, taps)
111
+ xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
112
+
113
+ # Convolution of two oriented 2D sinc filters.
114
+ fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
115
+ fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
116
+ f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
117
+
118
+ # Convolution of two oriented 2D Lanczos windows.
119
+ wi = lanczos_window(xi, a) * lanczos_window(yi, a)
120
+ wo = lanczos_window(xo, a) * lanczos_window(yo, a)
121
+ w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
122
+
123
+ # Construct windowed FIR filter.
124
+ f = f * w
125
+
126
+ # Finalize.
127
+ c = (aflt - amax) * up
128
+ f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
129
+ f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
130
+ f = f / f.sum([0,2], keepdim=True) / (up ** 2)
131
+ f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
132
+ return f
133
+
134
+ #----------------------------------------------------------------------------
135
+ # Apply the given affine transformation to a batch of 2D images.
136
+
137
+ def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
138
+ _N, _C, H, W = x.shape
139
+ mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
140
+
141
+ # Construct filter.
142
+ f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
143
+ assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
144
+ p = f.shape[0] // 2
145
+
146
+ # Construct sampling grid.
147
+ theta = mat.inverse()
148
+ theta[:2, 2] *= 2
149
+ theta[0, 2] += 1 / up / W
150
+ theta[1, 2] += 1 / up / H
151
+ theta[0, :] *= W / (W + p / up * 2)
152
+ theta[1, :] *= H / (H + p / up * 2)
153
+ theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
154
+ g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
155
+
156
+ # Resample image.
157
+ y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
158
+ z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
159
+
160
+ # Form mask.
161
+ m = torch.zeros_like(y)
162
+ c = p * 2 + 1
163
+ m[:, :, c:-c, c:-c] = 1
164
+ m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
165
+ return z, m
166
+
167
+ #----------------------------------------------------------------------------
168
+ # Apply fractional rotation to a batch of 2D images. Corresponds to the
169
+ # operator R_\alpha in Appendix E.3.
170
+
171
+ def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
172
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
173
+ mat = rotation_matrix(angle)
174
+ return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
175
+
176
+ #----------------------------------------------------------------------------
177
+ # Modify the frequency content of a batch of 2D images as if they had undergo
178
+ # fractional rotation -- but without actually rotating them. Corresponds to
179
+ # the operator R^*_\alpha in Appendix E.3.
180
+
181
+ def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
182
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
183
+ mat = rotation_matrix(-angle)
184
+ f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
185
+ y = upfirdn2d.filter2d(x=x, f=f)
186
+ m = torch.zeros_like(y)
187
+ c = f.shape[0] // 2
188
+ m[:, :, c:-c, c:-c] = 1
189
+ return y, m
190
+
191
+ #----------------------------------------------------------------------------
192
+ # Compute the selected equivariance metrics for the given generator.
193
+
194
+ def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
195
+ assert compute_eqt_int or compute_eqt_frac or compute_eqr
196
+
197
+ # Setup generator and labels.
198
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
199
+ I = torch.eye(3, device=opts.device)
200
+ M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
201
+ if M is None:
202
+ raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
203
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
204
+
205
+ # Sampling loop.
206
+ sums = None
207
+ progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
208
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
209
+ progress.update(batch_start)
210
+ s = []
211
+
212
+ # Randomize noise buffers, if any.
213
+ for name, buf in G.named_buffers():
214
+ if name.endswith('.noise_const'):
215
+ buf.copy_(torch.randn_like(buf))
216
+
217
+ # Run mapping network.
218
+ z = torch.randn([batch_size, G.z_dim], device=opts.device)
219
+ c = next(c_iter)
220
+ ws = G.mapping(z=z, c=c)
221
+
222
+ # Generate reference image.
223
+ M[:] = I
224
+ orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
225
+
226
+ # Integer translation (EQ-T).
227
+ if compute_eqt_int:
228
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
229
+ t = (t * G.img_resolution).round() / G.img_resolution
230
+ M[:] = I
231
+ M[:2, 2] = -t
232
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
233
+ ref, mask = apply_integer_translation(orig, t[0], t[1])
234
+ s += [(ref - img).square() * mask, mask]
235
+
236
+ # Fractional translation (EQ-T_frac).
237
+ if compute_eqt_frac:
238
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
239
+ M[:] = I
240
+ M[:2, 2] = -t
241
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
242
+ ref, mask = apply_fractional_translation(orig, t[0], t[1])
243
+ s += [(ref - img).square() * mask, mask]
244
+
245
+ # Rotation (EQ-R).
246
+ if compute_eqr:
247
+ angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
248
+ M[:] = rotation_matrix(-angle)
249
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
250
+ ref, ref_mask = apply_fractional_rotation(orig, angle)
251
+ pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
252
+ mask = ref_mask * pseudo_mask
253
+ s += [(ref - pseudo).square() * mask, mask]
254
+
255
+ # Accumulate results.
256
+ s = torch.stack([x.to(torch.float64).sum() for x in s])
257
+ sums = sums + s if sums is not None else s
258
+ progress.update(num_samples)
259
+
260
+ # Compute PSNRs.
261
+ if opts.num_gpus > 1:
262
+ torch.distributed.all_reduce(sums)
263
+ sums = sums.cpu()
264
+ mses = sums[0::2] / sums[1::2]
265
+ psnrs = np.log10(2) * 20 - mses.log10() * 10
266
+ psnrs = tuple(psnrs.numpy())
267
+ return psnrs[0] if len(psnrs) == 1 else psnrs
268
+
269
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/metrics/frechet_inception_distance.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Frechet Inception Distance (FID) from the paper
12
+ "GANs trained by a two time-scale update rule converge to a local Nash
13
+ equilibrium". Matches the original implementation by Heusel et al. at
14
+ https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
15
+
16
+ import numpy as np
17
+ import scipy.linalg
18
+ from . import metric_utils
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def compute_fid(opts, max_real, num_gen):
23
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
24
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
25
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
26
+
27
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
28
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
30
+
31
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
32
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
33
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
34
+
35
+ if opts.rank != 0:
36
+ return float('nan')
37
+
38
+ m = np.square(mu_gen - mu_real).sum()
39
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
40
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
41
+ return float(fid)
42
+
43
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/metrics/inception_score.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Inception Score (IS) from the paper "Improved techniques for training
12
+ GANs". Matches the original implementation by Salimans et al. at
13
+ https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
14
+
15
+ import numpy as np
16
+ from . import metric_utils
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def compute_is(opts, num_gen, num_splits):
21
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
23
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
24
+
25
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
26
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27
+ capture_all=True, max_items=num_gen).get_all()
28
+
29
+ if opts.rank != 0:
30
+ return float('nan'), float('nan')
31
+
32
+ scores = []
33
+ for i in range(num_splits):
34
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
35
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
36
+ kl = np.mean(np.sum(kl, axis=1))
37
+ scores.append(np.exp(kl))
38
+ return float(np.mean(scores)), float(np.std(scores))
39
+
40
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/metrics/kernel_inception_distance.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Kernel Inception Distance (KID) from the paper "Demystifying MMD
12
+ GANs". Matches the original implementation by Binkowski et al. at
13
+ https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
14
+
15
+ import numpy as np
16
+ from . import metric_utils
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
21
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
23
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24
+
25
+ real_features = metric_utils.compute_feature_stats_for_dataset(
26
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
28
+
29
+ gen_features = metric_utils.compute_feature_stats_for_generator(
30
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
32
+
33
+ if opts.rank != 0:
34
+ return float('nan')
35
+
36
+ n = real_features.shape[1]
37
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
38
+ t = 0
39
+ for _subset_idx in range(num_subsets):
40
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
41
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
42
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
43
+ b = (x @ y.T / n + 1) ** 3
44
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
45
+ kid = t / num_subsets / m
46
+ return float(kid)
47
+
48
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/metrics/metric_main.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Main API for computing and reporting quality metrics."""
12
+
13
+ import os
14
+ import time
15
+ import json
16
+ import torch
17
+ import dnnlib
18
+
19
+ from . import metric_utils
20
+ from . import frechet_inception_distance
21
+ from . import kernel_inception_distance
22
+ from . import precision_recall
23
+ from . import perceptual_path_length
24
+ from . import inception_score
25
+ from . import equivariance
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ _metric_dict = dict() # name => fn
30
+
31
+ def register_metric(fn):
32
+ assert callable(fn)
33
+ _metric_dict[fn.__name__] = fn
34
+ return fn
35
+
36
+ def is_valid_metric(metric):
37
+ return metric in _metric_dict
38
+
39
+ def list_valid_metrics():
40
+ return list(_metric_dict.keys())
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
45
+ assert is_valid_metric(metric)
46
+ opts = metric_utils.MetricOptions(**kwargs)
47
+
48
+ # Calculate.
49
+ start_time = time.time()
50
+ results = _metric_dict[metric](opts)
51
+ total_time = time.time() - start_time
52
+
53
+ # Broadcast results.
54
+ for key, value in list(results.items()):
55
+ if opts.num_gpus > 1:
56
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
57
+ torch.distributed.broadcast(tensor=value, src=0)
58
+ value = float(value.cpu())
59
+ results[key] = value
60
+
61
+ # Decorate with metadata.
62
+ return dnnlib.EasyDict(
63
+ results = dnnlib.EasyDict(results),
64
+ metric = metric,
65
+ total_time = total_time,
66
+ total_time_str = dnnlib.util.format_time(total_time),
67
+ num_gpus = opts.num_gpus,
68
+ )
69
+
70
+ #----------------------------------------------------------------------------
71
+
72
+ def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
73
+ metric = result_dict['metric']
74
+ assert is_valid_metric(metric)
75
+ if run_dir is not None and snapshot_pkl is not None:
76
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
77
+
78
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
79
+ print(jsonl_line)
80
+ if run_dir is not None and os.path.isdir(run_dir):
81
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
82
+ f.write(jsonl_line + '\n')
83
+
84
+ #----------------------------------------------------------------------------
85
+ # Recommended metrics.
86
+
87
+ @register_metric
88
+ def fid50k_full(opts):
89
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
90
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
91
+ return dict(fid50k_full=fid)
92
+
93
+ @register_metric
94
+ def kid50k_full(opts):
95
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
96
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
97
+ return dict(kid50k_full=kid)
98
+
99
+ @register_metric
100
+ def pr50k3_full(opts):
101
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
102
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
103
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
104
+
105
+ @register_metric
106
+ def ppl2_wend(opts):
107
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
108
+ return dict(ppl2_wend=ppl)
109
+
110
+ @register_metric
111
+ def eqt50k_int(opts):
112
+ opts.G_kwargs.update(force_fp32=True)
113
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
114
+ return dict(eqt50k_int=psnr)
115
+
116
+ @register_metric
117
+ def eqt50k_frac(opts):
118
+ opts.G_kwargs.update(force_fp32=True)
119
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
120
+ return dict(eqt50k_frac=psnr)
121
+
122
+ @register_metric
123
+ def eqr50k(opts):
124
+ opts.G_kwargs.update(force_fp32=True)
125
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
126
+ return dict(eqr50k=psnr)
127
+
128
+ #----------------------------------------------------------------------------
129
+ # Legacy metrics.
130
+
131
+ @register_metric
132
+ def fid50k(opts):
133
+ opts.dataset_kwargs.update(max_size=None)
134
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
135
+ return dict(fid50k=fid)
136
+
137
+ @register_metric
138
+ def kid50k(opts):
139
+ opts.dataset_kwargs.update(max_size=None)
140
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
141
+ return dict(kid50k=kid)
142
+
143
+ @register_metric
144
+ def pr50k3(opts):
145
+ opts.dataset_kwargs.update(max_size=None)
146
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
147
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
148
+
149
+ @register_metric
150
+ def is50k(opts):
151
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
152
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
153
+ return dict(is50k_mean=mean, is50k_std=std)
154
+
155
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/metrics/metric_utils.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Miscellaneous utilities used internally by the quality metrics."""
12
+
13
+ import os
14
+ import time
15
+ import hashlib
16
+ import pickle
17
+ import copy
18
+ import uuid
19
+ import numpy as np
20
+ import torch
21
+ import dnnlib
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ class MetricOptions:
26
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
27
+ assert 0 <= rank < num_gpus
28
+ self.G = G
29
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
30
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
31
+ self.num_gpus = num_gpus
32
+ self.rank = rank
33
+ self.device = device if device is not None else torch.device('cuda', rank)
34
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
35
+ self.cache = cache
36
+
37
+ #----------------------------------------------------------------------------
38
+
39
+ _feature_detector_cache = dict()
40
+
41
+ def get_feature_detector_name(url):
42
+ return os.path.splitext(url.split('/')[-1])[0]
43
+
44
+ def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
45
+ assert 0 <= rank < num_gpus
46
+ key = (url, device)
47
+ if key not in _feature_detector_cache:
48
+ is_leader = (rank == 0)
49
+ if not is_leader and num_gpus > 1:
50
+ torch.distributed.barrier() # leader goes first
51
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
52
+ _feature_detector_cache[key] = pickle.load(f).to(device)
53
+ if is_leader and num_gpus > 1:
54
+ torch.distributed.barrier() # others follow
55
+ return _feature_detector_cache[key]
56
+
57
+ #----------------------------------------------------------------------------
58
+
59
+ def iterate_random_labels(opts, batch_size):
60
+ if opts.G.c_dim == 0:
61
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
62
+ while True:
63
+ yield c
64
+ else:
65
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
66
+ while True:
67
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
68
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
69
+ yield c
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ class FeatureStats:
74
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
75
+ self.capture_all = capture_all
76
+ self.capture_mean_cov = capture_mean_cov
77
+ self.max_items = max_items
78
+ self.num_items = 0
79
+ self.num_features = None
80
+ self.all_features = None
81
+ self.raw_mean = None
82
+ self.raw_cov = None
83
+
84
+ def set_num_features(self, num_features):
85
+ if self.num_features is not None:
86
+ assert num_features == self.num_features
87
+ else:
88
+ self.num_features = num_features
89
+ self.all_features = []
90
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
91
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
92
+
93
+ def is_full(self):
94
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
95
+
96
+ def append(self, x):
97
+ x = np.asarray(x, dtype=np.float32)
98
+ assert x.ndim == 2
99
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
100
+ if self.num_items >= self.max_items:
101
+ return
102
+ x = x[:self.max_items - self.num_items]
103
+
104
+ self.set_num_features(x.shape[1])
105
+ self.num_items += x.shape[0]
106
+ if self.capture_all:
107
+ self.all_features.append(x)
108
+ if self.capture_mean_cov:
109
+ x64 = x.astype(np.float64)
110
+ self.raw_mean += x64.sum(axis=0)
111
+ self.raw_cov += x64.T @ x64
112
+
113
+ def append_torch(self, x, num_gpus=1, rank=0):
114
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
115
+ assert 0 <= rank < num_gpus
116
+ if num_gpus > 1:
117
+ ys = []
118
+ for src in range(num_gpus):
119
+ y = x.clone()
120
+ torch.distributed.broadcast(y, src=src)
121
+ ys.append(y)
122
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
123
+ self.append(x.cpu().numpy())
124
+
125
+ def get_all(self):
126
+ assert self.capture_all
127
+ return np.concatenate(self.all_features, axis=0)
128
+
129
+ def get_all_torch(self):
130
+ return torch.from_numpy(self.get_all())
131
+
132
+ def get_mean_cov(self):
133
+ assert self.capture_mean_cov
134
+ mean = self.raw_mean / self.num_items
135
+ cov = self.raw_cov / self.num_items
136
+ cov = cov - np.outer(mean, mean)
137
+ return mean, cov
138
+
139
+ def save(self, pkl_file):
140
+ with open(pkl_file, 'wb') as f:
141
+ pickle.dump(self.__dict__, f)
142
+
143
+ @staticmethod
144
+ def load(pkl_file):
145
+ with open(pkl_file, 'rb') as f:
146
+ s = dnnlib.EasyDict(pickle.load(f))
147
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
148
+ obj.__dict__.update(s)
149
+ return obj
150
+
151
+ #----------------------------------------------------------------------------
152
+
153
+ class ProgressMonitor:
154
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
155
+ self.tag = tag
156
+ self.num_items = num_items
157
+ self.verbose = verbose
158
+ self.flush_interval = flush_interval
159
+ self.progress_fn = progress_fn
160
+ self.pfn_lo = pfn_lo
161
+ self.pfn_hi = pfn_hi
162
+ self.pfn_total = pfn_total
163
+ self.start_time = time.time()
164
+ self.batch_time = self.start_time
165
+ self.batch_items = 0
166
+ if self.progress_fn is not None:
167
+ self.progress_fn(self.pfn_lo, self.pfn_total)
168
+
169
+ def update(self, cur_items):
170
+ assert (self.num_items is None) or (cur_items <= self.num_items)
171
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
172
+ return
173
+ cur_time = time.time()
174
+ total_time = cur_time - self.start_time
175
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
176
+ if (self.verbose) and (self.tag is not None):
177
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
178
+ self.batch_time = cur_time
179
+ self.batch_items = cur_items
180
+
181
+ if (self.progress_fn is not None) and (self.num_items is not None):
182
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
183
+
184
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
185
+ return ProgressMonitor(
186
+ tag = tag,
187
+ num_items = num_items,
188
+ flush_interval = flush_interval,
189
+ verbose = self.verbose,
190
+ progress_fn = self.progress_fn,
191
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
192
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
193
+ pfn_total = self.pfn_total,
194
+ )
195
+
196
+ #----------------------------------------------------------------------------
197
+
198
+ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
199
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
200
+ if data_loader_kwargs is None:
201
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
202
+
203
+ # Try to lookup from cache.
204
+ cache_file = None
205
+ if opts.cache:
206
+ # Choose cache file name.
207
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
208
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
209
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
210
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
211
+
212
+ # Check if the file exists (all processes must agree).
213
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
214
+ if opts.num_gpus > 1:
215
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
216
+ torch.distributed.broadcast(tensor=flag, src=0)
217
+ flag = (float(flag.cpu()) != 0)
218
+
219
+ # Load.
220
+ if flag:
221
+ return FeatureStats.load(cache_file)
222
+
223
+ # Initialize.
224
+ num_items = len(dataset)
225
+ if max_items is not None:
226
+ num_items = min(num_items, max_items)
227
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
228
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
229
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
230
+
231
+ # Main loop.
232
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
233
+ for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
234
+ if images.shape[1] == 1:
235
+ images = images.repeat([1, 3, 1, 1])
236
+ features = detector(images.to(opts.device), **detector_kwargs)
237
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
238
+ progress.update(stats.num_items)
239
+
240
+ # Save to cache.
241
+ if cache_file is not None and opts.rank == 0:
242
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
243
+ temp_file = cache_file + '.' + uuid.uuid4().hex
244
+ stats.save(temp_file)
245
+ os.replace(temp_file, cache_file) # atomic
246
+ return stats
247
+
248
+ #----------------------------------------------------------------------------
249
+
250
+ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs):
251
+ if batch_gen is None:
252
+ batch_gen = min(batch_size, 4)
253
+ assert batch_size % batch_gen == 0
254
+
255
+ # Setup generator and labels.
256
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
257
+ c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
258
+
259
+ # Initialize.
260
+ stats = FeatureStats(**stats_kwargs)
261
+ assert stats.max_items is not None
262
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
263
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
264
+
265
+ # Main loop.
266
+ while not stats.is_full():
267
+ images = []
268
+ for _i in range(batch_size // batch_gen):
269
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
270
+ img = G(z=z, c=next(c_iter), **opts.G_kwargs)['image']
271
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
272
+ images.append(img)
273
+ images = torch.cat(images)
274
+ if images.shape[1] == 1:
275
+ images = images.repeat([1, 3, 1, 1])
276
+ features = detector(images, **detector_kwargs)
277
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
278
+ progress.update(stats.num_items)
279
+ return stats
280
+
281
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/metrics/perceptual_path_length.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
12
+ Architecture for Generative Adversarial Networks". Matches the original
13
+ implementation by Karras et al. at
14
+ https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
15
+
16
+ import copy
17
+ import numpy as np
18
+ import torch
19
+ from . import metric_utils
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ # Spherical interpolation of a batch of vectors.
24
+ def slerp(a, b, t):
25
+ a = a / a.norm(dim=-1, keepdim=True)
26
+ b = b / b.norm(dim=-1, keepdim=True)
27
+ d = (a * b).sum(dim=-1, keepdim=True)
28
+ p = t * torch.acos(d)
29
+ c = b - d * a
30
+ c = c / c.norm(dim=-1, keepdim=True)
31
+ d = a * torch.cos(p) + c * torch.sin(p)
32
+ d = d / d.norm(dim=-1, keepdim=True)
33
+ return d
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ class PPLSampler(torch.nn.Module):
38
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
39
+ assert space in ['z', 'w']
40
+ assert sampling in ['full', 'end']
41
+ super().__init__()
42
+ self.G = copy.deepcopy(G)
43
+ self.G_kwargs = G_kwargs
44
+ self.epsilon = epsilon
45
+ self.space = space
46
+ self.sampling = sampling
47
+ self.crop = crop
48
+ self.vgg16 = copy.deepcopy(vgg16)
49
+
50
+ def forward(self, c):
51
+ # Generate random latents and interpolation t-values.
52
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
53
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
54
+
55
+ # Interpolate in W or Z.
56
+ if self.space == 'w':
57
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
58
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
59
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
60
+ else: # space == 'z'
61
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
62
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
63
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
64
+
65
+ # Randomize noise buffers.
66
+ for name, buf in self.G.named_buffers():
67
+ if name.endswith('.noise_const'):
68
+ buf.copy_(torch.randn_like(buf))
69
+
70
+ # Generate images.
71
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
72
+
73
+ # Center crop.
74
+ if self.crop:
75
+ assert img.shape[2] == img.shape[3]
76
+ c = img.shape[2] // 8
77
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
78
+
79
+ # Downsample to 256x256.
80
+ factor = self.G.img_resolution // 256
81
+ if factor > 1:
82
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
83
+
84
+ # Scale dynamic range from [-1,1] to [0,255].
85
+ img = (img + 1) * (255 / 2)
86
+ if self.G.img_channels == 1:
87
+ img = img.repeat([1, 3, 1, 1])
88
+
89
+ # Evaluate differential LPIPS.
90
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
91
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
92
+ return dist
93
+
94
+ #----------------------------------------------------------------------------
95
+
96
+ def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
97
+ vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
98
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99
+
100
+ # Setup sampler and labels.
101
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102
+ sampler.eval().requires_grad_(False).to(opts.device)
103
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
104
+
105
+ # Sampling loop.
106
+ dist = []
107
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
108
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
109
+ progress.update(batch_start)
110
+ x = sampler(next(c_iter))
111
+ for src in range(opts.num_gpus):
112
+ y = x.clone()
113
+ if opts.num_gpus > 1:
114
+ torch.distributed.broadcast(y, src=src)
115
+ dist.append(y)
116
+ progress.update(num_samples)
117
+
118
+ # Compute PPL.
119
+ if opts.rank != 0:
120
+ return float('nan')
121
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
122
+ lo = np.percentile(dist, 1, interpolation='lower')
123
+ hi = np.percentile(dist, 99, interpolation='higher')
124
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
125
+ return float(ppl)
126
+
127
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/metrics/precision_recall.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Precision/Recall (PR) from the paper "Improved Precision and Recall
12
+ Metric for Assessing Generative Models". Matches the original implementation
13
+ by Kynkaanniemi et al. at
14
+ https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
15
+
16
+ import torch
17
+ from . import metric_utils
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
22
+ assert 0 <= rank < num_gpus
23
+ num_cols = col_features.shape[0]
24
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
25
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
26
+ dist_batches = []
27
+ for col_batch in col_batches[rank :: num_gpus]:
28
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
29
+ for src in range(num_gpus):
30
+ dist_broadcast = dist_batch.clone()
31
+ if num_gpus > 1:
32
+ torch.distributed.broadcast(dist_broadcast, src=src)
33
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
34
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
35
+
36
+ #----------------------------------------------------------------------------
37
+
38
+ def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
39
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
40
+ detector_kwargs = dict(return_features=True)
41
+
42
+ real_features = metric_utils.compute_feature_stats_for_dataset(
43
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
44
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
45
+
46
+ gen_features = metric_utils.compute_feature_stats_for_generator(
47
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
48
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
49
+
50
+ results = dict()
51
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
52
+ kth = []
53
+ for manifold_batch in manifold.split(row_batch_size):
54
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
55
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
56
+ kth = torch.cat(kth) if opts.rank == 0 else None
57
+ pred = []
58
+ for probes_batch in probes.split(row_batch_size):
59
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
60
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
61
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
62
+ return results['precision'], results['recall']
63
+
64
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/torch_utils/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ # empty
pix2pix3D-main/pix2pix3D-main/torch_utils/custom_ops.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import glob
12
+ import hashlib
13
+ import importlib
14
+ import os
15
+ import re
16
+ import shutil
17
+ import uuid
18
+
19
+ import torch
20
+ import torch.utils.cpp_extension
21
+ from torch.utils.file_baton import FileBaton
22
+
23
+ #----------------------------------------------------------------------------
24
+ # Global options.
25
+
26
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
27
+
28
+ #----------------------------------------------------------------------------
29
+ # Internal helper funcs.
30
+
31
+ def _find_compiler_bindir():
32
+ patterns = [
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
34
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
35
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
36
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
37
+ ]
38
+ for pattern in patterns:
39
+ matches = sorted(glob.glob(pattern))
40
+ if len(matches):
41
+ return matches[-1]
42
+ return None
43
+
44
+ #----------------------------------------------------------------------------
45
+
46
+ def _get_mangled_gpu_name():
47
+ name = torch.cuda.get_device_name().lower()
48
+ out = []
49
+ for c in name:
50
+ if re.match('[a-z0-9_-]+', c):
51
+ out.append(c)
52
+ else:
53
+ out.append('-')
54
+ return ''.join(out)
55
+
56
+ #----------------------------------------------------------------------------
57
+ # Main entry point for compiling and loading C++/CUDA plugins.
58
+
59
+ _cached_plugins = dict()
60
+
61
+ def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
62
+ assert verbosity in ['none', 'brief', 'full']
63
+ if headers is None:
64
+ headers = []
65
+ if source_dir is not None:
66
+ sources = [os.path.join(source_dir, fname) for fname in sources]
67
+ headers = [os.path.join(source_dir, fname) for fname in headers]
68
+
69
+ # Already cached?
70
+ if module_name in _cached_plugins:
71
+ return _cached_plugins[module_name]
72
+
73
+ # Print status.
74
+ if verbosity == 'full':
75
+ print(f'Setting up PyTorch plugin "{module_name}"...')
76
+ elif verbosity == 'brief':
77
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
78
+ verbose_build = (verbosity == 'full')
79
+
80
+ # Compile and load.
81
+ try: # pylint: disable=too-many-nested-blocks
82
+ # Make sure we can find the necessary compiler binaries.
83
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
84
+ compiler_bindir = _find_compiler_bindir()
85
+ if compiler_bindir is None:
86
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
87
+ os.environ['PATH'] += ';' + compiler_bindir
88
+
89
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
90
+ # break the build or unnecessarily restrict what's available to nvcc.
91
+ # Unset it to let nvcc decide based on what's available on the
92
+ # machine.
93
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
94
+
95
+ # Incremental build md5sum trickery. Copies all the input source files
96
+ # into a cached build directory under a combined md5 digest of the input
97
+ # source files. Copying is done only if the combined digest has changed.
98
+ # This keeps input file timestamps and filenames the same as in previous
99
+ # extension builds, allowing for fast incremental rebuilds.
100
+ #
101
+ # This optimization is done only in case all the source files reside in
102
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
103
+ # environment variable is set (we take this as a signal that the user
104
+ # actually cares about this.)
105
+ #
106
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
107
+ # around the *.cu dependency bug in ninja config.
108
+ #
109
+ all_source_files = sorted(sources + headers)
110
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
111
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
112
+
113
+ # Compute combined hash digest for all source files.
114
+ hash_md5 = hashlib.md5()
115
+ for src in all_source_files:
116
+ with open(src, 'rb') as f:
117
+ hash_md5.update(f.read())
118
+
119
+ # Select cached build directory name.
120
+ source_digest = hash_md5.hexdigest()
121
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
122
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
123
+
124
+ if not os.path.isdir(cached_build_dir):
125
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
126
+ os.makedirs(tmpdir)
127
+ for src in all_source_files:
128
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
129
+ try:
130
+ os.replace(tmpdir, cached_build_dir) # atomic
131
+ except OSError:
132
+ # source directory already exists, delete tmpdir and its contents.
133
+ shutil.rmtree(tmpdir)
134
+ if not os.path.isdir(cached_build_dir): raise
135
+
136
+ # Compile.
137
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
138
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
139
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
140
+ else:
141
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
142
+
143
+ # Load.
144
+ module = importlib.import_module(module_name)
145
+
146
+ except:
147
+ if verbosity == 'brief':
148
+ print('Failed!')
149
+ raise
150
+
151
+ # Print status and add to cache dict.
152
+ if verbosity == 'full':
153
+ print(f'Done setting up PyTorch plugin "{module_name}".')
154
+ elif verbosity == 'brief':
155
+ print('Done.')
156
+ _cached_plugins[module_name] = module
157
+ return module
158
+
159
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/torch_utils/misc.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import re
12
+ import contextlib
13
+ import numpy as np
14
+ import torch
15
+ import warnings
16
+ import dnnlib
17
+
18
+ #----------------------------------------------------------------------------
19
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
20
+ # same constant is used multiple times.
21
+
22
+ _constant_cache = dict()
23
+
24
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
25
+ value = np.asarray(value)
26
+ if shape is not None:
27
+ shape = tuple(shape)
28
+ if dtype is None:
29
+ dtype = torch.get_default_dtype()
30
+ if device is None:
31
+ device = torch.device('cpu')
32
+ if memory_format is None:
33
+ memory_format = torch.contiguous_format
34
+
35
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
36
+ tensor = _constant_cache.get(key, None)
37
+ if tensor is None:
38
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
39
+ if shape is not None:
40
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
41
+ tensor = tensor.contiguous(memory_format=memory_format)
42
+ _constant_cache[key] = tensor
43
+ return tensor
44
+
45
+ #----------------------------------------------------------------------------
46
+ # Replace NaN/Inf with specified numerical values.
47
+
48
+ try:
49
+ nan_to_num = torch.nan_to_num # 1.8.0a0
50
+ except AttributeError:
51
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
52
+ assert isinstance(input, torch.Tensor)
53
+ if posinf is None:
54
+ posinf = torch.finfo(input.dtype).max
55
+ if neginf is None:
56
+ neginf = torch.finfo(input.dtype).min
57
+ assert nan == 0
58
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
59
+
60
+ #----------------------------------------------------------------------------
61
+ # Symbolic assert.
62
+
63
+ try:
64
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
65
+ except AttributeError:
66
+ symbolic_assert = torch.Assert # 1.7.0
67
+
68
+ #----------------------------------------------------------------------------
69
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
70
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
71
+
72
+ @contextlib.contextmanager
73
+ def suppress_tracer_warnings():
74
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
75
+ warnings.filters.insert(0, flt)
76
+ yield
77
+ warnings.filters.remove(flt)
78
+
79
+ #----------------------------------------------------------------------------
80
+ # Assert that the shape of a tensor matches the given list of integers.
81
+ # None indicates that the size of a dimension is allowed to vary.
82
+ # Performs symbolic assertion when used in torch.jit.trace().
83
+
84
+ def assert_shape(tensor, ref_shape):
85
+ if tensor.ndim != len(ref_shape):
86
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
87
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
88
+ if ref_size is None:
89
+ pass
90
+ elif isinstance(ref_size, torch.Tensor):
91
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
92
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
93
+ elif isinstance(size, torch.Tensor):
94
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
95
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
96
+ elif size != ref_size:
97
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
98
+
99
+ #----------------------------------------------------------------------------
100
+ # Function decorator that calls torch.autograd.profiler.record_function().
101
+
102
+ def profiled_function(fn):
103
+ def decorator(*args, **kwargs):
104
+ with torch.autograd.profiler.record_function(fn.__name__):
105
+ return fn(*args, **kwargs)
106
+ decorator.__name__ = fn.__name__
107
+ return decorator
108
+
109
+ #----------------------------------------------------------------------------
110
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
111
+ # indefinitely, shuffling items as it goes.
112
+
113
+ class InfiniteSampler(torch.utils.data.Sampler):
114
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
115
+ assert len(dataset) > 0
116
+ assert num_replicas > 0
117
+ assert 0 <= rank < num_replicas
118
+ assert 0 <= window_size <= 1
119
+ super().__init__(dataset)
120
+ self.dataset = dataset
121
+ self.rank = rank
122
+ self.num_replicas = num_replicas
123
+ self.shuffle = shuffle
124
+ self.seed = seed
125
+ self.window_size = window_size
126
+
127
+ def __iter__(self):
128
+ order = np.arange(len(self.dataset))
129
+ rnd = None
130
+ window = 0
131
+ if self.shuffle:
132
+ rnd = np.random.RandomState(self.seed)
133
+ rnd.shuffle(order)
134
+ window = int(np.rint(order.size * self.window_size))
135
+
136
+ idx = 0
137
+ while True:
138
+ i = idx % order.size
139
+ if idx % self.num_replicas == self.rank:
140
+ yield order[i]
141
+ if window >= 2:
142
+ j = (i - rnd.randint(window)) % order.size
143
+ order[i], order[j] = order[j], order[i]
144
+ idx += 1
145
+
146
+ #----------------------------------------------------------------------------
147
+ # Utilities for operating with torch.nn.Module parameters and buffers.
148
+
149
+ def params_and_buffers(module):
150
+ assert isinstance(module, torch.nn.Module)
151
+ return list(module.parameters()) + list(module.buffers())
152
+
153
+ def named_params_and_buffers(module):
154
+ assert isinstance(module, torch.nn.Module)
155
+ return list(module.named_parameters()) + list(module.named_buffers())
156
+
157
+ def copy_params_and_buffers(src_module, dst_module, require_all=False, allow_mismatch=False):
158
+ assert isinstance(src_module, torch.nn.Module)
159
+ assert isinstance(dst_module, torch.nn.Module)
160
+ src_tensors = dict(named_params_and_buffers(src_module))
161
+ for name, tensor in named_params_and_buffers(dst_module):
162
+ assert (name in src_tensors) or (not require_all)
163
+ if name not in src_tensors and name.replace('_semantic', '') not in src_tensors:
164
+ print(f'Warning: {name} not found in source module')
165
+ continue
166
+ if name not in src_tensors:
167
+ # print(f'Warning: {name} not found in source module, using {name.replace("_semantic", "")}')
168
+ name_src = name.replace('_semantic', '')
169
+ else:
170
+ name_src = name
171
+ if src_tensors[name_src].shape != tensor.shape and allow_mismatch:
172
+ print(f'Warning: {name_src} shape mismatch: {src_tensors[name_src].shape} vs {tensor.shape}')
173
+ continue
174
+ if name_src in src_tensors:
175
+ tensor.copy_(src_tensors[name_src].detach()).requires_grad_(tensor.requires_grad)
176
+ # print(f'Copied {name}')
177
+
178
+ #----------------------------------------------------------------------------
179
+ # Context manager for easily enabling/disabling DistributedDataParallel
180
+ # synchronization.
181
+
182
+ @contextlib.contextmanager
183
+ def ddp_sync(module, sync):
184
+ assert isinstance(module, torch.nn.Module)
185
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
186
+ yield
187
+ else:
188
+ with module.no_sync():
189
+ yield
190
+
191
+ #----------------------------------------------------------------------------
192
+ # Check DistributedDataParallel consistency across processes.
193
+
194
+ def check_ddp_consistency(module, ignore_regex=None):
195
+ assert isinstance(module, torch.nn.Module)
196
+ for name, tensor in named_params_and_buffers(module):
197
+ fullname = type(module).__name__ + '.' + name
198
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
199
+ continue
200
+ tensor = tensor.detach()
201
+ if tensor.is_floating_point():
202
+ tensor = nan_to_num(tensor)
203
+ other = tensor.clone()
204
+ torch.distributed.broadcast(tensor=other, src=0)
205
+ assert (tensor == other).all(), fullname
206
+
207
+ #----------------------------------------------------------------------------
208
+ # Print summary table of module hierarchy.
209
+
210
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
211
+ assert isinstance(module, torch.nn.Module)
212
+ assert not isinstance(module, torch.jit.ScriptModule)
213
+ assert isinstance(inputs, (tuple, list))
214
+
215
+ # Register hooks.
216
+ entries = []
217
+ nesting = [0]
218
+ def pre_hook(_mod, _inputs):
219
+ nesting[0] += 1
220
+ def post_hook(mod, _inputs, outputs):
221
+ nesting[0] -= 1
222
+ if nesting[0] <= max_nesting:
223
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
224
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
225
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
226
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
227
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
228
+
229
+ # Run module.
230
+ outputs = module(*inputs)
231
+ for hook in hooks:
232
+ hook.remove()
233
+
234
+ # Identify unique outputs, parameters, and buffers.
235
+ tensors_seen = set()
236
+ for e in entries:
237
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
238
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
239
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
240
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
241
+
242
+ # Filter out redundant entries.
243
+ if skip_redundant:
244
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
245
+
246
+ # Construct table.
247
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
248
+ rows += [['---'] * len(rows[0])]
249
+ param_total = 0
250
+ buffer_total = 0
251
+ submodule_names = {mod: name for name, mod in module.named_modules()}
252
+ for e in entries:
253
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
254
+ param_size = sum(t.numel() for t in e.unique_params)
255
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
256
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
257
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
258
+ rows += [[
259
+ name + (':0' if len(e.outputs) >= 2 else ''),
260
+ str(param_size) if param_size else '-',
261
+ str(buffer_size) if buffer_size else '-',
262
+ (output_shapes + ['-'])[0],
263
+ (output_dtypes + ['-'])[0],
264
+ ]]
265
+ for idx in range(1, len(e.outputs)):
266
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
267
+ param_total += param_size
268
+ buffer_total += buffer_size
269
+ rows += [['---'] * len(rows[0])]
270
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
271
+
272
+ # Print table.
273
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
274
+ print()
275
+ for row in rows:
276
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
277
+ print()
278
+ return outputs
279
+
280
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ # empty
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4
+ *
5
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6
+ * property and proprietary rights in and to this material, related
7
+ * documentation and any modifications thereto. Any use, reproduction,
8
+ * disclosure or distribution of this material and related documentation
9
+ * without an express license agreement from NVIDIA CORPORATION or
10
+ * its affiliates is strictly prohibited.
11
+ */
12
+
13
+ #include <torch/extension.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+ #include <c10/cuda/CUDAGuard.h>
16
+ #include "bias_act.h"
17
+
18
+ //------------------------------------------------------------------------
19
+
20
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
21
+ {
22
+ if (x.dim() != y.dim())
23
+ return false;
24
+ for (int64_t i = 0; i < x.dim(); i++)
25
+ {
26
+ if (x.size(i) != y.size(i))
27
+ return false;
28
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
29
+ return false;
30
+ }
31
+ return true;
32
+ }
33
+
34
+ //------------------------------------------------------------------------
35
+
36
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
37
+ {
38
+ // Validate arguments.
39
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
40
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
41
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
42
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
43
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
44
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
45
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
46
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
47
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
48
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
49
+
50
+ // Validate layout.
51
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
52
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
53
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
54
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
55
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
56
+
57
+ // Create output tensor.
58
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
59
+ torch::Tensor y = torch::empty_like(x);
60
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
61
+
62
+ // Initialize CUDA kernel parameters.
63
+ bias_act_kernel_params p;
64
+ p.x = x.data_ptr();
65
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
66
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
67
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
68
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
69
+ p.y = y.data_ptr();
70
+ p.grad = grad;
71
+ p.act = act;
72
+ p.alpha = alpha;
73
+ p.gain = gain;
74
+ p.clamp = clamp;
75
+ p.sizeX = (int)x.numel();
76
+ p.sizeB = (int)b.numel();
77
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
78
+
79
+ // Choose CUDA kernel.
80
+ void* kernel;
81
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
82
+ {
83
+ kernel = choose_bias_act_kernel<scalar_t>(p);
84
+ });
85
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
86
+
87
+ // Launch CUDA kernel.
88
+ p.loopX = 4;
89
+ int blockSize = 4 * 32;
90
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
91
+ void* args[] = {&p};
92
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93
+ return y;
94
+ }
95
+
96
+ //------------------------------------------------------------------------
97
+
98
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99
+ {
100
+ m.def("bias_act", &bias_act);
101
+ }
102
+
103
+ //------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4
+ *
5
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6
+ * property and proprietary rights in and to this material, related
7
+ * documentation and any modifications thereto. Any use, reproduction,
8
+ * disclosure or distribution of this material and related documentation
9
+ * without an express license agreement from NVIDIA CORPORATION or
10
+ * its affiliates is strictly prohibited.
11
+ */
12
+
13
+ #include <c10/util/Half.h>
14
+ #include "bias_act.h"
15
+
16
+ //------------------------------------------------------------------------
17
+ // Helpers.
18
+
19
+ template <class T> struct InternalType;
20
+ template <> struct InternalType<double> { typedef double scalar_t; };
21
+ template <> struct InternalType<float> { typedef float scalar_t; };
22
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
23
+
24
+ //------------------------------------------------------------------------
25
+ // CUDA kernel.
26
+
27
+ template <class T, int A>
28
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
29
+ {
30
+ typedef typename InternalType<T>::scalar_t scalar_t;
31
+ int G = p.grad;
32
+ scalar_t alpha = (scalar_t)p.alpha;
33
+ scalar_t gain = (scalar_t)p.gain;
34
+ scalar_t clamp = (scalar_t)p.clamp;
35
+ scalar_t one = (scalar_t)1;
36
+ scalar_t two = (scalar_t)2;
37
+ scalar_t expRange = (scalar_t)80;
38
+ scalar_t halfExpRange = (scalar_t)40;
39
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
40
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
41
+
42
+ // Loop over elements.
43
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
44
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
45
+ {
46
+ // Load.
47
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
48
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
49
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
50
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
51
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
52
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
53
+ scalar_t y = 0;
54
+
55
+ // Apply bias.
56
+ ((G == 0) ? x : xref) += b;
57
+
58
+ // linear
59
+ if (A == 1)
60
+ {
61
+ if (G == 0) y = x;
62
+ if (G == 1) y = x;
63
+ }
64
+
65
+ // relu
66
+ if (A == 2)
67
+ {
68
+ if (G == 0) y = (x > 0) ? x : 0;
69
+ if (G == 1) y = (yy > 0) ? x : 0;
70
+ }
71
+
72
+ // lrelu
73
+ if (A == 3)
74
+ {
75
+ if (G == 0) y = (x > 0) ? x : x * alpha;
76
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
77
+ }
78
+
79
+ // tanh
80
+ if (A == 4)
81
+ {
82
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
83
+ if (G == 1) y = x * (one - yy * yy);
84
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
85
+ }
86
+
87
+ // sigmoid
88
+ if (A == 5)
89
+ {
90
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
91
+ if (G == 1) y = x * yy * (one - yy);
92
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
93
+ }
94
+
95
+ // elu
96
+ if (A == 6)
97
+ {
98
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
99
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
100
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
101
+ }
102
+
103
+ // selu
104
+ if (A == 7)
105
+ {
106
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
107
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
108
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
109
+ }
110
+
111
+ // softplus
112
+ if (A == 8)
113
+ {
114
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
115
+ if (G == 1) y = x * (one - exp(-yy));
116
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
117
+ }
118
+
119
+ // swish
120
+ if (A == 9)
121
+ {
122
+ if (G == 0)
123
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
124
+ else
125
+ {
126
+ scalar_t c = exp(xref);
127
+ scalar_t d = c + one;
128
+ if (G == 1)
129
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
130
+ else
131
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
132
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
133
+ }
134
+ }
135
+
136
+ // Apply gain.
137
+ y *= gain * dy;
138
+
139
+ // Clamp.
140
+ if (clamp >= 0)
141
+ {
142
+ if (G == 0)
143
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
144
+ else
145
+ y = (yref > -clamp & yref < clamp) ? y : 0;
146
+ }
147
+
148
+ // Store.
149
+ ((T*)p.y)[xi] = (T)y;
150
+ }
151
+ }
152
+
153
+ //------------------------------------------------------------------------
154
+ // CUDA kernel selection.
155
+
156
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
157
+ {
158
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
159
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
160
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
161
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
162
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
163
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
164
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
165
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
166
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
167
+ return NULL;
168
+ }
169
+
170
+ //------------------------------------------------------------------------
171
+ // Template specializations.
172
+
173
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
174
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
175
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
176
+
177
+ //------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4
+ *
5
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6
+ * property and proprietary rights in and to this material, related
7
+ * documentation and any modifications thereto. Any use, reproduction,
8
+ * disclosure or distribution of this material and related documentation
9
+ * without an express license agreement from NVIDIA CORPORATION or
10
+ * its affiliates is strictly prohibited.
11
+ */
12
+
13
+ //------------------------------------------------------------------------
14
+ // CUDA kernel parameters.
15
+
16
+ struct bias_act_kernel_params
17
+ {
18
+ const void* x; // [sizeX]
19
+ const void* b; // [sizeB] or NULL
20
+ const void* xref; // [sizeX] or NULL
21
+ const void* yref; // [sizeX] or NULL
22
+ const void* dy; // [sizeX] or NULL
23
+ void* y; // [sizeX]
24
+
25
+ int grad;
26
+ int act;
27
+ float alpha;
28
+ float gain;
29
+ float clamp;
30
+
31
+ int sizeX;
32
+ int sizeB;
33
+ int stepB;
34
+ int loopX;
35
+ };
36
+
37
+ //------------------------------------------------------------------------
38
+ // CUDA kernel selection.
39
+
40
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
41
+
42
+ //------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Custom PyTorch ops for efficient bias and activation."""
12
+
13
+ import os
14
+ import numpy as np
15
+ import torch
16
+ import dnnlib
17
+
18
+ from .. import custom_ops
19
+ from .. import misc
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ activation_funcs = {
24
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33
+ }
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ _plugin = None
38
+ _null_tensor = torch.empty([0])
39
+
40
+ def _init():
41
+ global _plugin
42
+ if _plugin is None:
43
+ _plugin = custom_ops.get_plugin(
44
+ module_name='bias_act_plugin',
45
+ sources=['bias_act.cpp', 'bias_act.cu'],
46
+ headers=['bias_act.h'],
47
+ source_dir=os.path.dirname(__file__),
48
+ extra_cuda_cflags=['--use_fast_math'],
49
+ )
50
+ return True
51
+
52
+ #----------------------------------------------------------------------------
53
+
54
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
55
+ r"""Fused bias and activation function.
56
+
57
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
58
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
59
+ the fused op is considerably more efficient than performing the same calculation
60
+ using standard PyTorch ops. It supports first and second order gradients,
61
+ but not third order gradients.
62
+
63
+ Args:
64
+ x: Input activation tensor. Can be of any shape.
65
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
66
+ as `x`. The shape must be known, and it must match the dimension of `x`
67
+ corresponding to `dim`.
68
+ dim: The dimension in `x` corresponding to the elements of `b`.
69
+ The value of `dim` is ignored if `b` is not specified.
70
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
71
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
72
+ See `activation_funcs` for a full list. `None` is not allowed.
73
+ alpha: Shape parameter for the activation function, or `None` to use the default.
74
+ gain: Scaling factor for the output tensor, or `None` to use default.
75
+ See `activation_funcs` for the default scaling of each activation function.
76
+ If unsure, consider specifying 1.
77
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
78
+ the clamping (default).
79
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
80
+
81
+ Returns:
82
+ Tensor of the same shape and datatype as `x`.
83
+ """
84
+ assert isinstance(x, torch.Tensor)
85
+ assert impl in ['ref', 'cuda']
86
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
87
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
88
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
89
+
90
+ #----------------------------------------------------------------------------
91
+
92
+ @misc.profiled_function
93
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
94
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
95
+ """
96
+ assert isinstance(x, torch.Tensor)
97
+ assert clamp is None or clamp >= 0
98
+ spec = activation_funcs[act]
99
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
100
+ gain = float(gain if gain is not None else spec.def_gain)
101
+ clamp = float(clamp if clamp is not None else -1)
102
+
103
+ # Add bias.
104
+ if b is not None:
105
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
106
+ assert 0 <= dim < x.ndim
107
+ assert b.shape[0] == x.shape[dim]
108
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
109
+
110
+ # Evaluate activation function.
111
+ alpha = float(alpha)
112
+ x = spec.func(x, alpha=alpha)
113
+
114
+ # Scale by gain.
115
+ gain = float(gain)
116
+ if gain != 1:
117
+ x = x * gain
118
+
119
+ # Clamp.
120
+ if clamp >= 0:
121
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
122
+ return x
123
+
124
+ #----------------------------------------------------------------------------
125
+
126
+ _bias_act_cuda_cache = dict()
127
+
128
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
129
+ """Fast CUDA implementation of `bias_act()` using custom ops.
130
+ """
131
+ # Parse arguments.
132
+ assert clamp is None or clamp >= 0
133
+ spec = activation_funcs[act]
134
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
135
+ gain = float(gain if gain is not None else spec.def_gain)
136
+ clamp = float(clamp if clamp is not None else -1)
137
+
138
+ # Lookup from cache.
139
+ key = (dim, act, alpha, gain, clamp)
140
+ if key in _bias_act_cuda_cache:
141
+ return _bias_act_cuda_cache[key]
142
+
143
+ # Forward op.
144
+ class BiasActCuda(torch.autograd.Function):
145
+ @staticmethod
146
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
147
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
148
+ x = x.contiguous(memory_format=ctx.memory_format)
149
+ b = b.contiguous() if b is not None else _null_tensor
150
+ y = x
151
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
152
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
153
+ ctx.save_for_backward(
154
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
155
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156
+ y if 'y' in spec.ref else _null_tensor)
157
+ return y
158
+
159
+ @staticmethod
160
+ def backward(ctx, dy): # pylint: disable=arguments-differ
161
+ dy = dy.contiguous(memory_format=ctx.memory_format)
162
+ x, b, y = ctx.saved_tensors
163
+ dx = None
164
+ db = None
165
+
166
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
167
+ dx = dy
168
+ if act != 'linear' or gain != 1 or clamp >= 0:
169
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
170
+
171
+ if ctx.needs_input_grad[1]:
172
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
173
+
174
+ return dx, db
175
+
176
+ # Backward op.
177
+ class BiasActCudaGrad(torch.autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
180
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
181
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
182
+ ctx.save_for_backward(
183
+ dy if spec.has_2nd_grad else _null_tensor,
184
+ x, b, y)
185
+ return dx
186
+
187
+ @staticmethod
188
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
189
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
190
+ dy, x, b, y = ctx.saved_tensors
191
+ d_dy = None
192
+ d_x = None
193
+ d_b = None
194
+ d_y = None
195
+
196
+ if ctx.needs_input_grad[0]:
197
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
198
+
199
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
200
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
201
+
202
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
203
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
204
+
205
+ return d_dy, d_x, d_b, d_y
206
+
207
+ # Add to cache.
208
+ _bias_act_cuda_cache[key] = BiasActCuda
209
+ return BiasActCuda
210
+
211
+ #----------------------------------------------------------------------------
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
12
+ arbitrarily high order gradients with zero performance penalty."""
13
+
14
+ import contextlib
15
+ import torch
16
+
17
+ # pylint: disable=redefined-builtin
18
+ # pylint: disable=arguments-differ
19
+ # pylint: disable=protected-access
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ enabled = False # Enable the custom op by setting this to true.
24
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
25
+
26
+ @contextlib.contextmanager
27
+ def no_weight_gradients(disable=True):
28
+ global weight_gradients_disabled
29
+ old = weight_gradients_disabled
30
+ if disable:
31
+ weight_gradients_disabled = True
32
+ yield
33
+ weight_gradients_disabled = old
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
38
+ if _should_use_custom_op(input):
39
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
40
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
41
+
42
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
43
+ if _should_use_custom_op(input):
44
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
45
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def _should_use_custom_op(input):
50
+ assert isinstance(input, torch.Tensor)
51
+ if (not enabled) or (not torch.backends.cudnn.enabled):
52
+ return False
53
+ if input.device.type != 'cuda':
54
+ return False
55
+ return True
56
+
57
+ def _tuple_of_ints(xs, ndim):
58
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
59
+ assert len(xs) == ndim
60
+ assert all(isinstance(x, int) for x in xs)
61
+ return xs
62
+
63
+ #----------------------------------------------------------------------------
64
+
65
+ _conv2d_gradfix_cache = dict()
66
+ _null_tensor = torch.empty([0])
67
+
68
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69
+ # Parse arguments.
70
+ ndim = 2
71
+ weight_shape = tuple(weight_shape)
72
+ stride = _tuple_of_ints(stride, ndim)
73
+ padding = _tuple_of_ints(padding, ndim)
74
+ output_padding = _tuple_of_ints(output_padding, ndim)
75
+ dilation = _tuple_of_ints(dilation, ndim)
76
+
77
+ # Lookup from cache.
78
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79
+ if key in _conv2d_gradfix_cache:
80
+ return _conv2d_gradfix_cache[key]
81
+
82
+ # Validate arguments.
83
+ assert groups >= 1
84
+ assert len(weight_shape) == ndim + 2
85
+ assert all(stride[i] >= 1 for i in range(ndim))
86
+ assert all(padding[i] >= 0 for i in range(ndim))
87
+ assert all(dilation[i] >= 0 for i in range(ndim))
88
+ if not transpose:
89
+ assert all(output_padding[i] == 0 for i in range(ndim))
90
+ else: # transpose
91
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92
+
93
+ # Helpers.
94
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95
+ def calc_output_padding(input_shape, output_shape):
96
+ if transpose:
97
+ return [0, 0]
98
+ return [
99
+ input_shape[i + 2]
100
+ - (output_shape[i + 2] - 1) * stride[i]
101
+ - (1 - 2 * padding[i])
102
+ - dilation[i] * (weight_shape[i + 2] - 1)
103
+ for i in range(ndim)
104
+ ]
105
+
106
+ # Forward & backward.
107
+ class Conv2d(torch.autograd.Function):
108
+ @staticmethod
109
+ def forward(ctx, input, weight, bias):
110
+ assert weight.shape == weight_shape
111
+ ctx.save_for_backward(
112
+ input if weight.requires_grad else _null_tensor,
113
+ weight if input.requires_grad else _null_tensor,
114
+ )
115
+ ctx.input_shape = input.shape
116
+
117
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
118
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
119
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
120
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
121
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
122
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
123
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
124
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
125
+
126
+ # General case => cuDNN.
127
+ if transpose:
128
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
129
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
130
+
131
+ @staticmethod
132
+ def backward(ctx, grad_output):
133
+ input, weight = ctx.saved_tensors
134
+ input_shape = ctx.input_shape
135
+ grad_input = None
136
+ grad_weight = None
137
+ grad_bias = None
138
+
139
+ if ctx.needs_input_grad[0]:
140
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
141
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
142
+ grad_input = op.apply(grad_output, weight, None)
143
+ assert grad_input.shape == input_shape
144
+
145
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
146
+ grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
147
+ assert grad_weight.shape == weight_shape
148
+
149
+ if ctx.needs_input_grad[2]:
150
+ grad_bias = grad_output.sum([0, 2, 3])
151
+
152
+ return grad_input, grad_weight, grad_bias
153
+
154
+ # Gradient with respect to the weights.
155
+ class Conv2dGradWeight(torch.autograd.Function):
156
+ @staticmethod
157
+ def forward(ctx, grad_output, input, weight):
158
+ ctx.save_for_backward(
159
+ grad_output if input.requires_grad else _null_tensor,
160
+ input if grad_output.requires_grad else _null_tensor,
161
+ )
162
+ ctx.grad_output_shape = grad_output.shape
163
+ ctx.input_shape = input.shape
164
+
165
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
166
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
167
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
168
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
169
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
170
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
171
+
172
+ # General case => cuDNN.
173
+ return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
174
+
175
+
176
+ @staticmethod
177
+ def backward(ctx, grad2_grad_weight):
178
+ grad_output, input = ctx.saved_tensors
179
+ grad_output_shape = ctx.grad_output_shape
180
+ input_shape = ctx.input_shape
181
+ grad2_grad_output = None
182
+ grad2_input = None
183
+
184
+ if ctx.needs_input_grad[0]:
185
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
186
+ assert grad2_grad_output.shape == grad_output_shape
187
+
188
+ if ctx.needs_input_grad[1]:
189
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
190
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
191
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
192
+ assert grad2_input.shape == input_shape
193
+
194
+ return grad2_grad_output, grad2_input
195
+
196
+ _conv2d_gradfix_cache[key] = Conv2d
197
+ return Conv2d
198
+
199
+ #----------------------------------------------------------------------------