Upload 88 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- pix2pix3D-main/pix2pix3D-main/.gitignore +146 -0
- pix2pix3D-main/pix2pix3D-main/LICENSE +21 -0
- pix2pix3D-main/pix2pix3D-main/README.md +170 -0
- pix2pix3D-main/pix2pix3D-main/applications/demo/qt_demo_seg2cat.py +504 -0
- pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/__init__.py +0 -0
- pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/mouse_event.py +100 -0
- pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/ui.py +988 -0
- pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/ui_clean.py +797 -0
- pix2pix3D-main/pix2pix3D-main/applications/edge2cat.ipynb +0 -0
- pix2pix3D-main/pix2pix3D-main/applications/extract_mesh.py +267 -0
- pix2pix3D-main/pix2pix3D-main/applications/generate_samples.py +128 -0
- pix2pix3D-main/pix2pix3D-main/applications/generate_video.py +220 -0
- pix2pix3D-main/pix2pix3D-main/assets/demo.mp4 +3 -0
- pix2pix3D-main/pix2pix3D-main/assets/rendered_mesh_colored.gif +3 -0
- pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1.gif +3 -0
- pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_color.png +3 -0
- pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_label.png +0 -0
- pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_input.png +0 -0
- pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1_label.gif +3 -0
- pix2pix3D-main/pix2pix3D-main/assets/teaser_gif.gif +3 -0
- pix2pix3D-main/pix2pix3D-main/assets/teaser_jpg.jpg +3 -0
- pix2pix3D-main/pix2pix3D-main/assets/teaser_png.png +3 -0
- pix2pix3D-main/pix2pix3D-main/camera_utils.py +149 -0
- pix2pix3D-main/pix2pix3D-main/checkpoints/download_models.sh +5 -0
- pix2pix3D-main/pix2pix3D-main/dnnlib/__init__.py +11 -0
- pix2pix3D-main/pix2pix3D-main/dnnlib/util.py +493 -0
- pix2pix3D-main/pix2pix3D-main/environment.yml +39 -0
- pix2pix3D-main/pix2pix3D-main/examples/example_input.png +0 -0
- pix2pix3D-main/pix2pix3D-main/examples/example_input_edge2car.png +0 -0
- pix2pix3D-main/pix2pix3D-main/examples/example_input_edge2cat.png +0 -0
- pix2pix3D-main/pix2pix3D-main/legacy.py +325 -0
- pix2pix3D-main/pix2pix3D-main/metrics/__init__.py +11 -0
- pix2pix3D-main/pix2pix3D-main/metrics/equivariance.py +269 -0
- pix2pix3D-main/pix2pix3D-main/metrics/frechet_inception_distance.py +43 -0
- pix2pix3D-main/pix2pix3D-main/metrics/inception_score.py +40 -0
- pix2pix3D-main/pix2pix3D-main/metrics/kernel_inception_distance.py +48 -0
- pix2pix3D-main/pix2pix3D-main/metrics/metric_main.py +155 -0
- pix2pix3D-main/pix2pix3D-main/metrics/metric_utils.py +281 -0
- pix2pix3D-main/pix2pix3D-main/metrics/perceptual_path_length.py +127 -0
- pix2pix3D-main/pix2pix3D-main/metrics/precision_recall.py +64 -0
- pix2pix3D-main/pix2pix3D-main/torch_utils/__init__.py +11 -0
- pix2pix3D-main/pix2pix3D-main/torch_utils/custom_ops.py +159 -0
- pix2pix3D-main/pix2pix3D-main/torch_utils/misc.py +280 -0
- pix2pix3D-main/pix2pix3D-main/torch_utils/ops/__init__.py +11 -0
- pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.cpp +103 -0
- pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.cu +177 -0
- pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.h +42 -0
- pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.py +211 -0
- pix2pix3D-main/pix2pix3D-main/torch_utils/ops/conv2d_gradfix.py +199 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
pix2pix3D-main/pix2pix3D-main/assets/demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
pix2pix3D-main/pix2pix3D-main/assets/rendered_mesh_colored.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1_label.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_color.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
pix2pix3D-main/pix2pix3D-main/assets/teaser_gif.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
pix2pix3D-main/pix2pix3D-main/assets/teaser_jpg.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
pix2pix3D-main/pix2pix3D-main/assets/teaser_png.png filter=lfs diff=lfs merge=lfs -text
|
pix2pix3D-main/pix2pix3D-main/.gitignore
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
98 |
+
__pypackages__/
|
99 |
+
|
100 |
+
# Celery stuff
|
101 |
+
celerybeat-schedule
|
102 |
+
celerybeat.pid
|
103 |
+
|
104 |
+
# SageMath parsed files
|
105 |
+
*.sage.py
|
106 |
+
|
107 |
+
# Environments
|
108 |
+
.env
|
109 |
+
.venv
|
110 |
+
env/
|
111 |
+
venv/
|
112 |
+
ENV/
|
113 |
+
env.bak/
|
114 |
+
venv.bak/
|
115 |
+
|
116 |
+
# Spyder project settings
|
117 |
+
.spyderproject
|
118 |
+
.spyproject
|
119 |
+
|
120 |
+
# Rope project settings
|
121 |
+
.ropeproject
|
122 |
+
|
123 |
+
# mkdocs documentation
|
124 |
+
/site
|
125 |
+
|
126 |
+
# mypy
|
127 |
+
.mypy_cache/
|
128 |
+
.dmypy.json
|
129 |
+
dmypy.json
|
130 |
+
|
131 |
+
# Pyre type checker
|
132 |
+
.pyre/
|
133 |
+
|
134 |
+
# pytype static type analyzer
|
135 |
+
.pytype/
|
136 |
+
|
137 |
+
# Cython debug symbols
|
138 |
+
cython_debug/
|
139 |
+
|
140 |
+
logs
|
141 |
+
data
|
142 |
+
wandb
|
143 |
+
scripts
|
144 |
+
web
|
145 |
+
*.pyc
|
146 |
+
.vscode
|
pix2pix3D-main/pix2pix3D-main/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Kangle Deng
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
pix2pix3D-main/pix2pix3D-main/README.md
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
# 3D-aware Conditional Image Synthesis (pix2pix3D)
|
4 |
+
|
5 |
+
[**Project**](https://www.cs.cmu.edu/~pix2pix3D/) | [**Paper**](https://arxiv.org/abs/2302.08509)
|
6 |
+
|
7 |
+
This is the official PyTorch implementation of "3D-aware Conditional Image Synthesis". Pix2pix3D synthesizes 3D objects (neural fields) given a 2D label map, such as a segmentation or edge map. We also provide an interactive 3D editing demo.
|
8 |
+
|
9 |
+
https://user-images.githubusercontent.com/28395429/222578030-8bb2c727-397e-44b6-9ab1-9b0b09dd3b5b.mp4
|
10 |
+
|
11 |
+
|
12 |
+
[3D-aware Conditional Image Synthesis](https://arxiv.org/abs/2302.08509)
|
13 |
+
|
14 |
+
CVPR 2023
|
15 |
+
|
16 |
+
[Kangle Deng](https://dunbar12138.github.io/),
|
17 |
+
[Gengshan Yang](https://gengshan-y.github.io/),
|
18 |
+
[Deva Ramanan](https://www.cs.cmu.edu/~deva/),
|
19 |
+
[Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/)
|
20 |
+
|
21 |
+
Carnegie Mellon University
|
22 |
+
|
23 |
+
---
|
24 |
+
|
25 |
+
We propose pix2pix3D, a 3D-aware conditional generative model for controllable photorealistic image synthesis. Given a 2D label map, such as a segmentation or edge map, our model learns to synthesize a corresponding image from different viewpoints. To enable explicit 3D user control, we extend conditional generative models with neural radiance fields. Given widely-available monocular images and label map pairs, our model learns to assign a label to every 3D point in addition to color and density, which enables it to render the image and pixel-aligned label map simultaneously. Finally, we build an interactive system that allows users to edit the label map from any viewpoint and generate outputs accordingly.
|
26 |
+
|
27 |
+
<p align="center">
|
28 |
+
<img src="assets/teaser_jpg.jpg" width="720" />
|
29 |
+
</p>
|
30 |
+
|
31 |
+
## Getting Started
|
32 |
+
|
33 |
+
### Dependencies
|
34 |
+
|
35 |
+
We provide a conda env file that contains all the other dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment:
|
36 |
+
```
|
37 |
+
conda env create -f environment.yml
|
38 |
+
conda activate pix2pix3d
|
39 |
+
```
|
40 |
+
|
41 |
+
### Data
|
42 |
+
|
43 |
+
We provide our preprocessed datasets, including segmentation maps and edge maps. You can download the [CelebAMask](https://drive.google.com/drive/folders/1mC6i4YmdpazJSmXrW8WFSfsImAJ8a_CF?usp=sharing) dataset, [AFHQ-Cat-Seg](https://drive.google.com/drive/folders/1yjTTE57P9-hEe-IVcE-GXdh04WGo5lD9?usp=sharing) dataset, and [Shapenet-Car-Edge](https://drive.google.com/drive/folders/1XTPuu784DIvk0ie094qyLrcF-v-jMe3_?usp=sharing) dataset and put those zip files under ```data/```.
|
44 |
+
|
45 |
+
|
46 |
+
### Pre-trained Models
|
47 |
+
|
48 |
+
You can download our pre-trained models using the following script:
|
49 |
+
```
|
50 |
+
bash checkpoints/download_models.sh
|
51 |
+
```
|
52 |
+
|
53 |
+
---
|
54 |
+
### Inference
|
55 |
+
|
56 |
+
We provide several scripts to generate the results once you download the pre-trained models.
|
57 |
+
|
58 |
+
<p align="center">
|
59 |
+
<img src="assets/teaser_gif.gif" width="720" />
|
60 |
+
</p>
|
61 |
+
|
62 |
+
#### Generate Samples
|
63 |
+
You can generate results based on the samples in the dataset.
|
64 |
+
```
|
65 |
+
python applications/generate_samples.py --network <network_pkl> --outdir <output_dir> --random_seed <random_seeds list, e.g. 0 1> --cfg <configs, e.g., seg2cat, seg2face, edge2car> --input_id <sample_id in dataset>
|
66 |
+
```
|
67 |
+
For example:
|
68 |
+
|
69 |
+
| Input Label Map | Generated Image | Generated Label Map |
|
70 |
+
| ------------- | ------------- | -------- |
|
71 |
+
| <img src="assets/seg2cat_1666_input.png" width="256" /> | <img src="assets/seg2cat_1666_1_color.png" width="256" /> | <img src="assets/seg2cat_1666_1_label.png" width="256" /> |
|
72 |
+
|
73 |
+
You can get the results above by running:
|
74 |
+
```
|
75 |
+
python applications/generate_samples.py --network checkpoints/pix2pix3d_seg2cat.pkl --outdir examples --random_seed 1 --cfg seg2cat --input_id 1666
|
76 |
+
```
|
77 |
+
|
78 |
+
#### Render Videos
|
79 |
+
You can render a video result based on a specified input label map.
|
80 |
+
```
|
81 |
+
python applications/generate_video.py --network <network_pkl> --outdir <output_dir> --random_seed <random_seeds list, e.g. 0 1> --cfg <configs, e.g., seg2cat, seg2face, edge2car> --input <input label map>
|
82 |
+
```
|
83 |
+
|
84 |
+
For example:
|
85 |
+
| Input Label Map | Generated Image | Generated Label Map |
|
86 |
+
| ------------- | ------------- | -------- |
|
87 |
+
| <img src="assets/seg2cat_1666_input.png" width="256" /> | <img src="assets/seg2cat_1.gif" width="256" /> | <img src="assets/seg2cat_1_label.gif" width="256" /> |
|
88 |
+
|
89 |
+
You can get the results above using the following command:
|
90 |
+
```
|
91 |
+
python applications/generate_video.py --network checkpoints/pix2pix3d_seg2cat.pkl --outdir examples --random_seed 1 --cfg seg2cat --input examples/example_input.png
|
92 |
+
```
|
93 |
+
|
94 |
+
#### Extract Semantic Mesh
|
95 |
+
You can also extract the mesh and color it using 3D semantic labels. Some extra packages (`pyrender`, `trimesh`, and `mcubes`) are required for mesh extraction. You can install them by `pip`. The extracted mesh will be saved as `semantic_mesh.ply`.
|
96 |
+
|
97 |
+
For example:
|
98 |
+
| Input Label Map | Semantic Mesh |
|
99 |
+
| ------------- | ------------- |
|
100 |
+
| <img src="assets/seg2cat_1666_input.png" width="256" /> | <img src="assets/rendered_mesh_colored.gif" width="256" /> |
|
101 |
+
|
102 |
+
You can get the results above with the following command:
|
103 |
+
```
|
104 |
+
python applications/extract_mesh.py --network checkpoints/pix2pix3d_seg2cat.pkl --outdir examples --cfg seg2cat --input examples/example_input.png
|
105 |
+
```
|
106 |
+
|
107 |
+
<!-- #### Interpolation -->
|
108 |
+
|
109 |
+
|
110 |
+
### Training
|
111 |
+
|
112 |
+
We provide an example training script at `train_scripts/afhq_seg.sh`:
|
113 |
+
```
|
114 |
+
python train.py --outdir=<log_dir> \
|
115 |
+
--cfg=afhq --data=data/afhq_v2_train_cat_512.zip \
|
116 |
+
--mask_data=data/afhqcat_seg_6c.zip \
|
117 |
+
--data_type=seg --semantic_channels=6 \
|
118 |
+
--render_mask=True --dis_mask=True \
|
119 |
+
--neural_rendering_resolution_initial=128 \
|
120 |
+
--resume=<EG3D-checkpoints>/afhqcats512-128.pkl \
|
121 |
+
--gpus=2 --batch=4 --mbstd-group=2 \
|
122 |
+
--gamma=5 --gen_pose_cond=True \
|
123 |
+
--random_c_prob=0.5 \
|
124 |
+
--lambda_d_semantic=0.1 \
|
125 |
+
--lambda_lpips=1 \
|
126 |
+
--lambda_cross_view=1e-4 \
|
127 |
+
--only_raw_recons=True \
|
128 |
+
--wandb_log=False
|
129 |
+
```
|
130 |
+
Training parameters:
|
131 |
+
- `outdir`: The directory to save checkpoints and logs.
|
132 |
+
- `cfg`: Choose from [afhq, celeba, shapenet].
|
133 |
+
- `data`: RGB data file.
|
134 |
+
- `mask_data`: label map data file.
|
135 |
+
- `data_type`: Choose from [seg, edge]. Specify the `semantic_channels` if using `seg`.
|
136 |
+
- `render_mask`: Whether to render label maps along with RGB.
|
137 |
+
- `dis_mask`: Whether to use a GAN loss on rendered label maps.
|
138 |
+
- `neural_rendering_resolution_initial`: The resolution of NeRF outputs.
|
139 |
+
- `resume`: We partially initialize our network with EG3D pretrained checkpoints (download [here](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/eg3d)).
|
140 |
+
- `gpus`, `batch`, `mbstd-group`: Parameters for batch size and multi-gpu training.
|
141 |
+
- `gen_pose_cond`: Whether to condition the generation on camera poses.
|
142 |
+
- `random_c_prob`: Probablity of sampling random poses for training.
|
143 |
+
- `lambda_d_semantic`: The weight of GAN loss on label maps.
|
144 |
+
- `lambda_lpips`: The weight of RGB LPIPS loss.
|
145 |
+
- `lambda_cross_view`: The weight of cross-view consistency loss.
|
146 |
+
- `wandb_log`: Whether to use wandb log.
|
147 |
+
|
148 |
+
### Prepare your own dataset
|
149 |
+
|
150 |
+
We follow the dataset format of EG3D [here](https://github.com/NVlabs/eg3d#preparing-datasets). You can obtain the segmentation masks of your own dataset by [DINO clustering](https://github.com/ShirAmir/dino-vit-features/blob/main/part_cosegmentation.py), and obtain the edge map by [pidinet](https://github.com/hellozhuo/pidinet) and [informative drawing](https://github.com/carolineec/informative-drawings).
|
151 |
+
|
152 |
+
---
|
153 |
+
|
154 |
+
## Citation
|
155 |
+
|
156 |
+
If you find this repository useful for your research, please cite the following work.
|
157 |
+
```
|
158 |
+
@inproceedings{kangle2023pix2pix3d,
|
159 |
+
title={3D-aware Conditional Image Synthesis},
|
160 |
+
author={Deng, Kangle and Yang, Gengshan and Ramanan, Deva and Zhu, Jun-Yan},
|
161 |
+
booktitle = {CVPR},
|
162 |
+
year = {2023}
|
163 |
+
}
|
164 |
+
```
|
165 |
+
|
166 |
+
---
|
167 |
+
|
168 |
+
## Acknowledgments
|
169 |
+
We thank Sheng-Yu Wang, Nupur Kumari, Gaurav Parmer, Ruihan Gao, Muyang Li, George Cazenavette, Andrew Song, Zhipeng Bao, Tamaki Kojima, Krishna Wadhwani, Takuya Narihira, and Tatsuo Fujiwara for their discussion and help. We are grateful for the support from Sony Corporation, Singapore DSTA, and the CMU Argo AI Center for Autonomous Vehicle Research.
|
170 |
+
This codebase borrows heavily from [EG3D](https://github.com/NVlabs/eg3d) and [StyleNeRF](https://github.com/facebookresearch/StyleNeRF).
|
pix2pix3D-main/pix2pix3D-main/applications/demo/qt_demo_seg2cat.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./')
|
3 |
+
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
import time
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torchvision.utils import save_image
|
12 |
+
|
13 |
+
from ui_qt.ui_clean import Ui_Form_Seg2cat as Ui_Form
|
14 |
+
from ui_qt.mouse_event import GraphicsScene
|
15 |
+
|
16 |
+
from PyQt5.QtCore import *
|
17 |
+
from PyQt5.QtGui import *
|
18 |
+
from PyQt5.QtWidgets import *
|
19 |
+
from PyQt5.QtPrintSupport import QPrintDialog, QPrinter
|
20 |
+
|
21 |
+
import dnnlib
|
22 |
+
import legacy
|
23 |
+
from torch_utils import misc
|
24 |
+
|
25 |
+
from camera_utils import LookAtPoseSampler, FOV_to_intrinsics
|
26 |
+
from torch_utils import misc
|
27 |
+
from training.triplane_cond import TriPlaneGenerator
|
28 |
+
from training.training_loop import setup_snapshot_image_grid, get_image_grid, save_image_grid
|
29 |
+
from train import init_conditional_dataset_kwargs
|
30 |
+
|
31 |
+
from matplotlib import pyplot as plt
|
32 |
+
|
33 |
+
from pathlib import Path
|
34 |
+
|
35 |
+
from rich.progress import track
|
36 |
+
import json
|
37 |
+
|
38 |
+
import imageio
|
39 |
+
|
40 |
+
from torch import nn
|
41 |
+
import torch.nn.functional as F
|
42 |
+
|
43 |
+
import argparse
|
44 |
+
|
45 |
+
from scipy.spatial.transform import Rotation as R
|
46 |
+
|
47 |
+
from training.utils import color_mask as color_mask_np
|
48 |
+
|
49 |
+
color_list = [QColor(255, 255, 255), QColor(204, 0, 0), QColor(76, 153, 0), QColor(204, 204, 0), QColor(51, 51, 255), QColor(204, 0, 204), QColor(0, 255, 255), QColor(255, 204, 204), QColor(102, 51, 0), QColor(255, 0, 0), QColor(102, 204, 0), QColor(255, 255, 0), QColor(0, 0, 153), QColor(0, 0, 204), QColor(255, 51, 153), QColor(0, 204, 204), QColor(0, 51, 0), QColor(255, 153, 51), QColor(0, 204, 0)]
|
50 |
+
|
51 |
+
def color_mask(m):
|
52 |
+
my_color_list = [[255, 255, 255], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
|
53 |
+
if len(m.shape) == 2:
|
54 |
+
im_base = np.zeros((m.shape[0], m.shape[1], 3))
|
55 |
+
for idx, color in enumerate(my_color_list):
|
56 |
+
im_base[m == idx] = color
|
57 |
+
return im_base
|
58 |
+
elif len(m.shape) == 3:
|
59 |
+
im_base = np.zeros((m.shape[0], m.shape[1], m.shape[2], 3))
|
60 |
+
for idx, color in enumerate(my_color_list):
|
61 |
+
im_base[m == idx] = color
|
62 |
+
return im_base
|
63 |
+
|
64 |
+
def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, device='cuda'):
|
65 |
+
gen = model.synthesis
|
66 |
+
range_u, range_v = gen.C.range_u, gen.C.range_v
|
67 |
+
u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
|
68 |
+
v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
|
69 |
+
cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device, fov=fov)
|
70 |
+
return cam
|
71 |
+
|
72 |
+
def get_yaw_pitch(cam2world):
|
73 |
+
forward = cam2world[0:3, 2]
|
74 |
+
yaw = np.arctan2(forward[0], forward[2]) - np.pi / 2
|
75 |
+
phi = np.arccos(forward[1])
|
76 |
+
v = (1 - np.cos(phi)) / 2
|
77 |
+
pitch = (1+forward[1]) / 2 * np.pi
|
78 |
+
return yaw, pitch
|
79 |
+
|
80 |
+
def create_cam2world_fromeuler(euler, radius):
|
81 |
+
r = R.from_euler('zyx', euler, degrees=False)
|
82 |
+
cam2world = r.as_matrix()
|
83 |
+
cam2world = np.concatenate([cam2world, np.array([[0, 0, 0]])], axis=0)
|
84 |
+
cam2world = np.concatenate([cam2world, np.array([0, 0, 0, 1])[...,None]], axis=1)
|
85 |
+
cam2world[:3, 3] = -cam2world[:3, 2] * radius
|
86 |
+
return cam2world
|
87 |
+
|
88 |
+
class Ex(QWidget, Ui_Form):
|
89 |
+
def __init__(self):
|
90 |
+
super(Ex, self).__init__()
|
91 |
+
self.size = 35
|
92 |
+
self.yaw = 100
|
93 |
+
self.pitch = 50
|
94 |
+
self.roll = 0
|
95 |
+
self.truncation = 0.75
|
96 |
+
|
97 |
+
self.setupUi(self)
|
98 |
+
self.show()
|
99 |
+
self.output_img = None
|
100 |
+
|
101 |
+
self.mat_img = None
|
102 |
+
|
103 |
+
self.mode = 0
|
104 |
+
self.mask = None
|
105 |
+
self.mask_m = np.ones((512, 512, 1), dtype=np.uint8) * 255
|
106 |
+
self.img = None
|
107 |
+
|
108 |
+
self.mouse_clicked = False
|
109 |
+
self.scene = GraphicsScene(self.mode, self.size)
|
110 |
+
self.graphicsView.setScene(self.scene)
|
111 |
+
self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
|
112 |
+
self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
113 |
+
self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
114 |
+
|
115 |
+
self.result_scene = QGraphicsScene()
|
116 |
+
self.graphicsView_2.setScene(self.result_scene)
|
117 |
+
self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft)
|
118 |
+
self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
119 |
+
self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
120 |
+
|
121 |
+
self.dlg = QColorDialog(self.graphicsView)
|
122 |
+
self.color = None
|
123 |
+
self.device = torch.device('cuda')
|
124 |
+
|
125 |
+
# Parse arguments
|
126 |
+
parser = argparse.ArgumentParser(description='Real-time 3D editing demo')
|
127 |
+
parser.add_argument('--network', help='Path to the network pickle file', required=True)
|
128 |
+
parser.add_argument('--data_dir', default='data/', help='Directory to the data', required=False)
|
129 |
+
args = parser.parse_args()
|
130 |
+
|
131 |
+
network_pkl = args.network
|
132 |
+
self.G = self.get_model(network_pkl)
|
133 |
+
|
134 |
+
# Initialize dataset.
|
135 |
+
data_path = Path(args.data_dir) / 'afhq_v2_train_cat_512.zip'
|
136 |
+
mask_data = Path(args.data_dir) / 'afhqcat_seg_6c.zip'
|
137 |
+
data_type= 'seg'
|
138 |
+
dataset_kwargs, dataset_name = init_conditional_dataset_kwargs(data_path, mask_data, data_type)
|
139 |
+
self.training_data = dnnlib.util.construct_class_by_name(**dataset_kwargs)
|
140 |
+
|
141 |
+
self.input_batch = None
|
142 |
+
# self.ws = None
|
143 |
+
self.ws_texture = None
|
144 |
+
|
145 |
+
self.buffer_mask = None
|
146 |
+
|
147 |
+
focal_length = 4.2647 # shapenet has higher FOV
|
148 |
+
self.intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=self.device)
|
149 |
+
|
150 |
+
os.makedirs('examples/ui', exist_ok=True)
|
151 |
+
|
152 |
+
|
153 |
+
def open(self):
|
154 |
+
fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
|
155 |
+
QDir.currentPath())
|
156 |
+
if fileName:
|
157 |
+
image = QPixmap(fileName)
|
158 |
+
mat_img = Image.open(fileName)
|
159 |
+
self.img = mat_img.copy()
|
160 |
+
if image.isNull():
|
161 |
+
QMessageBox.information(self, "Image Viewer",
|
162 |
+
"Cannot load %s." % fileName)
|
163 |
+
return
|
164 |
+
image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
|
165 |
+
|
166 |
+
if len(self.ref_scene.items())>0:
|
167 |
+
self.ref_scene.removeItem(self.ref_scene.items()[-1])
|
168 |
+
self.ref_scene.addPixmap(image)
|
169 |
+
if len(self.result_scene.items())>0:
|
170 |
+
self.result_scene.removeItem(self.result_scene.items()[-1])
|
171 |
+
self.result_scene.addPixmap(image)
|
172 |
+
|
173 |
+
def open_mask(self):
|
174 |
+
fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
|
175 |
+
QDir.currentPath())
|
176 |
+
if fileName:
|
177 |
+
mat_img = cv2.imread(fileName)
|
178 |
+
self.mask = mat_img.copy()
|
179 |
+
self.mask_m = mat_img
|
180 |
+
mat_img = mat_img.copy()
|
181 |
+
image = QImage(mat_img, 512, 512, QImage.Format_RGB888)
|
182 |
+
|
183 |
+
if image.isNull():
|
184 |
+
QMessageBox.information(self, "Image Viewer",
|
185 |
+
"Cannot load %s." % fileName)
|
186 |
+
return
|
187 |
+
|
188 |
+
for i in range(512):
|
189 |
+
for j in range(512):
|
190 |
+
r, g, b, a = image.pixelColor(i, j).getRgb()
|
191 |
+
image.setPixel(i, j, color_list[r].rgb())
|
192 |
+
|
193 |
+
pixmap = QPixmap()
|
194 |
+
pixmap.convertFromImage(image)
|
195 |
+
self.image = pixmap.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
|
196 |
+
self.scene.reset()
|
197 |
+
if len(self.scene.items())>0:
|
198 |
+
self.scene.reset_items()
|
199 |
+
self.scene.addPixmap(self.image)
|
200 |
+
|
201 |
+
|
202 |
+
def bg_mode(self):
|
203 |
+
self.scene.mode = 0
|
204 |
+
|
205 |
+
def skin_mode(self):
|
206 |
+
self.scene.mode = 1
|
207 |
+
|
208 |
+
def nose_mode(self):
|
209 |
+
self.scene.mode = 2
|
210 |
+
|
211 |
+
def eye_g_mode(self):
|
212 |
+
self.scene.mode = 3
|
213 |
+
|
214 |
+
def l_eye_mode(self):
|
215 |
+
self.scene.mode = 4
|
216 |
+
|
217 |
+
def r_eye_mode(self):
|
218 |
+
self.scene.mode = 5
|
219 |
+
|
220 |
+
def l_brow_mode(self):
|
221 |
+
self.scene.mode = 6
|
222 |
+
|
223 |
+
def r_brow_mode(self):
|
224 |
+
self.scene.mode = 7
|
225 |
+
|
226 |
+
def l_ear_mode(self):
|
227 |
+
self.scene.mode = 8
|
228 |
+
|
229 |
+
def r_ear_mode(self):
|
230 |
+
self.scene.mode = 9
|
231 |
+
|
232 |
+
def mouth_mode(self):
|
233 |
+
self.scene.mode = 10
|
234 |
+
|
235 |
+
def u_lip_mode(self):
|
236 |
+
self.scene.mode = 11
|
237 |
+
|
238 |
+
def l_lip_mode(self):
|
239 |
+
self.scene.mode = 12
|
240 |
+
|
241 |
+
def hair_mode(self):
|
242 |
+
self.scene.mode = 13
|
243 |
+
|
244 |
+
def hat_mode(self):
|
245 |
+
self.scene.mode = 14
|
246 |
+
|
247 |
+
def ear_r_mode(self):
|
248 |
+
self.scene.mode = 15
|
249 |
+
|
250 |
+
def neck_l_mode(self):
|
251 |
+
self.scene.mode = 16
|
252 |
+
|
253 |
+
def neck_mode(self):
|
254 |
+
self.scene.mode = 17
|
255 |
+
|
256 |
+
def cloth_mode(self):
|
257 |
+
self.scene.mode = 18
|
258 |
+
|
259 |
+
def increase(self):
|
260 |
+
if self.scene.size < 50:
|
261 |
+
self.scene.size += 1
|
262 |
+
|
263 |
+
def decrease(self):
|
264 |
+
if self.scene.size > 1:
|
265 |
+
self.scene.size -= 1
|
266 |
+
|
267 |
+
def changeBrushSize(self, s):
|
268 |
+
self.scene.size = s
|
269 |
+
|
270 |
+
def changeYaw(self, s):
|
271 |
+
self.yaw = s
|
272 |
+
if self.ws is not None:
|
273 |
+
self.generate()
|
274 |
+
|
275 |
+
def changePitch(self, s):
|
276 |
+
self.pitch = s
|
277 |
+
# print('changing pitch', self.pitch)
|
278 |
+
if self.ws is not None:
|
279 |
+
self.generate()
|
280 |
+
|
281 |
+
def changeRoll(self, s):
|
282 |
+
self.roll = s
|
283 |
+
if self.ws is not None:
|
284 |
+
self.generate()
|
285 |
+
|
286 |
+
def changeTruncation(self, s):
|
287 |
+
self.truncation = s / 100
|
288 |
+
self.reconstruct()
|
289 |
+
self.generate()
|
290 |
+
|
291 |
+
def inputID(self):
|
292 |
+
input_id = int(self.text_inputID.toPlainText())
|
293 |
+
self.input_batch = self.training_data[input_id]
|
294 |
+
|
295 |
+
self.mask = self.input_batch['mask'].transpose(1,2,0).astype(np.uint8)
|
296 |
+
# self.mask_m = self.input_batch['mask'].transpose(1,2,0).astype(np.uint8)
|
297 |
+
mat_img = cv2.resize(self.mask[:,:,0], (512, 512), interpolation=cv2.INTER_NEAREST)
|
298 |
+
self.mask = mat_img[:,:,np.newaxis]
|
299 |
+
self.mask_m = self.mask.copy()
|
300 |
+
image = QImage(mat_img, 512, 512, QImage.Format_RGB888)
|
301 |
+
|
302 |
+
|
303 |
+
for i in range(512):
|
304 |
+
for j in range(512):
|
305 |
+
# r, g, b, a = image.pixelColor(i, j).getRgb()
|
306 |
+
r = mat_img[j, i]
|
307 |
+
# print(r)
|
308 |
+
image.setPixel(i, j, color_list[r].rgb())
|
309 |
+
|
310 |
+
pixmap = QPixmap()
|
311 |
+
pixmap.convertFromImage(image)
|
312 |
+
self.image = pixmap.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
|
313 |
+
self.scene.reset()
|
314 |
+
if len(self.scene.items())>0:
|
315 |
+
self.scene.reset_items()
|
316 |
+
self.scene.addPixmap(self.image)
|
317 |
+
|
318 |
+
self.ws = None
|
319 |
+
|
320 |
+
roll, yaw, pitch = R.from_matrix(self.input_batch['pose'][:16].reshape(4,4)[:3,:3]).as_euler('zyx', degrees=False)
|
321 |
+
# print(yaw, pitch)
|
322 |
+
# print(self.input_batch['pose'][:16].reshape(4, 4))
|
323 |
+
pitch_range = np.pi
|
324 |
+
yaw_range = np.pi / 2
|
325 |
+
roll_range = np.pi / 4
|
326 |
+
|
327 |
+
pitch = pitch - np.pi
|
328 |
+
pitch = pitch + 2 * np.pi if pitch < -np.pi else pitch
|
329 |
+
|
330 |
+
self.yaw = ((yaw) / yaw_range * 100)
|
331 |
+
self.pitch = ((pitch) / pitch_range * 100)
|
332 |
+
self.roll = ((roll) / (roll_range) * 100)
|
333 |
+
print(self.roll, self.yaw, self.pitch)
|
334 |
+
|
335 |
+
self.intrinsics = torch.tensor(self.input_batch['pose'][16:].reshape(3,3)).float().to(self.device)
|
336 |
+
|
337 |
+
|
338 |
+
self.slider_yawselect.setValue(self.yaw)
|
339 |
+
self.slider_pitchselect.setValue(self.pitch)
|
340 |
+
self.slider_rollselect.setValue(self.roll)
|
341 |
+
|
342 |
+
|
343 |
+
def get_mask(self): # get from output
|
344 |
+
mat_img = self.buffer_mask[0].astype(np.uint8)
|
345 |
+
|
346 |
+
# print(self.mask.shape)
|
347 |
+
# mat_img = cv2.resize(mat_img, (512, 512), interpolation=cv2.INTER_NEAREST)
|
348 |
+
self.mask = mat_img[:,:,np.newaxis]
|
349 |
+
self.mask_m = self.mask.copy()
|
350 |
+
# print(mat_img.shape)
|
351 |
+
image = QImage(mat_img, 512, 512, QImage.Format_RGB888)
|
352 |
+
|
353 |
+
|
354 |
+
for i in range(512):
|
355 |
+
for j in range(512):
|
356 |
+
# r, g, b, a = image.pixelColor(i, j).getRgb()
|
357 |
+
r = mat_img[j, i]
|
358 |
+
# print(r)
|
359 |
+
image.setPixel(i, j, color_list[r].rgb())
|
360 |
+
|
361 |
+
pixmap = QPixmap()
|
362 |
+
pixmap.convertFromImage(image)
|
363 |
+
self.image = pixmap.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
|
364 |
+
self.scene.reset()
|
365 |
+
if len(self.scene.items())>0:
|
366 |
+
self.scene.reset_items()
|
367 |
+
self.scene.addPixmap(self.image)
|
368 |
+
|
369 |
+
|
370 |
+
|
371 |
+
def generate(self):
|
372 |
+
ws = self.ws
|
373 |
+
|
374 |
+
pitch_range = np.pi
|
375 |
+
yaw_range = np.pi / 2
|
376 |
+
roll_range = np.pi / 4
|
377 |
+
pitch = self.pitch / 100 * pitch_range
|
378 |
+
yaw = self.yaw / 100 * yaw_range
|
379 |
+
roll = self.roll / 100 * roll_range
|
380 |
+
|
381 |
+
cam2world_pose = torch.tensor(create_cam2world_fromeuler([roll, yaw, pitch+np.pi], radius=2.7)).float().to(self.device)
|
382 |
+
|
383 |
+
# print(cam2world_pose)
|
384 |
+
pose = torch.cat([cam2world_pose.reshape(-1, 16), self.intrinsics.reshape(-1, 9)], 1)
|
385 |
+
|
386 |
+
out = self.G.synthesis(ws, pose.to(self.device), noise_mode='const', neural_rendering_resolution=128)
|
387 |
+
img = ((out['image'].permute(0,2,3,1).squeeze(0).cpu().numpy().clip(-1, 1) * 0.5 + 0.5) * 255).astype(np.uint8).copy()
|
388 |
+
if out['image'].shape[-1] != 512:
|
389 |
+
# print(f"Resizing {out['image'].shape[-1]} to {512}")
|
390 |
+
img = cv2.resize(img, (512, 512))
|
391 |
+
|
392 |
+
qim = QImage(img.data, img.shape[1], img.shape[0], img.strides[0], QImage.Format_RGB888)
|
393 |
+
if len(self.result_scene.items())>0:
|
394 |
+
self.result_scene.removeItem(self.result_scene.items()[-1])
|
395 |
+
self.result_scene.addPixmap(QPixmap.fromImage(qim))
|
396 |
+
|
397 |
+
self.buffer_mask = torch.argmax(out['semantic'].detach(), dim=1).cpu().numpy() # 1 x 512 x 512
|
398 |
+
self.output_img = img.copy()
|
399 |
+
self.get_mask()
|
400 |
+
|
401 |
+
|
402 |
+
def reconstruct(self):
|
403 |
+
if self.input_batch is None:
|
404 |
+
return
|
405 |
+
ws = self.ws
|
406 |
+
|
407 |
+
out = self.G.synthesis(ws, torch.tensor(self.input_batch['pose']).unsqueeze(0).to(self.device))
|
408 |
+
if out['img'].shape[-1] != 512:
|
409 |
+
print(f"Resizing {out['img'].shape[-1]} to {512}")
|
410 |
+
img = resize_image(out['img'].detach(), 512)
|
411 |
+
else:
|
412 |
+
img = out['img'].detach()
|
413 |
+
img = ((img.permute(0,2,3,1).squeeze(0).cpu().numpy().clip(-1, 1) * 0.5 + 0.5) * 255).astype(np.uint8).copy()
|
414 |
+
qim = QImage(img.data, img.shape[1], img.shape[0], img.strides[0], QImage.Format_RGB888)
|
415 |
+
if len(self.ref_scene.items())>0:
|
416 |
+
self.ref_scene.removeItem(self.ref_scene.items()[-1])
|
417 |
+
self.ref_scene.addPixmap(QPixmap.fromImage(qim))
|
418 |
+
|
419 |
+
if out['semantic'].shape[-1] != 512:
|
420 |
+
seg = resize_image(out['semantic'].detach(), 512)
|
421 |
+
else:
|
422 |
+
seg = out['semantic'].detach()
|
423 |
+
seg = color_mask(torch.argmax(seg, dim=1).cpu())[0].astype(np.uint8).copy()
|
424 |
+
qim = QImage(seg.data, seg.shape[1], seg.shape[0], seg.strides[0], QImage.Format_RGB888)
|
425 |
+
if len(self.ref_seg_scene.items())>0:
|
426 |
+
self.ref_seg_scene.removeItem(self.ref_seg_scene.items()[-1])
|
427 |
+
self.ref_seg_scene.addPixmap(QPixmap.fromImage(qim))
|
428 |
+
|
429 |
+
def get_ws(self):
|
430 |
+
z = torch.from_numpy(np.random.RandomState(int(self.text_seed.toPlainText())).randn(1, self.G.z_dim).astype('float32')).to(self.device)
|
431 |
+
|
432 |
+
for i in range(6):
|
433 |
+
self.mask_m = self.make_mask(self.mask_m, self.scene.mask_points[i], self.scene.size_points[i], i)
|
434 |
+
|
435 |
+
cv2.imwrite('examples/ui/mask_input.png', self.mask_m)
|
436 |
+
|
437 |
+
forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(self.G.rendering_kwargs['avg_camera_pivot'], device=self.device),
|
438 |
+
radius=self.G.rendering_kwargs['avg_camera_radius'], device=self.device)
|
439 |
+
focal_length = 4.2647 # shapenet has higher FOV
|
440 |
+
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=self.device)
|
441 |
+
forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
442 |
+
|
443 |
+
self.ws = self.G.mapping(z, forward_pose.to(self.device),
|
444 |
+
{'mask': torch.tensor(self.mask_m[None,...,0]).unsqueeze(0).to(self.device), 'pose': torch.tensor(self.input_batch['pose']).unsqueeze(0).to(self.device)})
|
445 |
+
|
446 |
+
if self.ws_texture is None:
|
447 |
+
self.ws_texture = self.ws[:,8:,:]
|
448 |
+
else:
|
449 |
+
self.ws[:,8:,:] = self.ws_texture
|
450 |
+
# print(self.ws[:,8:,:])
|
451 |
+
|
452 |
+
|
453 |
+
def generateAndReconstruct(self):
|
454 |
+
self.get_ws()
|
455 |
+
self.generate()
|
456 |
+
# self.reconstruct()
|
457 |
+
|
458 |
+
|
459 |
+
def make_mask(self, mask, pts, sizes, color):
|
460 |
+
if len(pts)>0:
|
461 |
+
for idx, pt in enumerate(pts):
|
462 |
+
cv2.line(mask,pt['prev'],pt['curr'],(color,color,color),sizes[idx])
|
463 |
+
return mask
|
464 |
+
|
465 |
+
def save_img(self):
|
466 |
+
for i in range(6):
|
467 |
+
self.mask_m = self.make_mask(self.mask_m, self.scene.mask_points[i], self.scene.size_points[i], i)
|
468 |
+
mask_np = np.array(self.mask_m)[..., 0]
|
469 |
+
print(mask_np.shape)
|
470 |
+
cv2.imwrite('examples/ui/mask.png', mask_np)
|
471 |
+
cv2.imwrite('examples/ui/mask_color.png', color_mask_np(mask_np).astype(np.uint8)[...,::-1])
|
472 |
+
cv2.imwrite('examples/ui/output.png',self.output_img[...,::-1])
|
473 |
+
|
474 |
+
|
475 |
+
def undo(self):
|
476 |
+
self.scene.undo()
|
477 |
+
|
478 |
+
def clear(self):
|
479 |
+
self.mask_m = self.mask.copy()
|
480 |
+
|
481 |
+
self.scene.reset_items()
|
482 |
+
self.scene.reset()
|
483 |
+
if type(self.image):
|
484 |
+
self.scene.addPixmap(self.image)
|
485 |
+
|
486 |
+
if len(self.result_scene.items())>0:
|
487 |
+
self.result_scene.removeItem(self.result_scene.items()[-1])
|
488 |
+
|
489 |
+
self.ws = None
|
490 |
+
|
491 |
+
def get_model(self, network_pkl):
|
492 |
+
device = torch.device('cuda')
|
493 |
+
with dnnlib.util.open_url(network_pkl) as f:
|
494 |
+
G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
|
495 |
+
|
496 |
+
return G
|
497 |
+
|
498 |
+
def clear_ws(self):
|
499 |
+
self.ws = None
|
500 |
+
|
501 |
+
if __name__ == '__main__':
|
502 |
+
app = QApplication(sys.argv)
|
503 |
+
ex = Ex()
|
504 |
+
sys.exit(app.exec_())
|
pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/__init__.py
ADDED
File without changes
|
pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/mouse_event.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from PyQt5.QtCore import *
|
4 |
+
from PyQt5.QtGui import *
|
5 |
+
from PyQt5.QtWidgets import *
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
color_list = [QColor(255, 255, 255), QColor(204, 0, 0), QColor(76, 153, 0), QColor(204, 204, 0), QColor(51, 51, 255), QColor(204, 0, 204), QColor(0, 255, 255), QColor(255, 204, 204), QColor(102, 51, 0), QColor(255, 0, 0), QColor(102, 204, 0), QColor(255, 255, 0), QColor(0, 0, 153), QColor(0, 0, 204), QColor(255, 51, 153), QColor(0, 204, 204), QColor(0, 51, 0), QColor(255, 153, 51), QColor(0, 204, 0)]
|
9 |
+
|
10 |
+
class GraphicsScene(QGraphicsScene):
|
11 |
+
def __init__(self, mode, size, parent=None):
|
12 |
+
QGraphicsScene.__init__(self, parent)
|
13 |
+
self.mode = mode
|
14 |
+
self.size = size
|
15 |
+
self.mouse_clicked = False
|
16 |
+
self.prev_pt = None
|
17 |
+
|
18 |
+
# self.masked_image = None
|
19 |
+
|
20 |
+
# save the points
|
21 |
+
self.mask_points = []
|
22 |
+
for i in range(len(color_list)):
|
23 |
+
self.mask_points.append([])
|
24 |
+
|
25 |
+
# save the size of points
|
26 |
+
self.size_points = []
|
27 |
+
for i in range(len(color_list)):
|
28 |
+
self.size_points.append([])
|
29 |
+
|
30 |
+
# save the history of edit
|
31 |
+
self.history = []
|
32 |
+
|
33 |
+
def reset(self):
|
34 |
+
# save the points
|
35 |
+
self.mask_points = []
|
36 |
+
for i in range(len(color_list)):
|
37 |
+
self.mask_points.append([])
|
38 |
+
# save the size of points
|
39 |
+
self.size_points = []
|
40 |
+
for i in range(len(color_list)):
|
41 |
+
self.size_points.append([])
|
42 |
+
# save the history of edit
|
43 |
+
self.history = []
|
44 |
+
|
45 |
+
self.mode = 0
|
46 |
+
self.prev_pt = None
|
47 |
+
|
48 |
+
def mousePressEvent(self, event):
|
49 |
+
self.mouse_clicked = True
|
50 |
+
|
51 |
+
def mouseReleaseEvent(self, event):
|
52 |
+
self.prev_pt = None
|
53 |
+
self.mouse_clicked = False
|
54 |
+
|
55 |
+
def mouseMoveEvent(self, event): # drawing
|
56 |
+
if self.mouse_clicked:
|
57 |
+
if self.prev_pt:
|
58 |
+
self.drawMask(self.prev_pt, event.scenePos(), color_list[self.mode], self.size)
|
59 |
+
pts = {}
|
60 |
+
pts['prev'] = (int(self.prev_pt.x()),int(self.prev_pt.y()))
|
61 |
+
pts['curr'] = (int(event.scenePos().x()),int(event.scenePos().y()))
|
62 |
+
|
63 |
+
self.size_points[self.mode].append(self.size)
|
64 |
+
self.mask_points[self.mode].append(pts)
|
65 |
+
self.history.append(self.mode)
|
66 |
+
self.prev_pt = event.scenePos()
|
67 |
+
else:
|
68 |
+
self.prev_pt = event.scenePos()
|
69 |
+
|
70 |
+
def drawMask(self, prev_pt, curr_pt, color, size):
|
71 |
+
lineItem = QGraphicsLineItem(QLineF(prev_pt, curr_pt))
|
72 |
+
lineItem.setPen(QPen(color, size, Qt.SolidLine)) # rect
|
73 |
+
self.addItem(lineItem)
|
74 |
+
|
75 |
+
def erase_prev_pt(self):
|
76 |
+
self.prev_pt = None
|
77 |
+
|
78 |
+
def reset_items(self):
|
79 |
+
for i in range(len(self.items())):
|
80 |
+
item = self.items()[0]
|
81 |
+
self.removeItem(item)
|
82 |
+
|
83 |
+
def undo(self):
|
84 |
+
if len(self.items())>1:
|
85 |
+
if len(self.items())>=9:
|
86 |
+
for i in range(8):
|
87 |
+
item = self.items()[0]
|
88 |
+
self.removeItem(item)
|
89 |
+
if self.history[-1] == self.mode:
|
90 |
+
self.mask_points[self.mode].pop()
|
91 |
+
self.size_points[self.mode].pop()
|
92 |
+
self.history.pop()
|
93 |
+
else:
|
94 |
+
for i in range(len(self.items())-1):
|
95 |
+
item = self.items()[0]
|
96 |
+
self.removeItem(item)
|
97 |
+
if self.history[-1] == self.mode:
|
98 |
+
self.mask_points[self.mode].pop()
|
99 |
+
self.size_points[self.mode].pop()
|
100 |
+
self.history.pop()
|
pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/ui.py
ADDED
@@ -0,0 +1,988 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PyQt5 import QtCore, QtGui, QtWidgets
|
2 |
+
from PyQt5.QtCore import Qt
|
3 |
+
|
4 |
+
class Ui_Form(object):
|
5 |
+
def setupUi(self, Form):
|
6 |
+
Form.setObjectName("Form")
|
7 |
+
Form.resize(1800, 660)
|
8 |
+
self.pushButton = QtWidgets.QPushButton(Form)
|
9 |
+
# self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
|
10 |
+
self.pushButton.setGeometry(QtCore.QRect(535, 360, 81, 27))
|
11 |
+
self.pushButton.setObjectName("pushButton")
|
12 |
+
self.pushButton_2 = QtWidgets.QPushButton(Form)
|
13 |
+
self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
|
14 |
+
self.pushButton_2.setObjectName("pushButton_2")
|
15 |
+
self.pushButton_3 = QtWidgets.QPushButton(Form)
|
16 |
+
self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
|
17 |
+
self.pushButton_3.setObjectName("pushButton_3")
|
18 |
+
self.pushButton_4 = QtWidgets.QPushButton(Form)
|
19 |
+
self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
|
20 |
+
self.pushButton_4.setObjectName("pushButton_4")
|
21 |
+
self.pushButton_5 = QtWidgets.QPushButton(Form)
|
22 |
+
self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
|
23 |
+
self.pushButton_5.setObjectName("pushButton_5")
|
24 |
+
self.pushButton_6 = QtWidgets.QPushButton(Form)
|
25 |
+
self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
|
26 |
+
self.pushButton_6.setObjectName("pushButton_6")
|
27 |
+
self.pushButton_7 = QtWidgets.QPushButton(Form)
|
28 |
+
self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
|
29 |
+
self.pushButton_7.setObjectName("pushButton_7")
|
30 |
+
self.pushButton_8 = QtWidgets.QPushButton(Form)
|
31 |
+
self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
|
32 |
+
self.pushButton_8.setObjectName("pushButton_8")
|
33 |
+
self.pushButton_9 = QtWidgets.QPushButton(Form)
|
34 |
+
self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
|
35 |
+
self.pushButton_9.setObjectName("pushButton_9")
|
36 |
+
self.pushButton_10 = QtWidgets.QPushButton(Form)
|
37 |
+
self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
|
38 |
+
self.pushButton_10.setObjectName("pushButton_10")
|
39 |
+
self.pushButton_11 = QtWidgets.QPushButton(Form)
|
40 |
+
self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
|
41 |
+
self.pushButton_11.setObjectName("pushButton_11")
|
42 |
+
self.pushButton_12 = QtWidgets.QPushButton(Form)
|
43 |
+
self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
|
44 |
+
self.pushButton_12.setObjectName("pushButton_12")
|
45 |
+
self.pushButton_13 = QtWidgets.QPushButton(Form)
|
46 |
+
self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
|
47 |
+
self.pushButton_13.setObjectName("pushButton_13")
|
48 |
+
self.pushButton_14 = QtWidgets.QPushButton(Form)
|
49 |
+
self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
|
50 |
+
self.pushButton_14.setObjectName("pushButton_14")
|
51 |
+
self.pushButton_15 = QtWidgets.QPushButton(Form)
|
52 |
+
self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
|
53 |
+
self.pushButton_15.setObjectName("pushButton_15")
|
54 |
+
self.pushButton_16 = QtWidgets.QPushButton(Form)
|
55 |
+
self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
|
56 |
+
self.pushButton_16.setObjectName("pushButton_16")
|
57 |
+
self.pushButton_17 = QtWidgets.QPushButton(Form)
|
58 |
+
self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
|
59 |
+
self.pushButton_17.setObjectName("pushButton_17")
|
60 |
+
self.pushButton_18 = QtWidgets.QPushButton(Form)
|
61 |
+
self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
|
62 |
+
self.pushButton_18.setObjectName("pushButton_18")
|
63 |
+
self.pushButton_19 = QtWidgets.QPushButton(Form)
|
64 |
+
self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
|
65 |
+
self.pushButton_19.setObjectName("pushButton_19")
|
66 |
+
self.pushButton_20 = QtWidgets.QPushButton(Form)
|
67 |
+
self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
|
68 |
+
self.pushButton_20.setObjectName("pushButton_20")
|
69 |
+
self.pushButton_21 = QtWidgets.QPushButton(Form)
|
70 |
+
self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
|
71 |
+
self.pushButton_21.setObjectName("pushButton_21")
|
72 |
+
self.pushButton_22 = QtWidgets.QPushButton(Form)
|
73 |
+
self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
|
74 |
+
self.pushButton_22.setObjectName("pushButton_22")
|
75 |
+
self.pushButton_23 = QtWidgets.QPushButton(Form)
|
76 |
+
self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
|
77 |
+
self.pushButton_23.setObjectName("pushButton_23")
|
78 |
+
self.pushButton_24 = QtWidgets.QPushButton(Form)
|
79 |
+
self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
|
80 |
+
self.pushButton_24.setObjectName("pushButton_24")
|
81 |
+
self.pushButton_25 = QtWidgets.QPushButton(Form)
|
82 |
+
self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
|
83 |
+
self.pushButton_25.setObjectName("pushButton_25")
|
84 |
+
# self.pushButton_26 = QtWidgets.QPushButton(Form)
|
85 |
+
# self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
86 |
+
# self.pushButton_26.setObjectName("pushButton_26")
|
87 |
+
# self.pushButton_27 = QtWidgets.QPushButton(Form)
|
88 |
+
# self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
89 |
+
# self.pushButton_27.setObjectName("pushButton_27")
|
90 |
+
|
91 |
+
|
92 |
+
self.slider_sizeselect = QtWidgets.QSlider(Form)
|
93 |
+
self.slider_sizeselect.setRange(10,70)
|
94 |
+
self.slider_sizeselect.setOrientation(Qt.Horizontal)
|
95 |
+
self.slider_sizeselect.setValue(Form.size)
|
96 |
+
self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
|
97 |
+
|
98 |
+
self.label_sizeselect = QtWidgets.QLabel(Form)
|
99 |
+
self.label_sizeselect.setText("Brush Size")
|
100 |
+
self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
|
101 |
+
|
102 |
+
self.slider_yawselect = QtWidgets.QSlider(Form)
|
103 |
+
self.slider_yawselect.setRange(-100,100)
|
104 |
+
self.slider_yawselect.setOrientation(Qt.Horizontal)
|
105 |
+
self.slider_yawselect.setValue(Form.yaw)
|
106 |
+
self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
107 |
+
|
108 |
+
self.label_yawselect = QtWidgets.QLabel(Form)
|
109 |
+
self.label_yawselect.setText("Yaw")
|
110 |
+
self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
|
111 |
+
|
112 |
+
self.slider_pitchselect = QtWidgets.QSlider(Form)
|
113 |
+
self.slider_pitchselect.setRange(-100,100)
|
114 |
+
self.slider_pitchselect.setOrientation(Qt.Horizontal)
|
115 |
+
self.slider_pitchselect.setValue(Form.pitch)
|
116 |
+
self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
117 |
+
|
118 |
+
self.label_pitchselect = QtWidgets.QLabel(Form)
|
119 |
+
self.label_pitchselect.setText("Pitch")
|
120 |
+
self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
|
121 |
+
|
122 |
+
self.text_inputID = QtWidgets.QTextEdit(Form)
|
123 |
+
self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
|
124 |
+
self.text_inputID.setObjectName("text_inputID")
|
125 |
+
|
126 |
+
self.pushButton_inputID = QtWidgets.QPushButton(Form)
|
127 |
+
self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
|
128 |
+
self.pushButton_inputID.setObjectName("pushButton_inputID")
|
129 |
+
|
130 |
+
self.text_seed = QtWidgets.QTextEdit(Form)
|
131 |
+
self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
|
132 |
+
self.text_seed.setObjectName("text_seed")
|
133 |
+
self.text_seed.setPlainText("0")
|
134 |
+
|
135 |
+
self.label_seed = QtWidgets.QLabel(Form)
|
136 |
+
self.label_seed.setText("Seed")
|
137 |
+
self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
self.graphicsView = QtWidgets.QGraphicsView(Form)
|
143 |
+
self.graphicsView.setGeometry(QtCore.QRect(20, 120, 512, 512))
|
144 |
+
self.graphicsView.setObjectName("graphicsView")
|
145 |
+
self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
|
146 |
+
self.graphicsView_2.setGeometry(QtCore.QRect(620, 120, 512, 512))
|
147 |
+
self.graphicsView_2.setObjectName("graphicsView_2")
|
148 |
+
self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
|
149 |
+
self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
|
150 |
+
self.graphicsView_3.setObjectName("graphicsView_3")
|
151 |
+
|
152 |
+
|
153 |
+
self.retranslateUi(Form)
|
154 |
+
self.pushButton.clicked.connect(Form.generateAndReconstruct)
|
155 |
+
self.pushButton_2.clicked.connect(Form.open)
|
156 |
+
self.pushButton_3.clicked.connect(Form.open_mask)
|
157 |
+
self.pushButton_4.clicked.connect(Form.clear)
|
158 |
+
self.pushButton_5.clicked.connect(Form.undo)
|
159 |
+
self.pushButton_6.clicked.connect(Form.save_img)
|
160 |
+
self.pushButton_7.clicked.connect(Form.bg_mode)
|
161 |
+
self.pushButton_8.clicked.connect(Form.skin_mode)
|
162 |
+
self.pushButton_9.clicked.connect(Form.nose_mode)
|
163 |
+
self.pushButton_10.clicked.connect(Form.eye_g_mode)
|
164 |
+
self.pushButton_11.clicked.connect(Form.l_eye_mode)
|
165 |
+
self.pushButton_12.clicked.connect(Form.r_eye_mode)
|
166 |
+
self.pushButton_13.clicked.connect(Form.l_brow_mode)
|
167 |
+
self.pushButton_14.clicked.connect(Form.r_brow_mode)
|
168 |
+
self.pushButton_15.clicked.connect(Form.l_ear_mode)
|
169 |
+
self.pushButton_16.clicked.connect(Form.r_ear_mode)
|
170 |
+
self.pushButton_17.clicked.connect(Form.mouth_mode)
|
171 |
+
self.pushButton_18.clicked.connect(Form.u_lip_mode)
|
172 |
+
self.pushButton_19.clicked.connect(Form.l_lip_mode)
|
173 |
+
self.pushButton_20.clicked.connect(Form.hair_mode)
|
174 |
+
self.pushButton_21.clicked.connect(Form.hat_mode)
|
175 |
+
self.pushButton_22.clicked.connect(Form.ear_r_mode)
|
176 |
+
self.pushButton_23.clicked.connect(Form.neck_l_mode)
|
177 |
+
self.pushButton_24.clicked.connect(Form.neck_mode)
|
178 |
+
self.pushButton_25.clicked.connect(Form.cloth_mode)
|
179 |
+
# self.pushButton_26.clicked.connect(Form.increase)
|
180 |
+
# self.pushButton_27.clicked.connect(Form.decrease)
|
181 |
+
|
182 |
+
|
183 |
+
self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
|
184 |
+
self.slider_yawselect.valueChanged.connect(Form.changeYaw)
|
185 |
+
self.slider_pitchselect.valueChanged.connect(Form.changePitch)
|
186 |
+
|
187 |
+
self.pushButton_inputID.clicked.connect(Form.inputID)
|
188 |
+
|
189 |
+
QtCore.QMetaObject.connectSlotsByName(Form)
|
190 |
+
|
191 |
+
def retranslateUi(self, Form):
|
192 |
+
_translate = QtCore.QCoreApplication.translate
|
193 |
+
Form.setWindowTitle(_translate("Form", "3D-GauGAN"))
|
194 |
+
self.pushButton.setText(_translate("Form", "Generate"))
|
195 |
+
self.pushButton_2.setText(_translate("Form", "Open Image"))
|
196 |
+
self.pushButton_3.setText(_translate("Form", "Open Mask"))
|
197 |
+
self.pushButton_4.setText(_translate("Form", "Clear"))
|
198 |
+
self.pushButton_5.setText(_translate("Form", "Undo"))
|
199 |
+
self.pushButton_6.setText(_translate("Form", "Save Image"))
|
200 |
+
self.pushButton_7.setText(_translate("Form", "BackGround"))
|
201 |
+
self.pushButton_8.setText(_translate("Form", "Skin"))
|
202 |
+
self.pushButton_9.setText(_translate("Form", "Nose"))
|
203 |
+
self.pushButton_10.setText(_translate("Form", "Eyeglass"))
|
204 |
+
self.pushButton_11.setText(_translate("Form", "Left Eye"))
|
205 |
+
self.pushButton_12.setText(_translate("Form", "Right Eye"))
|
206 |
+
self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
|
207 |
+
self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
|
208 |
+
self.pushButton_15.setText(_translate("Form", "Left ear"))
|
209 |
+
self.pushButton_16.setText(_translate("Form", "Right ear"))
|
210 |
+
self.pushButton_17.setText(_translate("Form", "Mouth"))
|
211 |
+
self.pushButton_18.setText(_translate("Form", "Upper Lip"))
|
212 |
+
self.pushButton_19.setText(_translate("Form", "Lower Lip"))
|
213 |
+
self.pushButton_20.setText(_translate("Form", "Hair"))
|
214 |
+
self.pushButton_21.setText(_translate("Form", "Hat"))
|
215 |
+
self.pushButton_22.setText(_translate("Form", "Earring"))
|
216 |
+
self.pushButton_23.setText(_translate("Form", "Necklace"))
|
217 |
+
self.pushButton_24.setText(_translate("Form", "Neck"))
|
218 |
+
self.pushButton_25.setText(_translate("Form", "Cloth"))
|
219 |
+
# self.pushButton_26.setText(_translate("Form", "+"))
|
220 |
+
# self.pushButton_27.setText(_translate("Form", "-"))
|
221 |
+
self.pushButton_inputID.setText(_translate("Form", "Input ID"))
|
222 |
+
|
223 |
+
|
224 |
+
class Ui_Form_Seg(object):
|
225 |
+
def setupUi(self, Form):
|
226 |
+
Form.setObjectName("Form")
|
227 |
+
Form.resize(1800, 1260)
|
228 |
+
self.pushButton = QtWidgets.QPushButton(Form)
|
229 |
+
# self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
|
230 |
+
self.pushButton.setGeometry(QtCore.QRect(535, 360, 81, 27))
|
231 |
+
self.pushButton.setObjectName("pushButton")
|
232 |
+
self.pushButton_2 = QtWidgets.QPushButton(Form)
|
233 |
+
self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
|
234 |
+
self.pushButton_2.setObjectName("pushButton_2")
|
235 |
+
self.pushButton_3 = QtWidgets.QPushButton(Form)
|
236 |
+
self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
|
237 |
+
self.pushButton_3.setObjectName("pushButton_3")
|
238 |
+
self.pushButton_4 = QtWidgets.QPushButton(Form)
|
239 |
+
self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
|
240 |
+
self.pushButton_4.setObjectName("pushButton_4")
|
241 |
+
self.pushButton_5 = QtWidgets.QPushButton(Form)
|
242 |
+
self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
|
243 |
+
self.pushButton_5.setObjectName("pushButton_5")
|
244 |
+
self.pushButton_6 = QtWidgets.QPushButton(Form)
|
245 |
+
self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
|
246 |
+
self.pushButton_6.setObjectName("pushButton_6")
|
247 |
+
self.pushButton_7 = QtWidgets.QPushButton(Form)
|
248 |
+
self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
|
249 |
+
self.pushButton_7.setObjectName("pushButton_7")
|
250 |
+
self.pushButton_8 = QtWidgets.QPushButton(Form)
|
251 |
+
self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
|
252 |
+
self.pushButton_8.setObjectName("pushButton_8")
|
253 |
+
self.pushButton_9 = QtWidgets.QPushButton(Form)
|
254 |
+
self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
|
255 |
+
self.pushButton_9.setObjectName("pushButton_9")
|
256 |
+
self.pushButton_10 = QtWidgets.QPushButton(Form)
|
257 |
+
self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
|
258 |
+
self.pushButton_10.setObjectName("pushButton_10")
|
259 |
+
self.pushButton_11 = QtWidgets.QPushButton(Form)
|
260 |
+
self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
|
261 |
+
self.pushButton_11.setObjectName("pushButton_11")
|
262 |
+
self.pushButton_12 = QtWidgets.QPushButton(Form)
|
263 |
+
self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
|
264 |
+
self.pushButton_12.setObjectName("pushButton_12")
|
265 |
+
self.pushButton_13 = QtWidgets.QPushButton(Form)
|
266 |
+
self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
|
267 |
+
self.pushButton_13.setObjectName("pushButton_13")
|
268 |
+
self.pushButton_14 = QtWidgets.QPushButton(Form)
|
269 |
+
self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
|
270 |
+
self.pushButton_14.setObjectName("pushButton_14")
|
271 |
+
self.pushButton_15 = QtWidgets.QPushButton(Form)
|
272 |
+
self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
|
273 |
+
self.pushButton_15.setObjectName("pushButton_15")
|
274 |
+
self.pushButton_16 = QtWidgets.QPushButton(Form)
|
275 |
+
self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
|
276 |
+
self.pushButton_16.setObjectName("pushButton_16")
|
277 |
+
self.pushButton_17 = QtWidgets.QPushButton(Form)
|
278 |
+
self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
|
279 |
+
self.pushButton_17.setObjectName("pushButton_17")
|
280 |
+
self.pushButton_18 = QtWidgets.QPushButton(Form)
|
281 |
+
self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
|
282 |
+
self.pushButton_18.setObjectName("pushButton_18")
|
283 |
+
self.pushButton_19 = QtWidgets.QPushButton(Form)
|
284 |
+
self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
|
285 |
+
self.pushButton_19.setObjectName("pushButton_19")
|
286 |
+
self.pushButton_20 = QtWidgets.QPushButton(Form)
|
287 |
+
self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
|
288 |
+
self.pushButton_20.setObjectName("pushButton_20")
|
289 |
+
self.pushButton_21 = QtWidgets.QPushButton(Form)
|
290 |
+
self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
|
291 |
+
self.pushButton_21.setObjectName("pushButton_21")
|
292 |
+
self.pushButton_22 = QtWidgets.QPushButton(Form)
|
293 |
+
self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
|
294 |
+
self.pushButton_22.setObjectName("pushButton_22")
|
295 |
+
self.pushButton_23 = QtWidgets.QPushButton(Form)
|
296 |
+
self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
|
297 |
+
self.pushButton_23.setObjectName("pushButton_23")
|
298 |
+
self.pushButton_24 = QtWidgets.QPushButton(Form)
|
299 |
+
self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
|
300 |
+
self.pushButton_24.setObjectName("pushButton_24")
|
301 |
+
self.pushButton_25 = QtWidgets.QPushButton(Form)
|
302 |
+
self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
|
303 |
+
self.pushButton_25.setObjectName("pushButton_25")
|
304 |
+
# self.pushButton_26 = QtWidgets.QPushButton(Form)
|
305 |
+
# self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
306 |
+
# self.pushButton_26.setObjectName("pushButton_26")
|
307 |
+
# self.pushButton_27 = QtWidgets.QPushButton(Form)
|
308 |
+
# self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
309 |
+
# self.pushButton_27.setObjectName("pushButton_27")
|
310 |
+
|
311 |
+
self.slider_sizeselect = QtWidgets.QSlider(Form)
|
312 |
+
self.slider_sizeselect.setRange(10,70)
|
313 |
+
self.slider_sizeselect.setOrientation(Qt.Horizontal)
|
314 |
+
self.slider_sizeselect.setValue(Form.size)
|
315 |
+
self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
|
316 |
+
|
317 |
+
self.label_sizeselect = QtWidgets.QLabel(Form)
|
318 |
+
self.label_sizeselect.setText("Brush Size")
|
319 |
+
self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
|
320 |
+
|
321 |
+
self.slider_yawselect = QtWidgets.QSlider(Form)
|
322 |
+
self.slider_yawselect.setRange(-100,100)
|
323 |
+
self.slider_yawselect.setOrientation(Qt.Horizontal)
|
324 |
+
self.slider_yawselect.setValue(Form.yaw)
|
325 |
+
self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
326 |
+
|
327 |
+
self.label_yawselect = QtWidgets.QLabel(Form)
|
328 |
+
self.label_yawselect.setText("Yaw")
|
329 |
+
self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
|
330 |
+
|
331 |
+
self.slider_pitchselect = QtWidgets.QSlider(Form)
|
332 |
+
self.slider_pitchselect.setRange(-100,100)
|
333 |
+
self.slider_pitchselect.setOrientation(Qt.Horizontal)
|
334 |
+
self.slider_pitchselect.setValue(Form.pitch)
|
335 |
+
self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
336 |
+
|
337 |
+
self.label_pitchselect = QtWidgets.QLabel(Form)
|
338 |
+
self.label_pitchselect.setText("Pitch")
|
339 |
+
self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
|
340 |
+
|
341 |
+
self.slider_truncation = QtWidgets.QSlider(Form)
|
342 |
+
self.slider_truncation.setRange(0,100)
|
343 |
+
self.slider_truncation.setOrientation(Qt.Horizontal)
|
344 |
+
self.slider_truncation.setValue(Form.truncation)
|
345 |
+
self.slider_truncation.setGeometry(QtCore.QRect(1530, 100, 97, 27))
|
346 |
+
|
347 |
+
self.label_truncation = QtWidgets.QLabel(Form)
|
348 |
+
self.label_truncation.setText("Truncation")
|
349 |
+
self.label_truncation.setGeometry(QtCore.QRect(1630, 100, 97, 27))
|
350 |
+
|
351 |
+
self.text_inputID = QtWidgets.QTextEdit(Form)
|
352 |
+
self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
|
353 |
+
self.text_inputID.setObjectName("text_inputID")
|
354 |
+
|
355 |
+
self.pushButton_inputID = QtWidgets.QPushButton(Form)
|
356 |
+
self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
|
357 |
+
self.pushButton_inputID.setObjectName("pushButton_inputID")
|
358 |
+
|
359 |
+
self.text_seed = QtWidgets.QTextEdit(Form)
|
360 |
+
self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
|
361 |
+
self.text_seed.setObjectName("text_seed")
|
362 |
+
self.text_seed.setPlainText("0")
|
363 |
+
|
364 |
+
self.label_seed = QtWidgets.QLabel(Form)
|
365 |
+
self.label_seed.setText("Seed")
|
366 |
+
self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
|
367 |
+
|
368 |
+
self.pushButton_inverse = QtWidgets.QPushButton(Form)
|
369 |
+
self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
|
370 |
+
self.pushButton_inverse.setObjectName("pushButton_inverse")
|
371 |
+
|
372 |
+
self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
|
373 |
+
self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
|
374 |
+
self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
|
375 |
+
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
self.graphicsView = QtWidgets.QGraphicsView(Form)
|
380 |
+
self.graphicsView.setGeometry(QtCore.QRect(20, 120, 512, 512))
|
381 |
+
self.graphicsView.setObjectName("graphicsView")
|
382 |
+
self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
|
383 |
+
self.graphicsView_2.setGeometry(QtCore.QRect(620, 120, 512, 512))
|
384 |
+
self.graphicsView_2.setObjectName("graphicsView_2")
|
385 |
+
self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
|
386 |
+
self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
|
387 |
+
self.graphicsView_3.setObjectName("graphicsView_3")
|
388 |
+
|
389 |
+
self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
|
390 |
+
self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
|
391 |
+
self.graphicsView_5.setObjectName("graphicsView_5")
|
392 |
+
self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
|
393 |
+
self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
|
394 |
+
self.graphicsView_6.setObjectName("graphicsView_6")
|
395 |
+
|
396 |
+
|
397 |
+
self.retranslateUi(Form)
|
398 |
+
self.pushButton.clicked.connect(Form.generateAndReconstruct)
|
399 |
+
self.pushButton_2.clicked.connect(Form.open)
|
400 |
+
self.pushButton_3.clicked.connect(Form.open_mask)
|
401 |
+
self.pushButton_4.clicked.connect(Form.clear)
|
402 |
+
self.pushButton_5.clicked.connect(Form.undo)
|
403 |
+
self.pushButton_6.clicked.connect(Form.save_img)
|
404 |
+
self.pushButton_7.clicked.connect(Form.bg_mode)
|
405 |
+
self.pushButton_8.clicked.connect(Form.skin_mode)
|
406 |
+
self.pushButton_9.clicked.connect(Form.nose_mode)
|
407 |
+
self.pushButton_10.clicked.connect(Form.eye_g_mode)
|
408 |
+
self.pushButton_11.clicked.connect(Form.l_eye_mode)
|
409 |
+
self.pushButton_12.clicked.connect(Form.r_eye_mode)
|
410 |
+
self.pushButton_13.clicked.connect(Form.l_brow_mode)
|
411 |
+
self.pushButton_14.clicked.connect(Form.r_brow_mode)
|
412 |
+
self.pushButton_15.clicked.connect(Form.l_ear_mode)
|
413 |
+
self.pushButton_16.clicked.connect(Form.r_ear_mode)
|
414 |
+
self.pushButton_17.clicked.connect(Form.mouth_mode)
|
415 |
+
self.pushButton_18.clicked.connect(Form.u_lip_mode)
|
416 |
+
self.pushButton_19.clicked.connect(Form.l_lip_mode)
|
417 |
+
self.pushButton_20.clicked.connect(Form.hair_mode)
|
418 |
+
self.pushButton_21.clicked.connect(Form.hat_mode)
|
419 |
+
self.pushButton_22.clicked.connect(Form.ear_r_mode)
|
420 |
+
self.pushButton_23.clicked.connect(Form.neck_l_mode)
|
421 |
+
self.pushButton_24.clicked.connect(Form.neck_mode)
|
422 |
+
self.pushButton_25.clicked.connect(Form.cloth_mode)
|
423 |
+
# self.pushButton_26.clicked.connect(Form.increase)
|
424 |
+
# self.pushButton_27.clicked.connect(Form.decrease)
|
425 |
+
|
426 |
+
self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
|
427 |
+
self.slider_yawselect.valueChanged.connect(Form.changeYaw)
|
428 |
+
self.slider_pitchselect.valueChanged.connect(Form.changePitch)
|
429 |
+
self.slider_truncation.valueChanged.connect(Form.changeTruncation)
|
430 |
+
|
431 |
+
self.pushButton_inputID.clicked.connect(Form.inputID)
|
432 |
+
|
433 |
+
self.pushButton_inverse.clicked.connect(Form.inverse)
|
434 |
+
self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
|
435 |
+
|
436 |
+
QtCore.QMetaObject.connectSlotsByName(Form)
|
437 |
+
|
438 |
+
def retranslateUi(self, Form):
|
439 |
+
_translate = QtCore.QCoreApplication.translate
|
440 |
+
Form.setWindowTitle(_translate("Form", "3D-GauGAN"))
|
441 |
+
self.pushButton.setText(_translate("Form", "Generate"))
|
442 |
+
self.pushButton_2.setText(_translate("Form", "Open Image"))
|
443 |
+
self.pushButton_3.setText(_translate("Form", "Open Mask"))
|
444 |
+
self.pushButton_4.setText(_translate("Form", "Clear"))
|
445 |
+
self.pushButton_5.setText(_translate("Form", "Undo"))
|
446 |
+
self.pushButton_6.setText(_translate("Form", "Save Image"))
|
447 |
+
self.pushButton_7.setText(_translate("Form", "BackGround"))
|
448 |
+
self.pushButton_8.setText(_translate("Form", "Skin"))
|
449 |
+
self.pushButton_9.setText(_translate("Form", "Nose"))
|
450 |
+
self.pushButton_10.setText(_translate("Form", "Eyeglass"))
|
451 |
+
self.pushButton_11.setText(_translate("Form", "Left Eye"))
|
452 |
+
self.pushButton_12.setText(_translate("Form", "Right Eye"))
|
453 |
+
self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
|
454 |
+
self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
|
455 |
+
self.pushButton_15.setText(_translate("Form", "Left ear"))
|
456 |
+
self.pushButton_16.setText(_translate("Form", "Right ear"))
|
457 |
+
self.pushButton_17.setText(_translate("Form", "Mouth"))
|
458 |
+
self.pushButton_18.setText(_translate("Form", "Upper Lip"))
|
459 |
+
self.pushButton_19.setText(_translate("Form", "Lower Lip"))
|
460 |
+
self.pushButton_20.setText(_translate("Form", "Hair"))
|
461 |
+
self.pushButton_21.setText(_translate("Form", "Hat"))
|
462 |
+
self.pushButton_22.setText(_translate("Form", "Earring"))
|
463 |
+
self.pushButton_23.setText(_translate("Form", "Necklace"))
|
464 |
+
self.pushButton_24.setText(_translate("Form", "Neck"))
|
465 |
+
self.pushButton_25.setText(_translate("Form", "Cloth"))
|
466 |
+
# self.pushButton_26.setText(_translate("Form", "+"))
|
467 |
+
# self.pushButton_27.setText(_translate("Form", "-"))
|
468 |
+
self.pushButton_inputID.setText(_translate("Form", "Input ID"))
|
469 |
+
self.pushButton_inverse.setText(_translate("Form", "Inverse"))
|
470 |
+
self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
|
471 |
+
|
472 |
+
|
473 |
+
class Ui_Form_Video(object):
|
474 |
+
def setupUi(self, Form):
|
475 |
+
Form.setObjectName("Form")
|
476 |
+
Form.resize(1800, 1260)
|
477 |
+
self.pushButton = QtWidgets.QPushButton(Form)
|
478 |
+
# self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
|
479 |
+
self.pushButton.setGeometry(QtCore.QRect(535, 360, 81, 27))
|
480 |
+
self.pushButton.setObjectName("pushButton")
|
481 |
+
self.pushButton_2 = QtWidgets.QPushButton(Form)
|
482 |
+
self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
|
483 |
+
self.pushButton_2.setObjectName("pushButton_2")
|
484 |
+
self.pushButton_3 = QtWidgets.QPushButton(Form)
|
485 |
+
self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
|
486 |
+
self.pushButton_3.setObjectName("pushButton_3")
|
487 |
+
self.pushButton_4 = QtWidgets.QPushButton(Form)
|
488 |
+
self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
|
489 |
+
self.pushButton_4.setObjectName("pushButton_4")
|
490 |
+
self.pushButton_5 = QtWidgets.QPushButton(Form)
|
491 |
+
self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
|
492 |
+
self.pushButton_5.setObjectName("pushButton_5")
|
493 |
+
self.pushButton_6 = QtWidgets.QPushButton(Form)
|
494 |
+
self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
|
495 |
+
self.pushButton_6.setObjectName("pushButton_6")
|
496 |
+
self.pushButton_7 = QtWidgets.QPushButton(Form)
|
497 |
+
self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
|
498 |
+
self.pushButton_7.setObjectName("pushButton_7")
|
499 |
+
self.pushButton_8 = QtWidgets.QPushButton(Form)
|
500 |
+
self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
|
501 |
+
self.pushButton_8.setObjectName("pushButton_8")
|
502 |
+
self.pushButton_9 = QtWidgets.QPushButton(Form)
|
503 |
+
self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
|
504 |
+
self.pushButton_9.setObjectName("pushButton_9")
|
505 |
+
self.pushButton_10 = QtWidgets.QPushButton(Form)
|
506 |
+
self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
|
507 |
+
self.pushButton_10.setObjectName("pushButton_10")
|
508 |
+
self.pushButton_11 = QtWidgets.QPushButton(Form)
|
509 |
+
self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
|
510 |
+
self.pushButton_11.setObjectName("pushButton_11")
|
511 |
+
self.pushButton_12 = QtWidgets.QPushButton(Form)
|
512 |
+
self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
|
513 |
+
self.pushButton_12.setObjectName("pushButton_12")
|
514 |
+
self.pushButton_13 = QtWidgets.QPushButton(Form)
|
515 |
+
self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
|
516 |
+
self.pushButton_13.setObjectName("pushButton_13")
|
517 |
+
self.pushButton_14 = QtWidgets.QPushButton(Form)
|
518 |
+
self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
|
519 |
+
self.pushButton_14.setObjectName("pushButton_14")
|
520 |
+
self.pushButton_15 = QtWidgets.QPushButton(Form)
|
521 |
+
self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
|
522 |
+
self.pushButton_15.setObjectName("pushButton_15")
|
523 |
+
self.pushButton_16 = QtWidgets.QPushButton(Form)
|
524 |
+
self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
|
525 |
+
self.pushButton_16.setObjectName("pushButton_16")
|
526 |
+
self.pushButton_17 = QtWidgets.QPushButton(Form)
|
527 |
+
self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
|
528 |
+
self.pushButton_17.setObjectName("pushButton_17")
|
529 |
+
self.pushButton_18 = QtWidgets.QPushButton(Form)
|
530 |
+
self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
|
531 |
+
self.pushButton_18.setObjectName("pushButton_18")
|
532 |
+
self.pushButton_19 = QtWidgets.QPushButton(Form)
|
533 |
+
self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
|
534 |
+
self.pushButton_19.setObjectName("pushButton_19")
|
535 |
+
self.pushButton_20 = QtWidgets.QPushButton(Form)
|
536 |
+
self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
|
537 |
+
self.pushButton_20.setObjectName("pushButton_20")
|
538 |
+
self.pushButton_21 = QtWidgets.QPushButton(Form)
|
539 |
+
self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
|
540 |
+
self.pushButton_21.setObjectName("pushButton_21")
|
541 |
+
self.pushButton_22 = QtWidgets.QPushButton(Form)
|
542 |
+
self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
|
543 |
+
self.pushButton_22.setObjectName("pushButton_22")
|
544 |
+
self.pushButton_23 = QtWidgets.QPushButton(Form)
|
545 |
+
self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
|
546 |
+
self.pushButton_23.setObjectName("pushButton_23")
|
547 |
+
self.pushButton_24 = QtWidgets.QPushButton(Form)
|
548 |
+
self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
|
549 |
+
self.pushButton_24.setObjectName("pushButton_24")
|
550 |
+
self.pushButton_25 = QtWidgets.QPushButton(Form)
|
551 |
+
self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
|
552 |
+
self.pushButton_25.setObjectName("pushButton_25")
|
553 |
+
# self.pushButton_26 = QtWidgets.QPushButton(Form)
|
554 |
+
# self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
555 |
+
# self.pushButton_26.setObjectName("pushButton_26")
|
556 |
+
# self.pushButton_27 = QtWidgets.QPushButton(Form)
|
557 |
+
# self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
558 |
+
# self.pushButton_27.setObjectName("pushButton_27")
|
559 |
+
|
560 |
+
self.slider_sizeselect = QtWidgets.QSlider(Form)
|
561 |
+
self.slider_sizeselect.setRange(10,70)
|
562 |
+
self.slider_sizeselect.setOrientation(Qt.Horizontal)
|
563 |
+
self.slider_sizeselect.setValue(Form.size)
|
564 |
+
self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
|
565 |
+
|
566 |
+
self.label_sizeselect = QtWidgets.QLabel(Form)
|
567 |
+
self.label_sizeselect.setText("Brush Size")
|
568 |
+
self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
|
569 |
+
|
570 |
+
self.slider_yawselect = QtWidgets.QSlider(Form)
|
571 |
+
self.slider_yawselect.setRange(-100,100)
|
572 |
+
self.slider_yawselect.setOrientation(Qt.Horizontal)
|
573 |
+
self.slider_yawselect.setValue(Form.yaw)
|
574 |
+
self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
575 |
+
|
576 |
+
self.label_yawselect = QtWidgets.QLabel(Form)
|
577 |
+
self.label_yawselect.setText("Yaw")
|
578 |
+
self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
|
579 |
+
|
580 |
+
self.slider_pitchselect = QtWidgets.QSlider(Form)
|
581 |
+
self.slider_pitchselect.setRange(-100,100)
|
582 |
+
self.slider_pitchselect.setOrientation(Qt.Horizontal)
|
583 |
+
self.slider_pitchselect.setValue(Form.pitch)
|
584 |
+
self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
585 |
+
|
586 |
+
self.label_pitchselect = QtWidgets.QLabel(Form)
|
587 |
+
self.label_pitchselect.setText("Pitch")
|
588 |
+
self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
|
589 |
+
|
590 |
+
self.slider_truncation = QtWidgets.QSlider(Form)
|
591 |
+
self.slider_truncation.setRange(0,100)
|
592 |
+
self.slider_truncation.setOrientation(Qt.Horizontal)
|
593 |
+
self.slider_truncation.setValue(Form.truncation)
|
594 |
+
self.slider_truncation.setGeometry(QtCore.QRect(1530, 100, 97, 27))
|
595 |
+
|
596 |
+
self.label_truncation = QtWidgets.QLabel(Form)
|
597 |
+
self.label_truncation.setText("Truncation")
|
598 |
+
self.label_truncation.setGeometry(QtCore.QRect(1630, 100, 97, 27))
|
599 |
+
|
600 |
+
self.text_inputID = QtWidgets.QTextEdit(Form)
|
601 |
+
self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
|
602 |
+
self.text_inputID.setObjectName("text_inputID")
|
603 |
+
|
604 |
+
self.pushButton_inputID = QtWidgets.QPushButton(Form)
|
605 |
+
self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
|
606 |
+
self.pushButton_inputID.setObjectName("pushButton_inputID")
|
607 |
+
|
608 |
+
self.text_seed = QtWidgets.QTextEdit(Form)
|
609 |
+
self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
|
610 |
+
self.text_seed.setObjectName("text_seed")
|
611 |
+
self.text_seed.setPlainText("0")
|
612 |
+
|
613 |
+
self.label_seed = QtWidgets.QLabel(Form)
|
614 |
+
self.label_seed.setText("Seed")
|
615 |
+
self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
|
616 |
+
|
617 |
+
# self.pushButton_inverse = QtWidgets.QPushButton(Form)
|
618 |
+
# self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
|
619 |
+
# self.pushButton_inverse.setObjectName("pushButton_inverse")
|
620 |
+
|
621 |
+
# self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
|
622 |
+
# self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
|
623 |
+
# self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
|
624 |
+
self.pushButton_get = QtWidgets.QPushButton(Form)
|
625 |
+
self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
|
626 |
+
self.pushButton_get.setObjectName("pushButton_get")
|
627 |
+
|
628 |
+
|
629 |
+
|
630 |
+
|
631 |
+
self.graphicsView = QtWidgets.QGraphicsView(Form)
|
632 |
+
self.graphicsView.setGeometry(QtCore.QRect(20, 120, 512, 512))
|
633 |
+
self.graphicsView.setObjectName("graphicsView")
|
634 |
+
self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
|
635 |
+
self.graphicsView_2.setGeometry(QtCore.QRect(620, 120, 512, 512))
|
636 |
+
self.graphicsView_2.setObjectName("graphicsView_2")
|
637 |
+
self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
|
638 |
+
self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
|
639 |
+
self.graphicsView_3.setObjectName("graphicsView_3")
|
640 |
+
|
641 |
+
self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
|
642 |
+
self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
|
643 |
+
self.graphicsView_5.setObjectName("graphicsView_5")
|
644 |
+
self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
|
645 |
+
self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
|
646 |
+
self.graphicsView_6.setObjectName("graphicsView_6")
|
647 |
+
|
648 |
+
|
649 |
+
self.retranslateUi(Form)
|
650 |
+
self.pushButton.clicked.connect(Form.generateAndReconstruct)
|
651 |
+
self.pushButton_2.clicked.connect(Form.open)
|
652 |
+
self.pushButton_3.clicked.connect(Form.open_mask)
|
653 |
+
self.pushButton_4.clicked.connect(Form.clear)
|
654 |
+
self.pushButton_5.clicked.connect(Form.undo)
|
655 |
+
self.pushButton_6.clicked.connect(Form.save_img)
|
656 |
+
self.pushButton_7.clicked.connect(Form.bg_mode)
|
657 |
+
self.pushButton_8.clicked.connect(Form.skin_mode)
|
658 |
+
self.pushButton_9.clicked.connect(Form.nose_mode)
|
659 |
+
self.pushButton_10.clicked.connect(Form.eye_g_mode)
|
660 |
+
self.pushButton_11.clicked.connect(Form.l_eye_mode)
|
661 |
+
self.pushButton_12.clicked.connect(Form.r_eye_mode)
|
662 |
+
self.pushButton_13.clicked.connect(Form.l_brow_mode)
|
663 |
+
self.pushButton_14.clicked.connect(Form.r_brow_mode)
|
664 |
+
self.pushButton_15.clicked.connect(Form.l_ear_mode)
|
665 |
+
self.pushButton_16.clicked.connect(Form.r_ear_mode)
|
666 |
+
self.pushButton_17.clicked.connect(Form.mouth_mode)
|
667 |
+
self.pushButton_18.clicked.connect(Form.u_lip_mode)
|
668 |
+
self.pushButton_19.clicked.connect(Form.l_lip_mode)
|
669 |
+
self.pushButton_20.clicked.connect(Form.hair_mode)
|
670 |
+
self.pushButton_21.clicked.connect(Form.hat_mode)
|
671 |
+
self.pushButton_22.clicked.connect(Form.ear_r_mode)
|
672 |
+
self.pushButton_23.clicked.connect(Form.neck_l_mode)
|
673 |
+
self.pushButton_24.clicked.connect(Form.neck_mode)
|
674 |
+
self.pushButton_25.clicked.connect(Form.cloth_mode)
|
675 |
+
# self.pushButton_26.clicked.connect(Form.increase)
|
676 |
+
# self.pushButton_27.clicked.connect(Form.decrease)
|
677 |
+
|
678 |
+
self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
|
679 |
+
self.slider_yawselect.valueChanged.connect(Form.changeYaw)
|
680 |
+
self.slider_pitchselect.valueChanged.connect(Form.changePitch)
|
681 |
+
self.slider_truncation.valueChanged.connect(Form.changeTruncation)
|
682 |
+
|
683 |
+
self.pushButton_inputID.clicked.connect(Form.inputID)
|
684 |
+
|
685 |
+
# self.pushButton_inverse.clicked.connect(Form.inverse)
|
686 |
+
# self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
|
687 |
+
self.pushButton_get.clicked.connect(Form.get_mask)
|
688 |
+
|
689 |
+
QtCore.QMetaObject.connectSlotsByName(Form)
|
690 |
+
|
691 |
+
def retranslateUi(self, Form):
|
692 |
+
_translate = QtCore.QCoreApplication.translate
|
693 |
+
Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis"))
|
694 |
+
self.pushButton.setText(_translate("Form", "Generate"))
|
695 |
+
self.pushButton_2.setText(_translate("Form", "Open Image"))
|
696 |
+
self.pushButton_3.setText(_translate("Form", "Open Mask"))
|
697 |
+
self.pushButton_4.setText(_translate("Form", "Clear"))
|
698 |
+
self.pushButton_5.setText(_translate("Form", "Undo"))
|
699 |
+
self.pushButton_6.setText(_translate("Form", "Save Image"))
|
700 |
+
self.pushButton_7.setText(_translate("Form", "BackGround"))
|
701 |
+
self.pushButton_8.setText(_translate("Form", "Skin"))
|
702 |
+
self.pushButton_9.setText(_translate("Form", "Nose"))
|
703 |
+
self.pushButton_10.setText(_translate("Form", "Eyeglass"))
|
704 |
+
self.pushButton_11.setText(_translate("Form", "Left Eye"))
|
705 |
+
self.pushButton_12.setText(_translate("Form", "Right Eye"))
|
706 |
+
self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
|
707 |
+
self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
|
708 |
+
self.pushButton_15.setText(_translate("Form", "Left ear"))
|
709 |
+
self.pushButton_16.setText(_translate("Form", "Right ear"))
|
710 |
+
self.pushButton_17.setText(_translate("Form", "Mouth"))
|
711 |
+
self.pushButton_18.setText(_translate("Form", "Upper Lip"))
|
712 |
+
self.pushButton_19.setText(_translate("Form", "Lower Lip"))
|
713 |
+
self.pushButton_20.setText(_translate("Form", "Hair"))
|
714 |
+
self.pushButton_21.setText(_translate("Form", "Hat"))
|
715 |
+
self.pushButton_22.setText(_translate("Form", "Earring"))
|
716 |
+
self.pushButton_23.setText(_translate("Form", "Necklace"))
|
717 |
+
self.pushButton_24.setText(_translate("Form", "Neck"))
|
718 |
+
self.pushButton_25.setText(_translate("Form", "Cloth"))
|
719 |
+
# self.pushButton_26.setText(_translate("Form", "+"))
|
720 |
+
# self.pushButton_27.setText(_translate("Form", "-"))
|
721 |
+
self.pushButton_inputID.setText(_translate("Form", "Input ID"))
|
722 |
+
# self.pushButton_inverse.setText(_translate("Form", "Inverse"))
|
723 |
+
# self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
|
724 |
+
self.pushButton_get.setText(_translate("Form", "Get"))
|
725 |
+
|
726 |
+
|
727 |
+
class Ui_Form_Edge2car(object):
|
728 |
+
def setupUi(self, Form):
|
729 |
+
Form.setObjectName("Form")
|
730 |
+
Form.resize(1800, 1260)
|
731 |
+
self.pushButton = QtWidgets.QPushButton(Form)
|
732 |
+
# self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
|
733 |
+
self.pushButton.setGeometry(QtCore.QRect(535, 360, 81, 27))
|
734 |
+
self.pushButton.setObjectName("pushButton")
|
735 |
+
self.pushButton_2 = QtWidgets.QPushButton(Form)
|
736 |
+
self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
|
737 |
+
self.pushButton_2.setObjectName("pushButton_2")
|
738 |
+
self.pushButton_3 = QtWidgets.QPushButton(Form)
|
739 |
+
self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
|
740 |
+
self.pushButton_3.setObjectName("pushButton_3")
|
741 |
+
self.pushButton_4 = QtWidgets.QPushButton(Form)
|
742 |
+
self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
|
743 |
+
self.pushButton_4.setObjectName("pushButton_4")
|
744 |
+
self.pushButton_5 = QtWidgets.QPushButton(Form)
|
745 |
+
self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
|
746 |
+
self.pushButton_5.setObjectName("pushButton_5")
|
747 |
+
self.pushButton_6 = QtWidgets.QPushButton(Form)
|
748 |
+
self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
|
749 |
+
self.pushButton_6.setObjectName("pushButton_6")
|
750 |
+
self.pushButton_7 = QtWidgets.QPushButton(Form)
|
751 |
+
self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
|
752 |
+
self.pushButton_7.setObjectName("pushButton_7")
|
753 |
+
self.pushButton_8 = QtWidgets.QPushButton(Form)
|
754 |
+
self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
|
755 |
+
self.pushButton_8.setObjectName("pushButton_8")
|
756 |
+
self.pushButton_9 = QtWidgets.QPushButton(Form)
|
757 |
+
self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
|
758 |
+
self.pushButton_9.setObjectName("pushButton_9")
|
759 |
+
self.pushButton_10 = QtWidgets.QPushButton(Form)
|
760 |
+
self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
|
761 |
+
self.pushButton_10.setObjectName("pushButton_10")
|
762 |
+
self.pushButton_11 = QtWidgets.QPushButton(Form)
|
763 |
+
self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
|
764 |
+
self.pushButton_11.setObjectName("pushButton_11")
|
765 |
+
self.pushButton_12 = QtWidgets.QPushButton(Form)
|
766 |
+
self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
|
767 |
+
self.pushButton_12.setObjectName("pushButton_12")
|
768 |
+
self.pushButton_13 = QtWidgets.QPushButton(Form)
|
769 |
+
self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
|
770 |
+
self.pushButton_13.setObjectName("pushButton_13")
|
771 |
+
self.pushButton_14 = QtWidgets.QPushButton(Form)
|
772 |
+
self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
|
773 |
+
self.pushButton_14.setObjectName("pushButton_14")
|
774 |
+
self.pushButton_15 = QtWidgets.QPushButton(Form)
|
775 |
+
self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
|
776 |
+
self.pushButton_15.setObjectName("pushButton_15")
|
777 |
+
self.pushButton_16 = QtWidgets.QPushButton(Form)
|
778 |
+
self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
|
779 |
+
self.pushButton_16.setObjectName("pushButton_16")
|
780 |
+
self.pushButton_17 = QtWidgets.QPushButton(Form)
|
781 |
+
self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
|
782 |
+
self.pushButton_17.setObjectName("pushButton_17")
|
783 |
+
self.pushButton_18 = QtWidgets.QPushButton(Form)
|
784 |
+
self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
|
785 |
+
self.pushButton_18.setObjectName("pushButton_18")
|
786 |
+
self.pushButton_19 = QtWidgets.QPushButton(Form)
|
787 |
+
self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
|
788 |
+
self.pushButton_19.setObjectName("pushButton_19")
|
789 |
+
self.pushButton_20 = QtWidgets.QPushButton(Form)
|
790 |
+
self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
|
791 |
+
self.pushButton_20.setObjectName("pushButton_20")
|
792 |
+
self.pushButton_21 = QtWidgets.QPushButton(Form)
|
793 |
+
self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
|
794 |
+
self.pushButton_21.setObjectName("pushButton_21")
|
795 |
+
self.pushButton_22 = QtWidgets.QPushButton(Form)
|
796 |
+
self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
|
797 |
+
self.pushButton_22.setObjectName("pushButton_22")
|
798 |
+
self.pushButton_23 = QtWidgets.QPushButton(Form)
|
799 |
+
self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
|
800 |
+
self.pushButton_23.setObjectName("pushButton_23")
|
801 |
+
self.pushButton_24 = QtWidgets.QPushButton(Form)
|
802 |
+
self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
|
803 |
+
self.pushButton_24.setObjectName("pushButton_24")
|
804 |
+
self.pushButton_25 = QtWidgets.QPushButton(Form)
|
805 |
+
self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
|
806 |
+
self.pushButton_25.setObjectName("pushButton_25")
|
807 |
+
# self.pushButton_26 = QtWidgets.QPushButton(Form)
|
808 |
+
# self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
809 |
+
# self.pushButton_26.setObjectName("pushButton_26")
|
810 |
+
# self.pushButton_27 = QtWidgets.QPushButton(Form)
|
811 |
+
# self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
812 |
+
# self.pushButton_27.setObjectName("pushButton_27")
|
813 |
+
|
814 |
+
self.slider_sizeselect = QtWidgets.QSlider(Form)
|
815 |
+
self.slider_sizeselect.setRange(10,70)
|
816 |
+
self.slider_sizeselect.setOrientation(Qt.Horizontal)
|
817 |
+
self.slider_sizeselect.setValue(Form.size)
|
818 |
+
self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
|
819 |
+
|
820 |
+
self.label_sizeselect = QtWidgets.QLabel(Form)
|
821 |
+
self.label_sizeselect.setText("Brush Size")
|
822 |
+
self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
|
823 |
+
|
824 |
+
self.slider_yawselect = QtWidgets.QSlider(Form)
|
825 |
+
self.slider_yawselect.setRange(-100,100)
|
826 |
+
self.slider_yawselect.setOrientation(Qt.Horizontal)
|
827 |
+
self.slider_yawselect.setValue(Form.yaw)
|
828 |
+
self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
829 |
+
|
830 |
+
self.label_yawselect = QtWidgets.QLabel(Form)
|
831 |
+
self.label_yawselect.setText("Yaw")
|
832 |
+
self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
|
833 |
+
|
834 |
+
self.slider_pitchselect = QtWidgets.QSlider(Form)
|
835 |
+
self.slider_pitchselect.setRange(-100,100)
|
836 |
+
self.slider_pitchselect.setOrientation(Qt.Horizontal)
|
837 |
+
self.slider_pitchselect.setValue(Form.pitch)
|
838 |
+
self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
839 |
+
|
840 |
+
self.label_pitchselect = QtWidgets.QLabel(Form)
|
841 |
+
self.label_pitchselect.setText("Pitch")
|
842 |
+
self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
|
843 |
+
|
844 |
+
self.slider_truncation = QtWidgets.QSlider(Form)
|
845 |
+
self.slider_truncation.setRange(0,100)
|
846 |
+
self.slider_truncation.setOrientation(Qt.Horizontal)
|
847 |
+
self.slider_truncation.setValue(Form.truncation)
|
848 |
+
self.slider_truncation.setGeometry(QtCore.QRect(1530, 100, 97, 27))
|
849 |
+
|
850 |
+
self.label_truncation = QtWidgets.QLabel(Form)
|
851 |
+
self.label_truncation.setText("Truncation")
|
852 |
+
self.label_truncation.setGeometry(QtCore.QRect(1630, 100, 97, 27))
|
853 |
+
|
854 |
+
self.text_inputID = QtWidgets.QTextEdit(Form)
|
855 |
+
self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
|
856 |
+
self.text_inputID.setObjectName("text_inputID")
|
857 |
+
|
858 |
+
self.pushButton_inputID = QtWidgets.QPushButton(Form)
|
859 |
+
self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
|
860 |
+
self.pushButton_inputID.setObjectName("pushButton_inputID")
|
861 |
+
|
862 |
+
self.text_seed = QtWidgets.QTextEdit(Form)
|
863 |
+
self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
|
864 |
+
self.text_seed.setObjectName("text_seed")
|
865 |
+
self.text_seed.setPlainText("0")
|
866 |
+
|
867 |
+
self.label_seed = QtWidgets.QLabel(Form)
|
868 |
+
self.label_seed.setText("Seed")
|
869 |
+
self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
|
870 |
+
|
871 |
+
# self.pushButton_inverse = QtWidgets.QPushButton(Form)
|
872 |
+
# self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
|
873 |
+
# self.pushButton_inverse.setObjectName("pushButton_inverse")
|
874 |
+
|
875 |
+
# self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
|
876 |
+
# self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
|
877 |
+
# self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
|
878 |
+
self.pushButton_get = QtWidgets.QPushButton(Form)
|
879 |
+
self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
|
880 |
+
self.pushButton_get.setObjectName("pushButton_get")
|
881 |
+
|
882 |
+
|
883 |
+
|
884 |
+
|
885 |
+
self.graphicsView = QtWidgets.QGraphicsView(Form)
|
886 |
+
self.graphicsView.setGeometry(QtCore.QRect(20, 120, 512, 512))
|
887 |
+
self.graphicsView.setObjectName("graphicsView")
|
888 |
+
self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
|
889 |
+
self.graphicsView_2.setGeometry(QtCore.QRect(620, 120, 512, 512))
|
890 |
+
self.graphicsView_2.setObjectName("graphicsView_2")
|
891 |
+
self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
|
892 |
+
self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
|
893 |
+
self.graphicsView_3.setObjectName("graphicsView_3")
|
894 |
+
|
895 |
+
self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
|
896 |
+
self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
|
897 |
+
self.graphicsView_5.setObjectName("graphicsView_5")
|
898 |
+
self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
|
899 |
+
self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
|
900 |
+
self.graphicsView_6.setObjectName("graphicsView_6")
|
901 |
+
|
902 |
+
|
903 |
+
self.retranslateUi(Form)
|
904 |
+
self.pushButton.clicked.connect(Form.generateAndReconstruct)
|
905 |
+
self.pushButton_2.clicked.connect(Form.open)
|
906 |
+
self.pushButton_3.clicked.connect(Form.open_mask)
|
907 |
+
self.pushButton_4.clicked.connect(Form.clear)
|
908 |
+
self.pushButton_5.clicked.connect(Form.undo)
|
909 |
+
self.pushButton_6.clicked.connect(Form.save_img)
|
910 |
+
self.pushButton_7.clicked.connect(Form.bg_mode)
|
911 |
+
self.pushButton_8.clicked.connect(Form.skin_mode)
|
912 |
+
self.pushButton_9.clicked.connect(Form.nose_mode)
|
913 |
+
self.pushButton_10.clicked.connect(Form.eye_g_mode)
|
914 |
+
self.pushButton_11.clicked.connect(Form.l_eye_mode)
|
915 |
+
self.pushButton_12.clicked.connect(Form.r_eye_mode)
|
916 |
+
self.pushButton_13.clicked.connect(Form.l_brow_mode)
|
917 |
+
self.pushButton_14.clicked.connect(Form.r_brow_mode)
|
918 |
+
self.pushButton_15.clicked.connect(Form.l_ear_mode)
|
919 |
+
self.pushButton_16.clicked.connect(Form.r_ear_mode)
|
920 |
+
self.pushButton_17.clicked.connect(Form.mouth_mode)
|
921 |
+
self.pushButton_18.clicked.connect(Form.u_lip_mode)
|
922 |
+
self.pushButton_19.clicked.connect(Form.l_lip_mode)
|
923 |
+
self.pushButton_20.clicked.connect(Form.hair_mode)
|
924 |
+
self.pushButton_21.clicked.connect(Form.hat_mode)
|
925 |
+
self.pushButton_22.clicked.connect(Form.ear_r_mode)
|
926 |
+
self.pushButton_23.clicked.connect(Form.neck_l_mode)
|
927 |
+
self.pushButton_24.clicked.connect(Form.neck_mode)
|
928 |
+
self.pushButton_25.clicked.connect(Form.cloth_mode)
|
929 |
+
# self.pushButton_26.clicked.connect(Form.increase)
|
930 |
+
# self.pushButton_27.clicked.connect(Form.decrease)
|
931 |
+
|
932 |
+
self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
|
933 |
+
self.slider_yawselect.valueChanged.connect(Form.changeYaw)
|
934 |
+
self.slider_pitchselect.valueChanged.connect(Form.changePitch)
|
935 |
+
self.slider_truncation.valueChanged.connect(Form.changeTruncation)
|
936 |
+
|
937 |
+
self.pushButton_inputID.clicked.connect(Form.inputID)
|
938 |
+
|
939 |
+
# self.pushButton_inverse.clicked.connect(Form.inverse)
|
940 |
+
# self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
|
941 |
+
self.pushButton_get.clicked.connect(Form.get_mask)
|
942 |
+
|
943 |
+
QtCore.QMetaObject.connectSlotsByName(Form)
|
944 |
+
|
945 |
+
def retranslateUi(self, Form):
|
946 |
+
_translate = QtCore.QCoreApplication.translate
|
947 |
+
Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis (Edge2car)"))
|
948 |
+
self.pushButton.setText(_translate("Form", "Generate"))
|
949 |
+
self.pushButton_2.setText(_translate("Form", "Open Image"))
|
950 |
+
self.pushButton_3.setText(_translate("Form", "Open Mask"))
|
951 |
+
self.pushButton_4.setText(_translate("Form", "Clear"))
|
952 |
+
self.pushButton_5.setText(_translate("Form", "Undo"))
|
953 |
+
self.pushButton_6.setText(_translate("Form", "Save Image"))
|
954 |
+
self.pushButton_7.setText(_translate("Form", "BackGround"))
|
955 |
+
self.pushButton_8.setText(_translate("Form", "Skin"))
|
956 |
+
self.pushButton_9.setText(_translate("Form", "Nose"))
|
957 |
+
self.pushButton_10.setText(_translate("Form", "Eyeglass"))
|
958 |
+
self.pushButton_11.setText(_translate("Form", "Left Eye"))
|
959 |
+
self.pushButton_12.setText(_translate("Form", "Right Eye"))
|
960 |
+
self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
|
961 |
+
self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
|
962 |
+
self.pushButton_15.setText(_translate("Form", "Left ear"))
|
963 |
+
self.pushButton_16.setText(_translate("Form", "Right ear"))
|
964 |
+
self.pushButton_17.setText(_translate("Form", "Mouth"))
|
965 |
+
self.pushButton_18.setText(_translate("Form", "Upper Lip"))
|
966 |
+
self.pushButton_19.setText(_translate("Form", "Lower Lip"))
|
967 |
+
self.pushButton_20.setText(_translate("Form", "Hair"))
|
968 |
+
self.pushButton_21.setText(_translate("Form", "Hat"))
|
969 |
+
self.pushButton_22.setText(_translate("Form", "Earring"))
|
970 |
+
self.pushButton_23.setText(_translate("Form", "Necklace"))
|
971 |
+
self.pushButton_24.setText(_translate("Form", "Neck"))
|
972 |
+
self.pushButton_25.setText(_translate("Form", "Cloth"))
|
973 |
+
# self.pushButton_26.setText(_translate("Form", "+"))
|
974 |
+
# self.pushButton_27.setText(_translate("Form", "-"))
|
975 |
+
self.pushButton_inputID.setText(_translate("Form", "Input ID"))
|
976 |
+
# self.pushButton_inverse.setText(_translate("Form", "Inverse"))
|
977 |
+
# self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
|
978 |
+
self.pushButton_get.setText(_translate("Form", "Get"))
|
979 |
+
|
980 |
+
if __name__ == "__main__":
|
981 |
+
import sys
|
982 |
+
app = QtWidgets.QApplication(sys.argv)
|
983 |
+
Form = QtWidgets.QWidget()
|
984 |
+
ui = Ui_Form()
|
985 |
+
ui.setupUi(Form)
|
986 |
+
Form.show()
|
987 |
+
sys.exit(app.exec_())
|
988 |
+
|
pix2pix3D-main/pix2pix3D-main/applications/demo/ui_qt/ui_clean.py
ADDED
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PyQt5 import QtCore, QtGui, QtWidgets
|
2 |
+
from PyQt5.QtCore import Qt
|
3 |
+
|
4 |
+
class Ui_Form_Clean(object):
|
5 |
+
def setupUi(self, Form):
|
6 |
+
Form.setObjectName("Form")
|
7 |
+
Form.resize(1800, 660)
|
8 |
+
self.pushButton = QtWidgets.QPushButton(Form)
|
9 |
+
# self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
|
10 |
+
self.pushButton.setGeometry(QtCore.QRect(685, 360, 81, 27))
|
11 |
+
self.pushButton.setObjectName("pushButton")
|
12 |
+
self.pushButton_2 = QtWidgets.QPushButton(Form)
|
13 |
+
self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
|
14 |
+
self.pushButton_2.setObjectName("pushButton_2")
|
15 |
+
self.pushButton_3 = QtWidgets.QPushButton(Form)
|
16 |
+
self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
|
17 |
+
self.pushButton_3.setObjectName("pushButton_3")
|
18 |
+
self.pushButton_4 = QtWidgets.QPushButton(Form)
|
19 |
+
self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
|
20 |
+
self.pushButton_4.setObjectName("pushButton_4")
|
21 |
+
self.pushButton_5 = QtWidgets.QPushButton(Form)
|
22 |
+
self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
|
23 |
+
self.pushButton_5.setObjectName("pushButton_5")
|
24 |
+
self.pushButton_6 = QtWidgets.QPushButton(Form)
|
25 |
+
self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
|
26 |
+
self.pushButton_6.setObjectName("pushButton_6")
|
27 |
+
self.pushButton_7 = QtWidgets.QPushButton(Form)
|
28 |
+
self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
|
29 |
+
self.pushButton_7.setObjectName("pushButton_7")
|
30 |
+
self.pushButton_8 = QtWidgets.QPushButton(Form)
|
31 |
+
self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
|
32 |
+
self.pushButton_8.setObjectName("pushButton_8")
|
33 |
+
self.pushButton_9 = QtWidgets.QPushButton(Form)
|
34 |
+
self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
|
35 |
+
self.pushButton_9.setObjectName("pushButton_9")
|
36 |
+
self.pushButton_10 = QtWidgets.QPushButton(Form)
|
37 |
+
self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
|
38 |
+
self.pushButton_10.setObjectName("pushButton_10")
|
39 |
+
self.pushButton_11 = QtWidgets.QPushButton(Form)
|
40 |
+
self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
|
41 |
+
self.pushButton_11.setObjectName("pushButton_11")
|
42 |
+
self.pushButton_12 = QtWidgets.QPushButton(Form)
|
43 |
+
self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
|
44 |
+
self.pushButton_12.setObjectName("pushButton_12")
|
45 |
+
self.pushButton_13 = QtWidgets.QPushButton(Form)
|
46 |
+
self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
|
47 |
+
self.pushButton_13.setObjectName("pushButton_13")
|
48 |
+
self.pushButton_14 = QtWidgets.QPushButton(Form)
|
49 |
+
self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
|
50 |
+
self.pushButton_14.setObjectName("pushButton_14")
|
51 |
+
self.pushButton_15 = QtWidgets.QPushButton(Form)
|
52 |
+
self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
|
53 |
+
self.pushButton_15.setObjectName("pushButton_15")
|
54 |
+
self.pushButton_16 = QtWidgets.QPushButton(Form)
|
55 |
+
self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
|
56 |
+
self.pushButton_16.setObjectName("pushButton_16")
|
57 |
+
self.pushButton_17 = QtWidgets.QPushButton(Form)
|
58 |
+
self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
|
59 |
+
self.pushButton_17.setObjectName("pushButton_17")
|
60 |
+
self.pushButton_18 = QtWidgets.QPushButton(Form)
|
61 |
+
self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
|
62 |
+
self.pushButton_18.setObjectName("pushButton_18")
|
63 |
+
self.pushButton_19 = QtWidgets.QPushButton(Form)
|
64 |
+
self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
|
65 |
+
self.pushButton_19.setObjectName("pushButton_19")
|
66 |
+
self.pushButton_20 = QtWidgets.QPushButton(Form)
|
67 |
+
self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
|
68 |
+
self.pushButton_20.setObjectName("pushButton_20")
|
69 |
+
self.pushButton_21 = QtWidgets.QPushButton(Form)
|
70 |
+
self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
|
71 |
+
self.pushButton_21.setObjectName("pushButton_21")
|
72 |
+
self.pushButton_22 = QtWidgets.QPushButton(Form)
|
73 |
+
self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
|
74 |
+
self.pushButton_22.setObjectName("pushButton_22")
|
75 |
+
self.pushButton_23 = QtWidgets.QPushButton(Form)
|
76 |
+
self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
|
77 |
+
self.pushButton_23.setObjectName("pushButton_23")
|
78 |
+
self.pushButton_24 = QtWidgets.QPushButton(Form)
|
79 |
+
self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
|
80 |
+
self.pushButton_24.setObjectName("pushButton_24")
|
81 |
+
self.pushButton_25 = QtWidgets.QPushButton(Form)
|
82 |
+
self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
|
83 |
+
self.pushButton_25.setObjectName("pushButton_25")
|
84 |
+
# self.pushButton_26 = QtWidgets.QPushButton(Form)
|
85 |
+
# self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
86 |
+
# self.pushButton_26.setObjectName("pushButton_26")
|
87 |
+
# self.pushButton_27 = QtWidgets.QPushButton(Form)
|
88 |
+
# self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
89 |
+
# self.pushButton_27.setObjectName("pushButton_27")
|
90 |
+
|
91 |
+
self.slider_sizeselect = QtWidgets.QSlider(Form)
|
92 |
+
self.slider_sizeselect.setRange(10,70)
|
93 |
+
self.slider_sizeselect.setOrientation(Qt.Horizontal)
|
94 |
+
self.slider_sizeselect.setValue(Form.size)
|
95 |
+
self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
|
96 |
+
|
97 |
+
self.label_sizeselect = QtWidgets.QLabel(Form)
|
98 |
+
self.label_sizeselect.setText("Brush Size")
|
99 |
+
self.label_sizeselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
|
100 |
+
|
101 |
+
self.slider_yawselect = QtWidgets.QSlider(Form)
|
102 |
+
self.slider_yawselect.setRange(-100,100)
|
103 |
+
self.slider_yawselect.setOrientation(Qt.Horizontal)
|
104 |
+
self.slider_yawselect.setValue(Form.yaw)
|
105 |
+
self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
106 |
+
|
107 |
+
self.label_yawselect = QtWidgets.QLabel(Form)
|
108 |
+
self.label_yawselect.setText("Yaw")
|
109 |
+
self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
|
110 |
+
|
111 |
+
self.slider_pitchselect = QtWidgets.QSlider(Form)
|
112 |
+
self.slider_pitchselect.setRange(-100,100)
|
113 |
+
self.slider_pitchselect.setOrientation(Qt.Horizontal)
|
114 |
+
self.slider_pitchselect.setValue(Form.pitch)
|
115 |
+
self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
116 |
+
|
117 |
+
self.label_pitchselect = QtWidgets.QLabel(Form)
|
118 |
+
self.label_pitchselect.setText("Pitch")
|
119 |
+
self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
|
120 |
+
|
121 |
+
self.slider_truncation = QtWidgets.QSlider(Form)
|
122 |
+
self.slider_truncation.setRange(0,100)
|
123 |
+
self.slider_truncation.setOrientation(Qt.Horizontal)
|
124 |
+
self.slider_truncation.setValue(Form.truncation)
|
125 |
+
self.slider_truncation.setGeometry(QtCore.QRect(1530, 100, 97, 27))
|
126 |
+
|
127 |
+
self.label_truncation = QtWidgets.QLabel(Form)
|
128 |
+
self.label_truncation.setText("Truncation")
|
129 |
+
self.label_truncation.setGeometry(QtCore.QRect(1630, 100, 97, 27))
|
130 |
+
|
131 |
+
self.text_inputID = QtWidgets.QTextEdit(Form)
|
132 |
+
self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
|
133 |
+
self.text_inputID.setObjectName("text_inputID")
|
134 |
+
|
135 |
+
self.pushButton_inputID = QtWidgets.QPushButton(Form)
|
136 |
+
self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
|
137 |
+
self.pushButton_inputID.setObjectName("pushButton_inputID")
|
138 |
+
|
139 |
+
self.text_seed = QtWidgets.QTextEdit(Form)
|
140 |
+
self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
|
141 |
+
self.text_seed.setObjectName("text_seed")
|
142 |
+
self.text_seed.setPlainText("0")
|
143 |
+
|
144 |
+
self.label_seed = QtWidgets.QLabel(Form)
|
145 |
+
self.label_seed.setText("Seed")
|
146 |
+
self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
|
147 |
+
|
148 |
+
# self.pushButton_inverse = QtWidgets.QPushButton(Form)
|
149 |
+
# self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
|
150 |
+
# self.pushButton_inverse.setObjectName("pushButton_inverse")
|
151 |
+
|
152 |
+
# self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
|
153 |
+
# self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
|
154 |
+
# self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
|
155 |
+
self.pushButton_get = QtWidgets.QPushButton(Form)
|
156 |
+
self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
|
157 |
+
self.pushButton_get.setObjectName("pushButton_get")
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
self.graphicsView = QtWidgets.QGraphicsView(Form)
|
163 |
+
self.graphicsView.setGeometry(QtCore.QRect(120, 120, 512, 512))
|
164 |
+
self.graphicsView.setObjectName("graphicsView")
|
165 |
+
self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
|
166 |
+
self.graphicsView_2.setGeometry(QtCore.QRect(820, 120, 512, 512))
|
167 |
+
self.graphicsView_2.setObjectName("graphicsView_2")
|
168 |
+
# self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
|
169 |
+
# self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
|
170 |
+
# self.graphicsView_3.setObjectName("graphicsView_3")
|
171 |
+
|
172 |
+
# self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
|
173 |
+
# self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
|
174 |
+
# self.graphicsView_5.setObjectName("graphicsView_5")
|
175 |
+
# self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
|
176 |
+
# self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
|
177 |
+
# self.graphicsView_6.setObjectName("graphicsView_6")
|
178 |
+
|
179 |
+
|
180 |
+
self.retranslateUi(Form)
|
181 |
+
self.pushButton.clicked.connect(Form.generateAndReconstruct)
|
182 |
+
self.pushButton_2.clicked.connect(Form.open)
|
183 |
+
self.pushButton_3.clicked.connect(Form.open_mask)
|
184 |
+
self.pushButton_4.clicked.connect(Form.clear)
|
185 |
+
self.pushButton_5.clicked.connect(Form.undo)
|
186 |
+
self.pushButton_6.clicked.connect(Form.save_img)
|
187 |
+
self.pushButton_7.clicked.connect(Form.bg_mode)
|
188 |
+
self.pushButton_8.clicked.connect(Form.skin_mode)
|
189 |
+
self.pushButton_9.clicked.connect(Form.nose_mode)
|
190 |
+
self.pushButton_10.clicked.connect(Form.eye_g_mode)
|
191 |
+
self.pushButton_11.clicked.connect(Form.l_eye_mode)
|
192 |
+
self.pushButton_12.clicked.connect(Form.r_eye_mode)
|
193 |
+
self.pushButton_13.clicked.connect(Form.l_brow_mode)
|
194 |
+
self.pushButton_14.clicked.connect(Form.r_brow_mode)
|
195 |
+
self.pushButton_15.clicked.connect(Form.l_ear_mode)
|
196 |
+
self.pushButton_16.clicked.connect(Form.r_ear_mode)
|
197 |
+
self.pushButton_17.clicked.connect(Form.mouth_mode)
|
198 |
+
self.pushButton_18.clicked.connect(Form.u_lip_mode)
|
199 |
+
self.pushButton_19.clicked.connect(Form.l_lip_mode)
|
200 |
+
self.pushButton_20.clicked.connect(Form.hair_mode)
|
201 |
+
self.pushButton_21.clicked.connect(Form.hat_mode)
|
202 |
+
self.pushButton_22.clicked.connect(Form.ear_r_mode)
|
203 |
+
self.pushButton_23.clicked.connect(Form.neck_l_mode)
|
204 |
+
self.pushButton_24.clicked.connect(Form.neck_mode)
|
205 |
+
self.pushButton_25.clicked.connect(Form.cloth_mode)
|
206 |
+
# self.pushButton_26.clicked.connect(Form.increase)
|
207 |
+
# self.pushButton_27.clicked.connect(Form.decrease)
|
208 |
+
|
209 |
+
self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
|
210 |
+
self.slider_yawselect.valueChanged.connect(Form.changeYaw)
|
211 |
+
self.slider_pitchselect.valueChanged.connect(Form.changePitch)
|
212 |
+
self.slider_truncation.valueChanged.connect(Form.changeTruncation)
|
213 |
+
|
214 |
+
self.pushButton_inputID.clicked.connect(Form.inputID)
|
215 |
+
|
216 |
+
# self.pushButton_inverse.clicked.connect(Form.inverse)
|
217 |
+
# self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
|
218 |
+
self.pushButton_get.clicked.connect(Form.get_mask)
|
219 |
+
|
220 |
+
QtCore.QMetaObject.connectSlotsByName(Form)
|
221 |
+
|
222 |
+
def retranslateUi(self, Form):
|
223 |
+
_translate = QtCore.QCoreApplication.translate
|
224 |
+
Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis"))
|
225 |
+
self.pushButton.setText(_translate("Form", "Generate"))
|
226 |
+
self.pushButton_2.setText(_translate("Form", "Open Image"))
|
227 |
+
self.pushButton_3.setText(_translate("Form", "Open Mask"))
|
228 |
+
self.pushButton_4.setText(_translate("Form", "Clear"))
|
229 |
+
self.pushButton_5.setText(_translate("Form", "Undo"))
|
230 |
+
self.pushButton_6.setText(_translate("Form", "Save Image"))
|
231 |
+
self.pushButton_7.setText(_translate("Form", "BackGround"))
|
232 |
+
self.pushButton_8.setText(_translate("Form", "Skin"))
|
233 |
+
self.pushButton_9.setText(_translate("Form", "Nose"))
|
234 |
+
self.pushButton_10.setText(_translate("Form", "Eyeglass"))
|
235 |
+
self.pushButton_11.setText(_translate("Form", "Left Eye"))
|
236 |
+
self.pushButton_12.setText(_translate("Form", "Right Eye"))
|
237 |
+
self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
|
238 |
+
self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
|
239 |
+
self.pushButton_15.setText(_translate("Form", "Left ear"))
|
240 |
+
self.pushButton_16.setText(_translate("Form", "Right ear"))
|
241 |
+
self.pushButton_17.setText(_translate("Form", "Mouth"))
|
242 |
+
self.pushButton_18.setText(_translate("Form", "Upper Lip"))
|
243 |
+
self.pushButton_19.setText(_translate("Form", "Lower Lip"))
|
244 |
+
self.pushButton_20.setText(_translate("Form", "Hair"))
|
245 |
+
self.pushButton_21.setText(_translate("Form", "Hat"))
|
246 |
+
self.pushButton_22.setText(_translate("Form", "Earring"))
|
247 |
+
self.pushButton_23.setText(_translate("Form", "Necklace"))
|
248 |
+
self.pushButton_24.setText(_translate("Form", "Neck"))
|
249 |
+
self.pushButton_25.setText(_translate("Form", "Cloth"))
|
250 |
+
# self.pushButton_26.setText(_translate("Form", "+"))
|
251 |
+
# self.pushButton_27.setText(_translate("Form", "-"))
|
252 |
+
self.pushButton_inputID.setText(_translate("Form", "Input ID"))
|
253 |
+
# self.pushButton_inverse.setText(_translate("Form", "Inverse"))
|
254 |
+
# self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
|
255 |
+
self.pushButton_get.setText(_translate("Form", "Get"))
|
256 |
+
|
257 |
+
|
258 |
+
class Ui_Form_Edge2car(object):
|
259 |
+
def setupUi(self, Form):
|
260 |
+
Form.setObjectName("Form")
|
261 |
+
Form.resize(1800, 660)
|
262 |
+
self.pushButton = QtWidgets.QPushButton(Form)
|
263 |
+
# self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
|
264 |
+
self.pushButton.setGeometry(QtCore.QRect(685, 360, 81, 27))
|
265 |
+
self.pushButton.setObjectName("pushButton")
|
266 |
+
self.pushButton_2 = QtWidgets.QPushButton(Form)
|
267 |
+
self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
|
268 |
+
self.pushButton_2.setObjectName("pushButton_2")
|
269 |
+
self.pushButton_3 = QtWidgets.QPushButton(Form)
|
270 |
+
self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
|
271 |
+
self.pushButton_3.setObjectName("pushButton_3")
|
272 |
+
self.pushButton_4 = QtWidgets.QPushButton(Form)
|
273 |
+
self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
|
274 |
+
self.pushButton_4.setObjectName("pushButton_4")
|
275 |
+
self.pushButton_5 = QtWidgets.QPushButton(Form)
|
276 |
+
self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
|
277 |
+
self.pushButton_5.setObjectName("pushButton_5")
|
278 |
+
self.pushButton_6 = QtWidgets.QPushButton(Form)
|
279 |
+
self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
|
280 |
+
self.pushButton_6.setObjectName("pushButton_6")
|
281 |
+
self.pushButton_7 = QtWidgets.QPushButton(Form)
|
282 |
+
self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
|
283 |
+
self.pushButton_7.setObjectName("pushButton_7")
|
284 |
+
self.pushButton_8 = QtWidgets.QPushButton(Form)
|
285 |
+
self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
|
286 |
+
self.pushButton_8.setObjectName("pushButton_8")
|
287 |
+
# self.pushButton_9 = QtWidgets.QPushButton(Form)
|
288 |
+
# self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
|
289 |
+
# self.pushButton_9.setObjectName("pushButton_9")
|
290 |
+
# self.pushButton_10 = QtWidgets.QPushButton(Form)
|
291 |
+
# self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
|
292 |
+
# self.pushButton_10.setObjectName("pushButton_10")
|
293 |
+
# self.pushButton_11 = QtWidgets.QPushButton(Form)
|
294 |
+
# self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
|
295 |
+
# self.pushButton_11.setObjectName("pushButton_11")
|
296 |
+
# self.pushButton_12 = QtWidgets.QPushButton(Form)
|
297 |
+
# self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
|
298 |
+
# self.pushButton_12.setObjectName("pushButton_12")
|
299 |
+
# self.pushButton_13 = QtWidgets.QPushButton(Form)
|
300 |
+
# self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
|
301 |
+
# self.pushButton_13.setObjectName("pushButton_13")
|
302 |
+
# self.pushButton_14 = QtWidgets.QPushButton(Form)
|
303 |
+
# self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
|
304 |
+
# self.pushButton_14.setObjectName("pushButton_14")
|
305 |
+
# self.pushButton_15 = QtWidgets.QPushButton(Form)
|
306 |
+
# self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
|
307 |
+
# self.pushButton_15.setObjectName("pushButton_15")
|
308 |
+
# self.pushButton_16 = QtWidgets.QPushButton(Form)
|
309 |
+
# self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
|
310 |
+
# self.pushButton_16.setObjectName("pushButton_16")
|
311 |
+
# self.pushButton_17 = QtWidgets.QPushButton(Form)
|
312 |
+
# self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
|
313 |
+
# self.pushButton_17.setObjectName("pushButton_17")
|
314 |
+
# self.pushButton_18 = QtWidgets.QPushButton(Form)
|
315 |
+
# self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
|
316 |
+
# self.pushButton_18.setObjectName("pushButton_18")
|
317 |
+
# self.pushButton_19 = QtWidgets.QPushButton(Form)
|
318 |
+
# self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
|
319 |
+
# self.pushButton_19.setObjectName("pushButton_19")
|
320 |
+
# self.pushButton_20 = QtWidgets.QPushButton(Form)
|
321 |
+
# self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
|
322 |
+
# self.pushButton_20.setObjectName("pushButton_20")
|
323 |
+
# self.pushButton_21 = QtWidgets.QPushButton(Form)
|
324 |
+
# self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
|
325 |
+
# self.pushButton_21.setObjectName("pushButton_21")
|
326 |
+
# self.pushButton_22 = QtWidgets.QPushButton(Form)
|
327 |
+
# self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
|
328 |
+
# self.pushButton_22.setObjectName("pushButton_22")
|
329 |
+
# self.pushButton_23 = QtWidgets.QPushButton(Form)
|
330 |
+
# self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
|
331 |
+
# self.pushButton_23.setObjectName("pushButton_23")
|
332 |
+
# self.pushButton_24 = QtWidgets.QPushButton(Form)
|
333 |
+
# self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
|
334 |
+
# self.pushButton_24.setObjectName("pushButton_24")
|
335 |
+
# self.pushButton_25 = QtWidgets.QPushButton(Form)
|
336 |
+
# self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
|
337 |
+
# self.pushButton_25.setObjectName("pushButton_25")
|
338 |
+
# self.pushButton_26 = QtWidgets.QPushButton(Form)
|
339 |
+
# self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
340 |
+
# self.pushButton_26.setObjectName("pushButton_26")
|
341 |
+
# self.pushButton_27 = QtWidgets.QPushButton(Form)
|
342 |
+
# self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
343 |
+
# self.pushButton_27.setObjectName("pushButton_27")
|
344 |
+
|
345 |
+
self.slider_sizeselect = QtWidgets.QSlider(Form)
|
346 |
+
self.slider_sizeselect.setRange(10,70)
|
347 |
+
self.slider_sizeselect.setOrientation(Qt.Horizontal)
|
348 |
+
self.slider_sizeselect.setValue(Form.size)
|
349 |
+
self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 100, 97, 27))
|
350 |
+
|
351 |
+
self.label_sizeselect = QtWidgets.QLabel(Form)
|
352 |
+
self.label_sizeselect.setText("Brush Size")
|
353 |
+
self.label_sizeselect.setGeometry(QtCore.QRect(1630, 100, 97, 27))
|
354 |
+
|
355 |
+
self.slider_yawselect = QtWidgets.QSlider(Form)
|
356 |
+
self.slider_yawselect.setRange(-100,100)
|
357 |
+
self.slider_yawselect.setOrientation(Qt.Horizontal)
|
358 |
+
self.slider_yawselect.setValue(Form.yaw)
|
359 |
+
self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
360 |
+
|
361 |
+
self.label_yawselect = QtWidgets.QLabel(Form)
|
362 |
+
self.label_yawselect.setText("Yaw")
|
363 |
+
self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
|
364 |
+
|
365 |
+
self.slider_pitchselect = QtWidgets.QSlider(Form)
|
366 |
+
self.slider_pitchselect.setRange(-100,100)
|
367 |
+
self.slider_pitchselect.setOrientation(Qt.Horizontal)
|
368 |
+
self.slider_pitchselect.setValue(Form.pitch)
|
369 |
+
self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
370 |
+
|
371 |
+
self.label_pitchselect = QtWidgets.QLabel(Form)
|
372 |
+
self.label_pitchselect.setText("Pitch")
|
373 |
+
self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
|
374 |
+
|
375 |
+
self.slider_rollselect = QtWidgets.QSlider(Form)
|
376 |
+
self.slider_rollselect.setRange(0,100)
|
377 |
+
self.slider_rollselect.setOrientation(Qt.Horizontal)
|
378 |
+
self.slider_rollselect.setValue(Form.roll)
|
379 |
+
self.slider_rollselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
|
380 |
+
|
381 |
+
self.label_rollselect = QtWidgets.QLabel(Form)
|
382 |
+
self.label_rollselect.setText("Roll")
|
383 |
+
self.label_rollselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
|
384 |
+
|
385 |
+
|
386 |
+
self.slider_truncation = QtWidgets.QSlider(Form)
|
387 |
+
self.slider_truncation.setRange(0,100)
|
388 |
+
self.slider_truncation.setOrientation(Qt.Horizontal)
|
389 |
+
self.slider_truncation.setValue(Form.truncation)
|
390 |
+
self.slider_truncation.setGeometry(QtCore.QRect(1530, 130, 97, 27))
|
391 |
+
|
392 |
+
self.label_truncation = QtWidgets.QLabel(Form)
|
393 |
+
self.label_truncation.setText("Truncation")
|
394 |
+
self.label_truncation.setGeometry(QtCore.QRect(1630, 130, 97, 27))
|
395 |
+
|
396 |
+
self.text_inputID = QtWidgets.QTextEdit(Form)
|
397 |
+
self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
|
398 |
+
self.text_inputID.setObjectName("text_inputID")
|
399 |
+
|
400 |
+
self.pushButton_inputID = QtWidgets.QPushButton(Form)
|
401 |
+
self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
|
402 |
+
self.pushButton_inputID.setObjectName("pushButton_inputID")
|
403 |
+
|
404 |
+
self.text_seed = QtWidgets.QTextEdit(Form)
|
405 |
+
self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
|
406 |
+
self.text_seed.setObjectName("text_seed")
|
407 |
+
self.text_seed.setPlainText("0")
|
408 |
+
|
409 |
+
self.label_seed = QtWidgets.QLabel(Form)
|
410 |
+
self.label_seed.setText("Seed")
|
411 |
+
self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
|
412 |
+
|
413 |
+
# self.pushButton_inverse = QtWidgets.QPushButton(Form)
|
414 |
+
# self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
|
415 |
+
# self.pushButton_inverse.setObjectName("pushButton_inverse")
|
416 |
+
|
417 |
+
# self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
|
418 |
+
# self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
|
419 |
+
# self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
|
420 |
+
self.pushButton_get = QtWidgets.QPushButton(Form)
|
421 |
+
self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
|
422 |
+
self.pushButton_get.setObjectName("pushButton_get")
|
423 |
+
|
424 |
+
|
425 |
+
|
426 |
+
|
427 |
+
self.graphicsView = QtWidgets.QGraphicsView(Form)
|
428 |
+
self.graphicsView.setGeometry(QtCore.QRect(120, 120, 512, 512))
|
429 |
+
self.graphicsView.setObjectName("graphicsView")
|
430 |
+
self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
|
431 |
+
self.graphicsView_2.setGeometry(QtCore.QRect(820, 120, 512, 512))
|
432 |
+
self.graphicsView_2.setObjectName("graphicsView_2")
|
433 |
+
# self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
|
434 |
+
# self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
|
435 |
+
# self.graphicsView_3.setObjectName("graphicsView_3")
|
436 |
+
|
437 |
+
# self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
|
438 |
+
# self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
|
439 |
+
# self.graphicsView_5.setObjectName("graphicsView_5")
|
440 |
+
# self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
|
441 |
+
# self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
|
442 |
+
# self.graphicsView_6.setObjectName("graphicsView_6")
|
443 |
+
|
444 |
+
|
445 |
+
self.retranslateUi(Form)
|
446 |
+
self.pushButton.clicked.connect(Form.generateAndReconstruct)
|
447 |
+
self.pushButton_2.clicked.connect(Form.open)
|
448 |
+
self.pushButton_3.clicked.connect(Form.open_mask)
|
449 |
+
self.pushButton_4.clicked.connect(Form.clear)
|
450 |
+
self.pushButton_5.clicked.connect(Form.undo)
|
451 |
+
self.pushButton_6.clicked.connect(Form.save_img)
|
452 |
+
self.pushButton_7.clicked.connect(Form.bg_mode)
|
453 |
+
self.pushButton_8.clicked.connect(Form.skin_mode)
|
454 |
+
# self.pushButton_9.clicked.connect(Form.nose_mode)
|
455 |
+
# self.pushButton_10.clicked.connect(Form.eye_g_mode)
|
456 |
+
# self.pushButton_11.clicked.connect(Form.l_eye_mode)
|
457 |
+
# self.pushButton_12.clicked.connect(Form.r_eye_mode)
|
458 |
+
# self.pushButton_13.clicked.connect(Form.l_brow_mode)
|
459 |
+
# self.pushButton_14.clicked.connect(Form.r_brow_mode)
|
460 |
+
# self.pushButton_15.clicked.connect(Form.l_ear_mode)
|
461 |
+
# self.pushButton_16.clicked.connect(Form.r_ear_mode)
|
462 |
+
# self.pushButton_17.clicked.connect(Form.mouth_mode)
|
463 |
+
# self.pushButton_18.clicked.connect(Form.u_lip_mode)
|
464 |
+
# self.pushButton_19.clicked.connect(Form.l_lip_mode)
|
465 |
+
# self.pushButton_20.clicked.connect(Form.hair_mode)
|
466 |
+
# self.pushButton_21.clicked.connect(Form.hat_mode)
|
467 |
+
# self.pushButton_22.clicked.connect(Form.ear_r_mode)
|
468 |
+
# self.pushButton_23.clicked.connect(Form.neck_l_mode)
|
469 |
+
# self.pushButton_24.clicked.connect(Form.neck_mode)
|
470 |
+
# self.pushButton_25.clicked.connect(Form.cloth_mode)
|
471 |
+
# self.pushButton_26.clicked.connect(Form.increase)
|
472 |
+
# self.pushButton_27.clicked.connect(Form.decrease)
|
473 |
+
|
474 |
+
self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
|
475 |
+
self.slider_yawselect.valueChanged.connect(Form.changeYaw)
|
476 |
+
self.slider_pitchselect.valueChanged.connect(Form.changePitch)
|
477 |
+
self.slider_rollselect.valueChanged.connect(Form.changeRoll)
|
478 |
+
self.slider_truncation.valueChanged.connect(Form.changeTruncation)
|
479 |
+
|
480 |
+
self.pushButton_inputID.clicked.connect(Form.inputID)
|
481 |
+
|
482 |
+
# self.pushButton_inverse.clicked.connect(Form.inverse)
|
483 |
+
# self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
|
484 |
+
self.pushButton_get.clicked.connect(Form.get_mask)
|
485 |
+
|
486 |
+
QtCore.QMetaObject.connectSlotsByName(Form)
|
487 |
+
|
488 |
+
def retranslateUi(self, Form):
|
489 |
+
_translate = QtCore.QCoreApplication.translate
|
490 |
+
Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis (Edge2car)"))
|
491 |
+
self.pushButton.setText(_translate("Form", "Generate"))
|
492 |
+
self.pushButton_2.setText(_translate("Form", "Open Image"))
|
493 |
+
self.pushButton_3.setText(_translate("Form", "Open Mask"))
|
494 |
+
self.pushButton_4.setText(_translate("Form", "Clear"))
|
495 |
+
self.pushButton_5.setText(_translate("Form", "Undo"))
|
496 |
+
self.pushButton_6.setText(_translate("Form", "Save Image"))
|
497 |
+
self.pushButton_7.setText(_translate("Form", "BackGround"))
|
498 |
+
self.pushButton_8.setText(_translate("Form", "Edge"))
|
499 |
+
# self.pushButton_9.setText(_translate("Form", "Nose"))
|
500 |
+
# self.pushButton_10.setText(_translate("Form", "Eyeglass"))
|
501 |
+
# self.pushButton_11.setText(_translate("Form", "Left Eye"))
|
502 |
+
# self.pushButton_12.setText(_translate("Form", "Right Eye"))
|
503 |
+
# self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
|
504 |
+
# self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
|
505 |
+
# self.pushButton_15.setText(_translate("Form", "Left ear"))
|
506 |
+
# self.pushButton_16.setText(_translate("Form", "Right ear"))
|
507 |
+
# self.pushButton_17.setText(_translate("Form", "Mouth"))
|
508 |
+
# self.pushButton_18.setText(_translate("Form", "Upper Lip"))
|
509 |
+
# self.pushButton_19.setText(_translate("Form", "Lower Lip"))
|
510 |
+
# self.pushButton_20.setText(_translate("Form", "Hair"))
|
511 |
+
# self.pushButton_21.setText(_translate("Form", "Hat"))
|
512 |
+
# self.pushButton_22.setText(_translate("Form", "Earring"))
|
513 |
+
# self.pushButton_23.setText(_translate("Form", "Necklace"))
|
514 |
+
# self.pushButton_24.setText(_translate("Form", "Neck"))
|
515 |
+
# self.pushButton_25.setText(_translate("Form", "Cloth"))
|
516 |
+
# self.pushButton_26.setText(_translate("Form", "+"))
|
517 |
+
# self.pushButton_27.setText(_translate("Form", "-"))
|
518 |
+
self.pushButton_inputID.setText(_translate("Form", "Input ID"))
|
519 |
+
# self.pushButton_inverse.setText(_translate("Form", "Inverse"))
|
520 |
+
# self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
|
521 |
+
self.pushButton_get.setText(_translate("Form", "Get"))
|
522 |
+
|
523 |
+
|
524 |
+
class Ui_Form_Seg2cat(object):
|
525 |
+
def setupUi(self, Form):
|
526 |
+
Form.setObjectName("Form")
|
527 |
+
Form.resize(1800, 660)
|
528 |
+
self.pushButton = QtWidgets.QPushButton(Form)
|
529 |
+
# self.pushButton.setGeometry(QtCore.QRect(1160, 360, 81, 27))
|
530 |
+
self.pushButton.setGeometry(QtCore.QRect(685, 360, 81, 27))
|
531 |
+
self.pushButton.setObjectName("pushButton")
|
532 |
+
self.pushButton_2 = QtWidgets.QPushButton(Form)
|
533 |
+
self.pushButton_2.setGeometry(QtCore.QRect(10, 10, 97, 27))
|
534 |
+
self.pushButton_2.setObjectName("pushButton_2")
|
535 |
+
self.pushButton_3 = QtWidgets.QPushButton(Form)
|
536 |
+
self.pushButton_3.setGeometry(QtCore.QRect(10, 40, 97, 27))
|
537 |
+
self.pushButton_3.setObjectName("pushButton_3")
|
538 |
+
self.pushButton_4 = QtWidgets.QPushButton(Form)
|
539 |
+
self.pushButton_4.setGeometry(QtCore.QRect(130, 10, 97, 27))
|
540 |
+
self.pushButton_4.setObjectName("pushButton_4")
|
541 |
+
self.pushButton_5 = QtWidgets.QPushButton(Form)
|
542 |
+
self.pushButton_5.setGeometry(QtCore.QRect(130, 40, 97, 27))
|
543 |
+
self.pushButton_5.setObjectName("pushButton_5")
|
544 |
+
self.pushButton_6 = QtWidgets.QPushButton(Form)
|
545 |
+
self.pushButton_6.setGeometry(QtCore.QRect(250, 10, 97, 27))
|
546 |
+
self.pushButton_6.setObjectName("pushButton_6")
|
547 |
+
self.pushButton_7 = QtWidgets.QPushButton(Form)
|
548 |
+
self.pushButton_7.setGeometry(QtCore.QRect(250, 40, 97, 27))
|
549 |
+
self.pushButton_7.setObjectName("pushButton_7")
|
550 |
+
self.pushButton_8 = QtWidgets.QPushButton(Form)
|
551 |
+
self.pushButton_8.setGeometry(QtCore.QRect(450, 10, 97, 27))
|
552 |
+
self.pushButton_8.setObjectName("pushButton_8")
|
553 |
+
self.pushButton_9 = QtWidgets.QPushButton(Form)
|
554 |
+
self.pushButton_9.setGeometry(QtCore.QRect(450, 40, 97, 27))
|
555 |
+
self.pushButton_9.setObjectName("pushButton_9")
|
556 |
+
self.pushButton_10 = QtWidgets.QPushButton(Form)
|
557 |
+
self.pushButton_10.setGeometry(QtCore.QRect(570, 10, 97, 27))
|
558 |
+
self.pushButton_10.setObjectName("pushButton_10")
|
559 |
+
self.pushButton_11 = QtWidgets.QPushButton(Form)
|
560 |
+
self.pushButton_11.setGeometry(QtCore.QRect(570, 40, 97, 27))
|
561 |
+
self.pushButton_11.setObjectName("pushButton_11")
|
562 |
+
self.pushButton_12 = QtWidgets.QPushButton(Form)
|
563 |
+
self.pushButton_12.setGeometry(QtCore.QRect(690, 10, 97, 27))
|
564 |
+
self.pushButton_12.setObjectName("pushButton_12")
|
565 |
+
# self.pushButton_13 = QtWidgets.QPushButton(Form)
|
566 |
+
# self.pushButton_13.setGeometry(QtCore.QRect(690, 40, 97, 27))
|
567 |
+
# self.pushButton_13.setObjectName("pushButton_13")
|
568 |
+
# self.pushButton_14 = QtWidgets.QPushButton(Form)
|
569 |
+
# self.pushButton_14.setGeometry(QtCore.QRect(810, 10, 97, 27))
|
570 |
+
# self.pushButton_14.setObjectName("pushButton_14")
|
571 |
+
# self.pushButton_15 = QtWidgets.QPushButton(Form)
|
572 |
+
# self.pushButton_15.setGeometry(QtCore.QRect(810, 40, 97, 27))
|
573 |
+
# self.pushButton_15.setObjectName("pushButton_15")
|
574 |
+
# self.pushButton_16 = QtWidgets.QPushButton(Form)
|
575 |
+
# self.pushButton_16.setGeometry(QtCore.QRect(930, 10, 97, 27))
|
576 |
+
# self.pushButton_16.setObjectName("pushButton_16")
|
577 |
+
# self.pushButton_17 = QtWidgets.QPushButton(Form)
|
578 |
+
# self.pushButton_17.setGeometry(QtCore.QRect(930, 40, 97, 27))
|
579 |
+
# self.pushButton_17.setObjectName("pushButton_17")
|
580 |
+
# self.pushButton_18 = QtWidgets.QPushButton(Form)
|
581 |
+
# self.pushButton_18.setGeometry(QtCore.QRect(1050, 10, 97, 27))
|
582 |
+
# self.pushButton_18.setObjectName("pushButton_18")
|
583 |
+
# self.pushButton_19 = QtWidgets.QPushButton(Form)
|
584 |
+
# self.pushButton_19.setGeometry(QtCore.QRect(1050, 40, 97, 27))
|
585 |
+
# self.pushButton_19.setObjectName("pushButton_19")
|
586 |
+
# self.pushButton_20 = QtWidgets.QPushButton(Form)
|
587 |
+
# self.pushButton_20.setGeometry(QtCore.QRect(1170, 10, 97, 27))
|
588 |
+
# self.pushButton_20.setObjectName("pushButton_20")
|
589 |
+
# self.pushButton_21 = QtWidgets.QPushButton(Form)
|
590 |
+
# self.pushButton_21.setGeometry(QtCore.QRect(1170, 40, 97, 27))
|
591 |
+
# self.pushButton_21.setObjectName("pushButton_21")
|
592 |
+
# self.pushButton_22 = QtWidgets.QPushButton(Form)
|
593 |
+
# self.pushButton_22.setGeometry(QtCore.QRect(1290, 10, 97, 27))
|
594 |
+
# self.pushButton_22.setObjectName("pushButton_22")
|
595 |
+
# self.pushButton_23 = QtWidgets.QPushButton(Form)
|
596 |
+
# self.pushButton_23.setGeometry(QtCore.QRect(1290, 40, 97, 27))
|
597 |
+
# self.pushButton_23.setObjectName("pushButton_23")
|
598 |
+
# self.pushButton_24 = QtWidgets.QPushButton(Form)
|
599 |
+
# self.pushButton_24.setGeometry(QtCore.QRect(1410, 10, 97, 27))
|
600 |
+
# self.pushButton_24.setObjectName("pushButton_24")
|
601 |
+
# self.pushButton_25 = QtWidgets.QPushButton(Form)
|
602 |
+
# self.pushButton_25.setGeometry(QtCore.QRect(1410, 40, 97, 27))
|
603 |
+
# self.pushButton_25.setObjectName("pushButton_25")
|
604 |
+
# self.pushButton_26 = QtWidgets.QPushButton(Form)
|
605 |
+
# self.pushButton_26.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
606 |
+
# self.pushButton_26.setObjectName("pushButton_26")
|
607 |
+
# self.pushButton_27 = QtWidgets.QPushButton(Form)
|
608 |
+
# self.pushButton_27.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
609 |
+
# self.pushButton_27.setObjectName("pushButton_27")
|
610 |
+
|
611 |
+
self.slider_sizeselect = QtWidgets.QSlider(Form)
|
612 |
+
self.slider_sizeselect.setRange(10,70)
|
613 |
+
self.slider_sizeselect.setOrientation(Qt.Horizontal)
|
614 |
+
self.slider_sizeselect.setValue(Form.size)
|
615 |
+
self.slider_sizeselect.setGeometry(QtCore.QRect(1530, 100, 97, 27))
|
616 |
+
|
617 |
+
self.label_sizeselect = QtWidgets.QLabel(Form)
|
618 |
+
self.label_sizeselect.setText("Brush Size")
|
619 |
+
self.label_sizeselect.setGeometry(QtCore.QRect(1630, 100, 97, 27))
|
620 |
+
|
621 |
+
self.slider_yawselect = QtWidgets.QSlider(Form)
|
622 |
+
self.slider_yawselect.setRange(-100,100)
|
623 |
+
self.slider_yawselect.setOrientation(Qt.Horizontal)
|
624 |
+
self.slider_yawselect.setValue(Form.yaw)
|
625 |
+
self.slider_yawselect.setGeometry(QtCore.QRect(1530, 10, 97, 27))
|
626 |
+
|
627 |
+
self.label_yawselect = QtWidgets.QLabel(Form)
|
628 |
+
self.label_yawselect.setText("Yaw")
|
629 |
+
self.label_yawselect.setGeometry(QtCore.QRect(1630, 10, 97, 27))
|
630 |
+
|
631 |
+
self.slider_pitchselect = QtWidgets.QSlider(Form)
|
632 |
+
self.slider_pitchselect.setRange(-100,100)
|
633 |
+
self.slider_pitchselect.setOrientation(Qt.Horizontal)
|
634 |
+
self.slider_pitchselect.setValue(Form.pitch)
|
635 |
+
self.slider_pitchselect.setGeometry(QtCore.QRect(1530, 40, 97, 27))
|
636 |
+
|
637 |
+
self.label_pitchselect = QtWidgets.QLabel(Form)
|
638 |
+
self.label_pitchselect.setText("Pitch")
|
639 |
+
self.label_pitchselect.setGeometry(QtCore.QRect(1630, 40, 97, 27))
|
640 |
+
|
641 |
+
self.slider_rollselect = QtWidgets.QSlider(Form)
|
642 |
+
self.slider_rollselect.setRange(-100,100)
|
643 |
+
self.slider_rollselect.setOrientation(Qt.Horizontal)
|
644 |
+
self.slider_rollselect.setValue(Form.roll)
|
645 |
+
self.slider_rollselect.setGeometry(QtCore.QRect(1530, 70, 97, 27))
|
646 |
+
|
647 |
+
self.label_rollselect = QtWidgets.QLabel(Form)
|
648 |
+
self.label_rollselect.setText("Roll")
|
649 |
+
self.label_rollselect.setGeometry(QtCore.QRect(1630, 70, 97, 27))
|
650 |
+
|
651 |
+
|
652 |
+
self.slider_truncation = QtWidgets.QSlider(Form)
|
653 |
+
self.slider_truncation.setRange(0,100)
|
654 |
+
self.slider_truncation.setOrientation(Qt.Horizontal)
|
655 |
+
self.slider_truncation.setValue(Form.truncation)
|
656 |
+
self.slider_truncation.setGeometry(QtCore.QRect(1530, 130, 97, 27))
|
657 |
+
|
658 |
+
self.label_truncation = QtWidgets.QLabel(Form)
|
659 |
+
self.label_truncation.setText("Truncation")
|
660 |
+
self.label_truncation.setGeometry(QtCore.QRect(1630, 130, 97, 27))
|
661 |
+
|
662 |
+
self.text_inputID = QtWidgets.QTextEdit(Form)
|
663 |
+
self.text_inputID.setGeometry(QtCore.QRect(10, 70, 40, 27))
|
664 |
+
self.text_inputID.setObjectName("text_inputID")
|
665 |
+
|
666 |
+
self.pushButton_inputID = QtWidgets.QPushButton(Form)
|
667 |
+
self.pushButton_inputID.setGeometry(QtCore.QRect(60, 70, 60, 27))
|
668 |
+
self.pushButton_inputID.setObjectName("pushButton_inputID")
|
669 |
+
|
670 |
+
self.text_seed = QtWidgets.QTextEdit(Form)
|
671 |
+
self.text_seed.setGeometry(QtCore.QRect(140, 70, 40, 27))
|
672 |
+
self.text_seed.setObjectName("text_seed")
|
673 |
+
self.text_seed.setPlainText("0")
|
674 |
+
|
675 |
+
self.label_seed = QtWidgets.QLabel(Form)
|
676 |
+
self.label_seed.setText("Seed")
|
677 |
+
self.label_seed.setGeometry(QtCore.QRect(190, 70, 97, 27))
|
678 |
+
|
679 |
+
# self.pushButton_inverse = QtWidgets.QPushButton(Form)
|
680 |
+
# self.pushButton_inverse.setGeometry(QtCore.QRect(535, 400, 81, 27))
|
681 |
+
# self.pushButton_inverse.setObjectName("pushButton_inverse")
|
682 |
+
|
683 |
+
# self.pushButton_clear_ws = QtWidgets.QPushButton(Form)
|
684 |
+
# self.pushButton_clear_ws.setGeometry(QtCore.QRect(535, 430, 81, 27))
|
685 |
+
# self.pushButton_clear_ws.setObjectName("pushButton_clear_ws")
|
686 |
+
self.pushButton_get = QtWidgets.QPushButton(Form)
|
687 |
+
self.pushButton_get.setGeometry(QtCore.QRect(1500, 680 + 512 + 10, 81, 27))
|
688 |
+
self.pushButton_get.setObjectName("pushButton_get")
|
689 |
+
|
690 |
+
|
691 |
+
|
692 |
+
|
693 |
+
self.graphicsView = QtWidgets.QGraphicsView(Form)
|
694 |
+
self.graphicsView.setGeometry(QtCore.QRect(120, 120, 512, 512))
|
695 |
+
self.graphicsView.setObjectName("graphicsView")
|
696 |
+
self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
|
697 |
+
self.graphicsView_2.setGeometry(QtCore.QRect(820, 120, 512, 512))
|
698 |
+
self.graphicsView_2.setObjectName("graphicsView_2")
|
699 |
+
# self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
|
700 |
+
# self.graphicsView_3.setGeometry(QtCore.QRect(1260, 120, 512, 512))
|
701 |
+
# self.graphicsView_3.setObjectName("graphicsView_3")
|
702 |
+
|
703 |
+
# self.graphicsView_5 = QtWidgets.QGraphicsView(Form)
|
704 |
+
# self.graphicsView_5.setGeometry(QtCore.QRect(620, 680, 512, 512))
|
705 |
+
# self.graphicsView_5.setObjectName("graphicsView_5")
|
706 |
+
# self.graphicsView_6 = QtWidgets.QGraphicsView(Form)
|
707 |
+
# self.graphicsView_6.setGeometry(QtCore.QRect(1260, 680, 512, 512))
|
708 |
+
# self.graphicsView_6.setObjectName("graphicsView_6")
|
709 |
+
|
710 |
+
|
711 |
+
self.retranslateUi(Form)
|
712 |
+
self.pushButton.clicked.connect(Form.generateAndReconstruct)
|
713 |
+
self.pushButton_2.clicked.connect(Form.open)
|
714 |
+
self.pushButton_3.clicked.connect(Form.open_mask)
|
715 |
+
self.pushButton_4.clicked.connect(Form.clear)
|
716 |
+
self.pushButton_5.clicked.connect(Form.undo)
|
717 |
+
self.pushButton_6.clicked.connect(Form.save_img)
|
718 |
+
self.pushButton_7.clicked.connect(Form.bg_mode)
|
719 |
+
self.pushButton_8.clicked.connect(Form.skin_mode)
|
720 |
+
self.pushButton_9.clicked.connect(Form.nose_mode)
|
721 |
+
self.pushButton_10.clicked.connect(Form.eye_g_mode)
|
722 |
+
self.pushButton_11.clicked.connect(Form.l_eye_mode)
|
723 |
+
self.pushButton_12.clicked.connect(Form.r_eye_mode)
|
724 |
+
# self.pushButton_13.clicked.connect(Form.l_brow_mode)
|
725 |
+
# self.pushButton_14.clicked.connect(Form.r_brow_mode)
|
726 |
+
# self.pushButton_15.clicked.connect(Form.l_ear_mode)
|
727 |
+
# self.pushButton_16.clicked.connect(Form.r_ear_mode)
|
728 |
+
# self.pushButton_17.clicked.connect(Form.mouth_mode)
|
729 |
+
# self.pushButton_18.clicked.connect(Form.u_lip_mode)
|
730 |
+
# self.pushButton_19.clicked.connect(Form.l_lip_mode)
|
731 |
+
# self.pushButton_20.clicked.connect(Form.hair_mode)
|
732 |
+
# self.pushButton_21.clicked.connect(Form.hat_mode)
|
733 |
+
# self.pushButton_22.clicked.connect(Form.ear_r_mode)
|
734 |
+
# self.pushButton_23.clicked.connect(Form.neck_l_mode)
|
735 |
+
# self.pushButton_24.clicked.connect(Form.neck_mode)
|
736 |
+
# self.pushButton_25.clicked.connect(Form.cloth_mode)
|
737 |
+
# self.pushButton_26.clicked.connect(Form.increase)
|
738 |
+
# self.pushButton_27.clicked.connect(Form.decrease)
|
739 |
+
|
740 |
+
self.slider_sizeselect.valueChanged.connect(Form.changeBrushSize)
|
741 |
+
self.slider_yawselect.valueChanged.connect(Form.changeYaw)
|
742 |
+
self.slider_pitchselect.valueChanged.connect(Form.changePitch)
|
743 |
+
self.slider_rollselect.valueChanged.connect(Form.changeRoll)
|
744 |
+
self.slider_truncation.valueChanged.connect(Form.changeTruncation)
|
745 |
+
|
746 |
+
self.pushButton_inputID.clicked.connect(Form.inputID)
|
747 |
+
|
748 |
+
# self.pushButton_inverse.clicked.connect(Form.inverse)
|
749 |
+
# self.pushButton_clear_ws.clicked.connect(Form.clear_ws)
|
750 |
+
self.pushButton_get.clicked.connect(Form.get_mask)
|
751 |
+
|
752 |
+
QtCore.QMetaObject.connectSlotsByName(Form)
|
753 |
+
|
754 |
+
def retranslateUi(self, Form):
|
755 |
+
_translate = QtCore.QCoreApplication.translate
|
756 |
+
Form.setWindowTitle(_translate("Form", "3D-aware Conditional Image Synthesis (Seg2cat)"))
|
757 |
+
self.pushButton.setText(_translate("Form", "Generate"))
|
758 |
+
self.pushButton_2.setText(_translate("Form", "Open Image"))
|
759 |
+
self.pushButton_3.setText(_translate("Form", "Open Mask"))
|
760 |
+
self.pushButton_4.setText(_translate("Form", "Clear"))
|
761 |
+
self.pushButton_5.setText(_translate("Form", "Undo"))
|
762 |
+
self.pushButton_6.setText(_translate("Form", "Save Image"))
|
763 |
+
self.pushButton_7.setText(_translate("Form", "BackGround"))
|
764 |
+
self.pushButton_8.setText(_translate("Form", "Face"))
|
765 |
+
self.pushButton_9.setText(_translate("Form", "Ear"))
|
766 |
+
self.pushButton_10.setText(_translate("Form", "Mouth"))
|
767 |
+
self.pushButton_11.setText(_translate("Form", "Eyes"))
|
768 |
+
self.pushButton_12.setText(_translate("Form", "Whiskers"))
|
769 |
+
# self.pushButton_13.setText(_translate("Form", "Left Eyebrow"))
|
770 |
+
# self.pushButton_14.setText(_translate("Form", "Right Eyebrow"))
|
771 |
+
# self.pushButton_15.setText(_translate("Form", "Left ear"))
|
772 |
+
# self.pushButton_16.setText(_translate("Form", "Right ear"))
|
773 |
+
# self.pushButton_17.setText(_translate("Form", "Mouth"))
|
774 |
+
# self.pushButton_18.setText(_translate("Form", "Upper Lip"))
|
775 |
+
# self.pushButton_19.setText(_translate("Form", "Lower Lip"))
|
776 |
+
# self.pushButton_20.setText(_translate("Form", "Hair"))
|
777 |
+
# self.pushButton_21.setText(_translate("Form", "Hat"))
|
778 |
+
# self.pushButton_22.setText(_translate("Form", "Earring"))
|
779 |
+
# self.pushButton_23.setText(_translate("Form", "Necklace"))
|
780 |
+
# self.pushButton_24.setText(_translate("Form", "Neck"))
|
781 |
+
# self.pushButton_25.setText(_translate("Form", "Cloth"))
|
782 |
+
# self.pushButton_26.setText(_translate("Form", "+"))
|
783 |
+
# self.pushButton_27.setText(_translate("Form", "-"))
|
784 |
+
self.pushButton_inputID.setText(_translate("Form", "Input ID"))
|
785 |
+
# self.pushButton_inverse.setText(_translate("Form", "Inverse"))
|
786 |
+
# self.pushButton_clear_ws.setText(_translate("Form", "Clear ws"))
|
787 |
+
self.pushButton_get.setText(_translate("Form", "Get"))
|
788 |
+
|
789 |
+
if __name__ == "__main__":
|
790 |
+
import sys
|
791 |
+
app = QtWidgets.QApplication(sys.argv)
|
792 |
+
Form = QtWidgets.QWidget()
|
793 |
+
ui = Ui_Form()
|
794 |
+
ui.setupUi(Form)
|
795 |
+
Form.show()
|
796 |
+
sys.exit(app.exec_())
|
797 |
+
|
pix2pix3D-main/pix2pix3D-main/applications/edge2cat.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pix2pix3D-main/pix2pix3D-main/applications/extract_mesh.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./')
|
3 |
+
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from typing import List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import click
|
9 |
+
import dnnlib
|
10 |
+
import numpy as np
|
11 |
+
import PIL.Image
|
12 |
+
import torch
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
import legacy
|
17 |
+
from camera_utils import LookAtPoseSampler
|
18 |
+
|
19 |
+
from matplotlib import pyplot as plt
|
20 |
+
|
21 |
+
from pathlib import Path
|
22 |
+
|
23 |
+
import json
|
24 |
+
|
25 |
+
from training.utils import color_mask, color_list
|
26 |
+
|
27 |
+
from tqdm import tqdm
|
28 |
+
|
29 |
+
import imageio
|
30 |
+
|
31 |
+
import argparse
|
32 |
+
|
33 |
+
import trimesh
|
34 |
+
import pyrender
|
35 |
+
import mcubes
|
36 |
+
|
37 |
+
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
38 |
+
|
39 |
+
def init_conditional_dataset_kwargs(data, mask_data, data_type, resolution=None):
|
40 |
+
try:
|
41 |
+
if data_type =='seg':
|
42 |
+
dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageSegFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False, resolution=resolution)
|
43 |
+
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
|
44 |
+
dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
|
45 |
+
dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
|
46 |
+
dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
|
47 |
+
return dataset_kwargs, dataset_obj.name
|
48 |
+
elif data_type == 'edge':
|
49 |
+
dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageEdgeFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False)
|
50 |
+
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
|
51 |
+
dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
|
52 |
+
dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
|
53 |
+
dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
|
54 |
+
return dataset_kwargs, dataset_obj.name
|
55 |
+
else:
|
56 |
+
raise click.ClickException(f'Unknown data_type: {data_type}')
|
57 |
+
except IOError as err:
|
58 |
+
raise click.ClickException(f'--data: {err}')
|
59 |
+
|
60 |
+
def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64):
|
61 |
+
# return numpy array of forwarded sigma value
|
62 |
+
# bound = (nerf.rendering_kwargs['ray_end'] - nerf.rendering_kwargs['ray_start']) * 0.5
|
63 |
+
bound = nerf.rendering_kwargs['box_warp'] * 0.5
|
64 |
+
X = torch.linspace(-bound, bound, resolution).split(block_resolution)
|
65 |
+
|
66 |
+
sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
67 |
+
|
68 |
+
for xi, xs in enumerate(X):
|
69 |
+
for yi, ys in enumerate(X):
|
70 |
+
for zi, zs in enumerate(X):
|
71 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
72 |
+
pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C
|
73 |
+
block_shape = [1, len(xs), len(ys), len(zs)]
|
74 |
+
out = nerf.sample_mixed(pts.reshape(1,-1,3), None, ws=styles, noise_mode='const')
|
75 |
+
feat_out, sigma_out = out['rgb'], out['sigma']
|
76 |
+
sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \
|
77 |
+
yi * block_resolution: yi * block_resolution + len(ys), \
|
78 |
+
zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy()
|
79 |
+
# print(feat_out.shape)
|
80 |
+
|
81 |
+
return sigma_np, bound
|
82 |
+
|
83 |
+
|
84 |
+
def extract_geometry(nerf, styles, resolution, threshold):
|
85 |
+
|
86 |
+
# print('threshold: {}'.format(threshold))
|
87 |
+
u, bound = get_sigma_field_np(nerf, styles, resolution)
|
88 |
+
vertices, faces = mcubes.marching_cubes(u, threshold)
|
89 |
+
# vertices, faces, normals, values = skimage.measure.marching_cubes(
|
90 |
+
# u, level=10
|
91 |
+
# )
|
92 |
+
b_min_np = np.array([-bound, -bound, -bound])
|
93 |
+
b_max_np = np.array([ bound, bound, bound])
|
94 |
+
|
95 |
+
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
96 |
+
return vertices.astype('float32'), faces
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
def main():
|
101 |
+
# Parse arguments
|
102 |
+
parser = argparse.ArgumentParser(description='Generate samples from a trained model')
|
103 |
+
parser.add_argument('--network', help='Path to the network pickle file', required=True)
|
104 |
+
parser.add_argument('--outdir', help='Directory to save the output', required=True)
|
105 |
+
|
106 |
+
parser.add_argument('--input_id', type=int, default=0, help='Input label map id', required=False)
|
107 |
+
parser.add_argument('--data_dir', default='data/', help='Directory to the data', required=False)
|
108 |
+
parser.add_argument('--input', help='input label map', required=False)
|
109 |
+
parser.add_argument('--cfg', help='Base Configuration: seg2face, seg2cat, edge2car', required=True)
|
110 |
+
args = parser.parse_args()
|
111 |
+
device = 'cuda'
|
112 |
+
|
113 |
+
# Load the network
|
114 |
+
with dnnlib.util.open_url(args.network) as f:
|
115 |
+
G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
|
116 |
+
|
117 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
118 |
+
neural_rendering_resolution = 128
|
119 |
+
pitch_range, yaw_range = 0.25, 0.35
|
120 |
+
data_type = 'seg'
|
121 |
+
# Initialize pose sampler.
|
122 |
+
forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
|
123 |
+
radius=G.rendering_kwargs['avg_camera_radius'], device=device)
|
124 |
+
focal_length = 4.2647 # shapenet has higher FOV
|
125 |
+
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
|
126 |
+
forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
127 |
+
elif args.cfg == 'edge2car':
|
128 |
+
neural_rendering_resolution = 64
|
129 |
+
pitch_range, yaw_range = np.pi / 2, np.pi
|
130 |
+
data_type= 'edge'
|
131 |
+
|
132 |
+
forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
|
133 |
+
radius=G.rendering_kwargs['avg_camera_radius'], device=device)
|
134 |
+
focal_length = 1.7074 # shapenet has higher FOV
|
135 |
+
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
|
136 |
+
forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
137 |
+
else:
|
138 |
+
print('Invalid cfg')
|
139 |
+
return
|
140 |
+
|
141 |
+
save_dir = Path(args.outdir)
|
142 |
+
|
143 |
+
# Load the input label map
|
144 |
+
if args.input is not None:
|
145 |
+
input_label = PIL.Image.open(args.input)
|
146 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
147 |
+
input_label = np.array(input_label).astype(np.uint8)
|
148 |
+
input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
|
149 |
+
|
150 |
+
# Save the visualized input label map
|
151 |
+
PIL.Image.fromarray(color_mask(input_label[0,0].cpu().numpy()).astype(np.uint8)).save(save_dir / f'{args.cfg}_input.png')
|
152 |
+
elif args.cfg == 'edge2car':
|
153 |
+
input_label = np.array(input_label).astype(np.float32)[..., 0]
|
154 |
+
input_label = -(torch.tensor(input_label).to(torch.float32) / 127.5 - 1).unsqueeze(0).unsqueeze(0).to(device)
|
155 |
+
input_pose = forward_pose.to(device)
|
156 |
+
|
157 |
+
elif args.input_id is not None:
|
158 |
+
if args.cfg == 'seg2cat':
|
159 |
+
data_path = Path(args.data_dir) / 'afhq_v2_train_cat_512.zip'
|
160 |
+
mask_data = Path(args.data_dir) / 'afhqcat_seg_6c.zip'
|
161 |
+
elif args.cfg == 'edge2car':
|
162 |
+
data_path = Path(args.data_dir) / 'cars_128.zip'
|
163 |
+
mask_data = Path(args.data_dir) / 'shapenet_car_contour.zip'
|
164 |
+
elif args.cfg == 'seg2face':
|
165 |
+
data_path = Path(args.data_dir) / 'celebamask_test.zip'
|
166 |
+
mask_data = Path(args.data_dir) / 'celebamask_test_label.zip'
|
167 |
+
|
168 |
+
dataset_kwargs, dataset_name = init_conditional_dataset_kwargs(str(data_path), str(mask_data), data_type)
|
169 |
+
dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
|
170 |
+
batch = dataset[args.input_id]
|
171 |
+
|
172 |
+
save_dir = Path(args.outdir)
|
173 |
+
|
174 |
+
# Save the input label map
|
175 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
176 |
+
PIL.Image.fromarray(color_mask(batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
|
177 |
+
elif args.cfg == 'edge2car':
|
178 |
+
PIL.Image.fromarray((255 - batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
|
179 |
+
|
180 |
+
input_pose = torch.tensor(batch['pose']).unsqueeze(0).to(device)
|
181 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
182 |
+
input_label = torch.tensor(batch['mask']).unsqueeze(0).to(device)
|
183 |
+
elif args.cfg == 'edge2car':
|
184 |
+
input_label = -(torch.tensor(batch['mask']).to(torch.float32) / 127.5 - 1).unsqueeze(0).to(device)
|
185 |
+
|
186 |
+
# Generate videos
|
187 |
+
z = torch.from_numpy(np.random.RandomState(int(0)).randn(1, G.z_dim).astype('float32')).to(device)
|
188 |
+
|
189 |
+
with torch.no_grad():
|
190 |
+
ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose})
|
191 |
+
|
192 |
+
mesh_trimesh = trimesh.Trimesh(*extract_geometry(G, ws, resolution=512, threshold=50.))
|
193 |
+
|
194 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
195 |
+
|
196 |
+
verts_np = np.array(mesh_trimesh.vertices)
|
197 |
+
colors = torch.zeros((verts_np.shape[0], 3), device=device)
|
198 |
+
semantic_colors = torch.zeros((verts_np.shape[0], 6), device=device)
|
199 |
+
samples_color = torch.tensor(verts_np, device=device).unsqueeze(0).float()
|
200 |
+
|
201 |
+
head = 0
|
202 |
+
max_batch = 10000000
|
203 |
+
with tqdm(total = verts_np.shape[0]) as pbar:
|
204 |
+
with torch.no_grad():
|
205 |
+
while head < verts_np.shape[0]:
|
206 |
+
torch.manual_seed(0)
|
207 |
+
out = G.sample_mixed(samples_color[:, head:head+max_batch], None, ws, truncation_psi=1, noise_mode='const')
|
208 |
+
# sigma = out['sigma']
|
209 |
+
colors[head:head+max_batch, :] = out['rgb'][0,:,:3]
|
210 |
+
seg = out['rgb'][0, :, 32:32+6]
|
211 |
+
semantic_colors[head:head+max_batch, :] = seg
|
212 |
+
# semantics[:, head:head+max_batch] = out['semantic']
|
213 |
+
head += max_batch
|
214 |
+
pbar.update(max_batch)
|
215 |
+
|
216 |
+
semantic_colors = torch.tensor(color_list)[torch.argmax(semantic_colors, dim=-1)]
|
217 |
+
|
218 |
+
mesh_trimesh.visual.vertex_colors = semantic_colors.cpu().numpy().astype(np.uint8)
|
219 |
+
|
220 |
+
# Save mesh.
|
221 |
+
mesh_trimesh.export(os.path.join(save_dir, f'semantic_mesh.ply'))
|
222 |
+
elif args.cfg == 'edge2car':
|
223 |
+
# Save mesh.
|
224 |
+
mesh_trimesh.export(os.path.join(save_dir, f'{args.cfg}_mesh.ply'))
|
225 |
+
|
226 |
+
mesh = pyrender.Mesh.from_trimesh(mesh_trimesh)
|
227 |
+
light = pyrender.SpotLight(color=np.ones(3), intensity=3.0,
|
228 |
+
innerConeAngle=np.pi/4)
|
229 |
+
r = pyrender.OffscreenRenderer(512, 512)
|
230 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
231 |
+
camera = pyrender.OrthographicCamera(xmag=0.3, ymag=0.3)
|
232 |
+
|
233 |
+
elif args.cfg == 'edge2car':
|
234 |
+
camera = pyrender.OrthographicCamera(xmag=0.6, ymag=0.6)
|
235 |
+
|
236 |
+
|
237 |
+
frames_mesh = []
|
238 |
+
num_frames = 120
|
239 |
+
|
240 |
+
for frame_idx in tqdm(range(num_frames)):
|
241 |
+
scene = pyrender.Scene()
|
242 |
+
scene.add(mesh)
|
243 |
+
|
244 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
245 |
+
camera_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
|
246 |
+
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
|
247 |
+
torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=1, device=device)
|
248 |
+
elif args.cfg == 'edge2car':
|
249 |
+
camera_pose = LookAtPoseSampler.sample(-3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
|
250 |
+
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
|
251 |
+
torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=1.2, device=device)
|
252 |
+
camera_pose = camera_pose.reshape(4, 4).cpu().numpy().copy()
|
253 |
+
camera_pose[:, 1] = -camera_pose[:, 1]
|
254 |
+
camera_pose[:, 2] = -camera_pose[:, 2]
|
255 |
+
|
256 |
+
scene.add(camera, pose=camera_pose)
|
257 |
+
scene.add(light, pose=camera_pose)
|
258 |
+
color, depth = r.render(scene)
|
259 |
+
frames_mesh.append(color)
|
260 |
+
|
261 |
+
imageio.mimsave(os.path.join(save_dir, f'rendered_mesh.gif'), frames_mesh, fps=60)
|
262 |
+
r.delete()
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
if __name__ == '__main__':
|
267 |
+
main()
|
pix2pix3D-main/pix2pix3D-main/applications/generate_samples.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./')
|
3 |
+
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from typing import List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import click
|
9 |
+
import dnnlib
|
10 |
+
import numpy as np
|
11 |
+
import PIL.Image
|
12 |
+
import torch
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
import legacy
|
17 |
+
|
18 |
+
from matplotlib import pyplot as plt
|
19 |
+
|
20 |
+
from pathlib import Path
|
21 |
+
|
22 |
+
import json
|
23 |
+
|
24 |
+
from training.utils import color_mask, color_list
|
25 |
+
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
import argparse
|
29 |
+
|
30 |
+
def init_conditional_dataset_kwargs(data, mask_data, data_type, resolution=None):
|
31 |
+
try:
|
32 |
+
if data_type =='seg':
|
33 |
+
dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageSegFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False, resolution=resolution)
|
34 |
+
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
|
35 |
+
dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
|
36 |
+
dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
|
37 |
+
dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
|
38 |
+
return dataset_kwargs, dataset_obj.name
|
39 |
+
elif data_type == 'edge':
|
40 |
+
dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageEdgeFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False)
|
41 |
+
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
|
42 |
+
dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
|
43 |
+
dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
|
44 |
+
dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
|
45 |
+
return dataset_kwargs, dataset_obj.name
|
46 |
+
else:
|
47 |
+
raise click.ClickException(f'Unknown data_type: {data_type}')
|
48 |
+
except IOError as err:
|
49 |
+
raise click.ClickException(f'--data: {err}')
|
50 |
+
|
51 |
+
def main():
|
52 |
+
# Parse arguments
|
53 |
+
parser = argparse.ArgumentParser(description='Generate samples from a trained model')
|
54 |
+
parser.add_argument('--network', help='Path to the network pickle file', required=True)
|
55 |
+
parser.add_argument('--outdir', help='Directory to save the output', required=True)
|
56 |
+
# Define an argument of a list of random seeds
|
57 |
+
parser.add_argument('--random_seed', help='Random seed', nargs="+", type=int)
|
58 |
+
|
59 |
+
parser.add_argument('--input_id', type=int, default=0, help='Input label map id', required=True)
|
60 |
+
parser.add_argument('--data_dir', default='data/', help='Directory to the data', required=False)
|
61 |
+
parser.add_argument('--cfg', help='Base Configuration: seg2face, seg2cat, edge2car', required=True)
|
62 |
+
args = parser.parse_args()
|
63 |
+
device = 'cuda'
|
64 |
+
|
65 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
66 |
+
neural_rendering_resolution = 128
|
67 |
+
data_type = 'seg'
|
68 |
+
elif args.cfg == 'edge2car':
|
69 |
+
neural_rendering_resolution = 64
|
70 |
+
data_type= 'edge'
|
71 |
+
else:
|
72 |
+
print('Invalid cfg')
|
73 |
+
return
|
74 |
+
|
75 |
+
# Load the network
|
76 |
+
with dnnlib.util.open_url(args.network) as f:
|
77 |
+
G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
|
78 |
+
|
79 |
+
# Load the input label map
|
80 |
+
# Initialize dataset.
|
81 |
+
if args.cfg == 'seg2cat':
|
82 |
+
data_path = Path(args.data_dir) / 'afhq_v2_train_cat_512.zip'
|
83 |
+
mask_data = Path(args.data_dir) / 'afhqcat_seg_6c.zip'
|
84 |
+
elif args.cfg == 'edge2car':
|
85 |
+
data_path = Path(args.data_dir) / 'cars_128.zip'
|
86 |
+
mask_data = Path(args.data_dir) / 'shapenet_car_contour.zip'
|
87 |
+
elif args.cfg == 'seg2face':
|
88 |
+
data_path = Path(args.data_dir) / 'celebamask_test.zip'
|
89 |
+
mask_data = Path(args.data_dir) / 'celebamask_test_label.zip'
|
90 |
+
|
91 |
+
dataset_kwargs, dataset_name = init_conditional_dataset_kwargs(str(data_path), str(mask_data), data_type)
|
92 |
+
dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
|
93 |
+
batch = dataset[args.input_id]
|
94 |
+
|
95 |
+
save_dir = Path(args.outdir)
|
96 |
+
|
97 |
+
# Save the input label map
|
98 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
99 |
+
PIL.Image.fromarray(color_mask(batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
|
100 |
+
elif args.cfg == 'edge2car':
|
101 |
+
PIL.Image.fromarray((255 - batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
|
102 |
+
|
103 |
+
# Generate samples
|
104 |
+
for seed in args.random_seed:
|
105 |
+
z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, G.z_dim).astype('float32')).to(device)
|
106 |
+
input_pose = torch.tensor(batch['pose']).unsqueeze(0).to(device)
|
107 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
108 |
+
input_label = torch.tensor(batch['mask']).unsqueeze(0).to(device)
|
109 |
+
elif args.cfg == 'edge2car':
|
110 |
+
input_label = -(torch.tensor(batch['mask']).to(torch.float32) / 127.5 - 1).unsqueeze(0).to(device)
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose})
|
114 |
+
out = G.synthesis(ws, input_pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
|
115 |
+
|
116 |
+
image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8)
|
117 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
118 |
+
image_label = color_mask(torch.argmax(out['semantic'][0], dim=0).cpu().numpy()).astype(np.uint8)
|
119 |
+
elif args.cfg == 'edge2car':
|
120 |
+
image_label = ((out['semantic'][0].cpu().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8)[0]
|
121 |
+
|
122 |
+
PIL.Image.fromarray(image_color).save(save_dir / f'{args.cfg}_{args.input_id}_{seed}_color.png')
|
123 |
+
PIL.Image.fromarray(image_label).save(save_dir / f'{args.cfg}_{args.input_id}_{seed}_label.png')
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
main()
|
pix2pix3D-main/pix2pix3D-main/applications/generate_video.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./')
|
3 |
+
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from typing import List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import click
|
9 |
+
import dnnlib
|
10 |
+
import numpy as np
|
11 |
+
import PIL.Image
|
12 |
+
import torch
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
import legacy
|
17 |
+
from camera_utils import LookAtPoseSampler
|
18 |
+
|
19 |
+
from matplotlib import pyplot as plt
|
20 |
+
|
21 |
+
from pathlib import Path
|
22 |
+
|
23 |
+
import json
|
24 |
+
|
25 |
+
from training.utils import color_mask, color_list
|
26 |
+
|
27 |
+
from tqdm import tqdm
|
28 |
+
|
29 |
+
import imageio
|
30 |
+
|
31 |
+
import argparse
|
32 |
+
|
33 |
+
def init_conditional_dataset_kwargs(data, mask_data, data_type, resolution=None):
|
34 |
+
try:
|
35 |
+
if data_type =='seg':
|
36 |
+
dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageSegFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False, resolution=resolution)
|
37 |
+
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
|
38 |
+
dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
|
39 |
+
dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
|
40 |
+
dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
|
41 |
+
return dataset_kwargs, dataset_obj.name
|
42 |
+
elif data_type == 'edge':
|
43 |
+
dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageEdgeFolderDataset', path=data, mask_path=mask_data, data_type=data_type, use_labels=True, max_size=None, xflip=False)
|
44 |
+
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
|
45 |
+
dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
|
46 |
+
dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
|
47 |
+
dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
|
48 |
+
return dataset_kwargs, dataset_obj.name
|
49 |
+
else:
|
50 |
+
raise click.ClickException(f'Unknown data_type: {data_type}')
|
51 |
+
except IOError as err:
|
52 |
+
raise click.ClickException(f'--data: {err}')
|
53 |
+
|
54 |
+
def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution = 128, device='cuda'):
|
55 |
+
frames, frames_label = [], []
|
56 |
+
|
57 |
+
for frame_idx in tqdm(range(num_frames)):
|
58 |
+
cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
|
59 |
+
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
|
60 |
+
torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
|
61 |
+
pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
62 |
+
with torch.no_grad():
|
63 |
+
out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
|
64 |
+
# frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
|
65 |
+
image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8)
|
66 |
+
frames.append(image_color)
|
67 |
+
frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8))
|
68 |
+
|
69 |
+
return frames, frames_label
|
70 |
+
|
71 |
+
def render_video_edge(G, ws, intrinsics, num_frames = 120, pitch_range = np.pi / 2, yaw_range = np.pi, neural_rendering_resolution = 64, device='cuda'):
|
72 |
+
frames, frames_label = [], []
|
73 |
+
|
74 |
+
for frame_idx in tqdm(range(num_frames)):
|
75 |
+
cam2world_pose = LookAtPoseSampler.sample(-3.14/2 + yaw_range * np.cos(2 * 3.14 * frame_idx / num_frames),
|
76 |
+
3.14/2 -0.05 + pitch_range * np.sin(2 * 3.14 * frame_idx / num_frames),
|
77 |
+
torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
|
78 |
+
pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
79 |
+
with torch.no_grad():
|
80 |
+
out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
|
81 |
+
frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
|
82 |
+
frames_label.append(((out['semantic'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8)[0])
|
83 |
+
|
84 |
+
return frames, frames_label
|
85 |
+
|
86 |
+
def render_video_edge2cat(G, ws, intrinsics, num_frames = 120, pitch_range = np.pi / 2, yaw_range = np.pi, neural_rendering_resolution = 64, device='cuda'):
|
87 |
+
frames, frames_label = [], []
|
88 |
+
|
89 |
+
for frame_idx in tqdm(range(num_frames)):
|
90 |
+
cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
|
91 |
+
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
|
92 |
+
torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
|
93 |
+
pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
94 |
+
with torch.no_grad():
|
95 |
+
out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
|
96 |
+
frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
|
97 |
+
frames_label.append(((out['semantic'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8)[0])
|
98 |
+
|
99 |
+
return frames, frames_label
|
100 |
+
|
101 |
+
def main():
|
102 |
+
# Parse arguments
|
103 |
+
parser = argparse.ArgumentParser(description='Generate samples from a trained model')
|
104 |
+
parser.add_argument('--network', help='Path to the network pickle file', required=True)
|
105 |
+
parser.add_argument('--outdir', help='Directory to save the output', required=True)
|
106 |
+
# Define an argument of a list of random seeds
|
107 |
+
parser.add_argument('--random_seed', help='Random seed', nargs="+", type=int)
|
108 |
+
|
109 |
+
parser.add_argument('--input_id', type=int, default=0, help='Input label map id', required=False)
|
110 |
+
parser.add_argument('--data_dir', default='data/', help='Directory to the data', required=False)
|
111 |
+
parser.add_argument('--input', help='input label map', required=False)
|
112 |
+
parser.add_argument('--cfg', help='Base Configuration: seg2face, seg2cat, edge2car', required=True)
|
113 |
+
args = parser.parse_args()
|
114 |
+
device = 'cuda'
|
115 |
+
|
116 |
+
# Load the network
|
117 |
+
with dnnlib.util.open_url(args.network) as f:
|
118 |
+
G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
|
119 |
+
|
120 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face' or args.cfg == 'edge2cat':
|
121 |
+
neural_rendering_resolution = 128
|
122 |
+
pitch_range, yaw_range = 0.25, 0.35
|
123 |
+
data_type = 'seg'
|
124 |
+
# Initialize pose sampler.
|
125 |
+
forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
|
126 |
+
radius=G.rendering_kwargs['avg_camera_radius'], device=device)
|
127 |
+
focal_length = 4.2647
|
128 |
+
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
|
129 |
+
forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
130 |
+
elif args.cfg == 'edge2car':
|
131 |
+
neural_rendering_resolution = 64
|
132 |
+
pitch_range, yaw_range = np.pi / 2, np.pi
|
133 |
+
data_type= 'edge'
|
134 |
+
|
135 |
+
forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
|
136 |
+
radius=G.rendering_kwargs['avg_camera_radius'], device=device)
|
137 |
+
focal_length = 1.7074 # shapenet has higher FOV
|
138 |
+
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
|
139 |
+
forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
140 |
+
else:
|
141 |
+
print('Invalid cfg')
|
142 |
+
return
|
143 |
+
|
144 |
+
save_dir = Path(args.outdir)
|
145 |
+
|
146 |
+
# Load the input label map
|
147 |
+
if args.input is not None:
|
148 |
+
input_label = PIL.Image.open(args.input)
|
149 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
150 |
+
input_label = np.array(input_label).astype(np.uint8)
|
151 |
+
input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
|
152 |
+
|
153 |
+
# Save the visualized input label map
|
154 |
+
PIL.Image.fromarray(color_mask(input_label[0,0].cpu().numpy()).astype(np.uint8)).save(save_dir / f'{args.cfg}_input.png')
|
155 |
+
elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
|
156 |
+
input_label = np.array(input_label).astype(np.float32)
|
157 |
+
if input_label.ndim == 3:
|
158 |
+
input_label = input_label[:,:,0]
|
159 |
+
print(input_label.min(), input_label.max())
|
160 |
+
input_label = (torch.tensor(input_label).to(torch.float32) / 127.5 - 1).unsqueeze(0).unsqueeze(0).to(device)
|
161 |
+
plt.imshow(input_label.cpu().numpy()[0,0], cmap='gray')
|
162 |
+
plt.savefig(save_dir / f'{args.cfg}_input.png')
|
163 |
+
|
164 |
+
input_pose = forward_pose.to(device)
|
165 |
+
|
166 |
+
elif args.input_id is not None:
|
167 |
+
if args.cfg == 'seg2cat':
|
168 |
+
data_path = Path(args.data_dir) / 'afhq_v2_train_cat_512.zip'
|
169 |
+
mask_data = Path(args.data_dir) / 'afhqcat_seg_6c.zip'
|
170 |
+
elif args.cfg == 'edge2car':
|
171 |
+
data_path = Path(args.data_dir) / 'cars_128.zip'
|
172 |
+
mask_data = Path(args.data_dir) / 'shapenet_car_contour.zip'
|
173 |
+
elif args.cfg == 'seg2face':
|
174 |
+
# data_path = Path(args.data_dir) / 'celebamask_test.zip'
|
175 |
+
# mask_data = Path(args.data_dir) / 'celebamask_test_label.zip'
|
176 |
+
data_path = '/data2/datasets/CelebAMask_eg3d/test/celebamask_test.zip'
|
177 |
+
mask_data = '/data2/datasets/CelebAMask_eg3d/test/celebamask_test_label.zip'
|
178 |
+
elif args.cfg == 'edge2cat':
|
179 |
+
data_path = '/data2/datasets/AFHQ_eg3d/afhq_v2_train_cat_512.zip'
|
180 |
+
mask_data = '/data2/datasets/AFHQ_eg3d/afhqcat_contour_pidinet.zip'
|
181 |
+
|
182 |
+
dataset_kwargs, dataset_name = init_conditional_dataset_kwargs(str(data_path), str(mask_data), data_type)
|
183 |
+
dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
|
184 |
+
batch = dataset[args.input_id]
|
185 |
+
|
186 |
+
save_dir = Path(args.outdir)
|
187 |
+
|
188 |
+
# Save the input label map
|
189 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
190 |
+
PIL.Image.fromarray(color_mask(batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
|
191 |
+
elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
|
192 |
+
PIL.Image.fromarray((255 - batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
|
193 |
+
|
194 |
+
input_pose = torch.tensor(batch['pose']).unsqueeze(0).to(device)
|
195 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
196 |
+
input_label = torch.tensor(batch['mask']).unsqueeze(0).to(device)
|
197 |
+
elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
|
198 |
+
input_label = -(torch.tensor(batch['mask']).to(torch.float32) / 127.5 - 1).unsqueeze(0).to(device)
|
199 |
+
|
200 |
+
# Generate videos
|
201 |
+
for seed in args.random_seed:
|
202 |
+
z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, G.z_dim).astype('float32')).to(device)
|
203 |
+
|
204 |
+
with torch.no_grad():
|
205 |
+
ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose})
|
206 |
+
|
207 |
+
# Generate the video
|
208 |
+
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
|
209 |
+
frames, frames_label = render_video(G, ws, intrinsics, num_frames = 120, pitch_range = pitch_range, yaw_range = yaw_range, neural_rendering_resolution=neural_rendering_resolution, device=device)
|
210 |
+
elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
|
211 |
+
frames, frames_label = render_video_edge2cat(G, ws, intrinsics, num_frames = 120, pitch_range = pitch_range, yaw_range = yaw_range, neural_rendering_resolution=neural_rendering_resolution, device=device)
|
212 |
+
|
213 |
+
# Save the video
|
214 |
+
imageio.mimsave(save_dir / f'{args.cfg}_{seed}.gif', frames, fps=60)
|
215 |
+
imageio.mimsave(save_dir / f'{args.cfg}_{seed}_label.gif', frames_label, fps=60)
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
if __name__ == '__main__':
|
220 |
+
main()
|
pix2pix3D-main/pix2pix3D-main/assets/demo.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:021005482c8f89c0d62b62709020c7e7e3de05152132d2c1dfa1b9256962fe59
|
3 |
+
size 6362125
|
pix2pix3D-main/pix2pix3D-main/assets/rendered_mesh_colored.gif
ADDED
![]() |
Git LFS Details
|
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1.gif
ADDED
![]() |
Git LFS Details
|
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_color.png
ADDED
![]() |
Git LFS Details
|
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_1_label.png
ADDED
![]() |
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1666_input.png
ADDED
![]() |
pix2pix3D-main/pix2pix3D-main/assets/seg2cat_1_label.gif
ADDED
![]() |
Git LFS Details
|
pix2pix3D-main/pix2pix3D-main/assets/teaser_gif.gif
ADDED
![]() |
Git LFS Details
|
pix2pix3D-main/pix2pix3D-main/assets/teaser_jpg.jpg
ADDED
![]() |
Git LFS Details
|
pix2pix3D-main/pix2pix3D-main/assets/teaser_png.png
ADDED
![]() |
Git LFS Details
|
pix2pix3D-main/pix2pix3D-main/camera_utils.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""
|
12 |
+
Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
|
13 |
+
"""
|
14 |
+
|
15 |
+
import math
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from training.volumetric_rendering import math_utils
|
21 |
+
|
22 |
+
class GaussianCameraPoseSampler:
|
23 |
+
"""
|
24 |
+
Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
|
25 |
+
Camera is specified as looking at the origin.
|
26 |
+
If horizontal and vertical stddev (specified in radians) are zero, gives a
|
27 |
+
deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
|
28 |
+
The coordinate system is specified with y-up, z-forward, x-left.
|
29 |
+
Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
|
30 |
+
vertical mean is the polar angle (angle from the y axis) in radians.
|
31 |
+
A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
|
32 |
+
|
33 |
+
Example:
|
34 |
+
For a camera pose looking at the origin with the camera at position [0, 0, 1]:
|
35 |
+
cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
|
36 |
+
"""
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
|
40 |
+
h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
|
41 |
+
v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
|
42 |
+
v = torch.clamp(v, 1e-5, math.pi - 1e-5)
|
43 |
+
|
44 |
+
theta = h
|
45 |
+
v = v / math.pi
|
46 |
+
phi = torch.arccos(1 - 2*v)
|
47 |
+
|
48 |
+
camera_origins = torch.zeros((batch_size, 3), device=device)
|
49 |
+
|
50 |
+
camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
|
51 |
+
camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
|
52 |
+
camera_origins[:, 1:2] = radius*torch.cos(phi)
|
53 |
+
|
54 |
+
forward_vectors = math_utils.normalize_vecs(-camera_origins)
|
55 |
+
return create_cam2world_matrix(forward_vectors, camera_origins)
|
56 |
+
|
57 |
+
|
58 |
+
class LookAtPoseSampler:
|
59 |
+
"""
|
60 |
+
Same as GaussianCameraPoseSampler, except the
|
61 |
+
camera is specified as looking at 'lookat_position', a 3-vector.
|
62 |
+
|
63 |
+
Example:
|
64 |
+
For a camera pose looking at the origin with the camera at position [0, 0, 1]:
|
65 |
+
cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
|
66 |
+
"""
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
|
70 |
+
h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
|
71 |
+
v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
|
72 |
+
v = torch.clamp(v, 1e-5, math.pi - 1e-5)
|
73 |
+
|
74 |
+
theta = h
|
75 |
+
v = v / math.pi
|
76 |
+
phi = torch.arccos(1 - 2*v)
|
77 |
+
|
78 |
+
camera_origins = torch.zeros((batch_size, 3), device=device)
|
79 |
+
|
80 |
+
camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
|
81 |
+
camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
|
82 |
+
camera_origins[:, 1:2] = radius*torch.cos(phi)
|
83 |
+
|
84 |
+
# forward_vectors = math_utils.normalize_vecs(-camera_origins)
|
85 |
+
forward_vectors = math_utils.normalize_vecs(lookat_position - camera_origins)
|
86 |
+
return create_cam2world_matrix(forward_vectors, camera_origins)
|
87 |
+
|
88 |
+
class UniformCameraPoseSampler:
|
89 |
+
"""
|
90 |
+
Same as GaussianCameraPoseSampler, except the
|
91 |
+
pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev.
|
92 |
+
|
93 |
+
Example:
|
94 |
+
For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
|
95 |
+
|
96 |
+
cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
|
97 |
+
"""
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
|
101 |
+
h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean
|
102 |
+
v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean
|
103 |
+
v = torch.clamp(v, 1e-5, math.pi - 1e-5)
|
104 |
+
|
105 |
+
theta = h
|
106 |
+
v = v / math.pi
|
107 |
+
phi = torch.arccos(1 - 2*v)
|
108 |
+
|
109 |
+
camera_origins = torch.zeros((batch_size, 3), device=device)
|
110 |
+
|
111 |
+
camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
|
112 |
+
camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
|
113 |
+
camera_origins[:, 1:2] = radius*torch.cos(phi)
|
114 |
+
|
115 |
+
forward_vectors = math_utils.normalize_vecs(-camera_origins)
|
116 |
+
return create_cam2world_matrix(forward_vectors, camera_origins)
|
117 |
+
|
118 |
+
def create_cam2world_matrix(forward_vector, origin):
|
119 |
+
"""
|
120 |
+
Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
|
121 |
+
Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll.
|
122 |
+
"""
|
123 |
+
|
124 |
+
forward_vector = math_utils.normalize_vecs(forward_vector)
|
125 |
+
up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector)
|
126 |
+
|
127 |
+
right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
|
128 |
+
up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1))
|
129 |
+
|
130 |
+
rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
|
131 |
+
rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1)
|
132 |
+
|
133 |
+
translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
|
134 |
+
translation_matrix[:, :3, 3] = origin
|
135 |
+
cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
|
136 |
+
assert(cam2world.shape[1:] == (4, 4))
|
137 |
+
return cam2world
|
138 |
+
|
139 |
+
|
140 |
+
def FOV_to_intrinsics(fov_degrees, device='cpu'):
|
141 |
+
"""
|
142 |
+
Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
|
143 |
+
Note the intrinsics are returned as normalized by image size, rather than in pixel units.
|
144 |
+
Assumes principal point is at image center.
|
145 |
+
"""
|
146 |
+
|
147 |
+
focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
|
148 |
+
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
|
149 |
+
return intrinsics
|
pix2pix3D-main/pix2pix3D-main/checkpoints/download_models.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cd checkpoints
|
2 |
+
wget http://cs.cmu.edu/~pix2pix3D/release.tar
|
3 |
+
tar -xvf release.tar
|
4 |
+
rm release.tar
|
5 |
+
cd ..
|
pix2pix3D-main/pix2pix3D-main/dnnlib/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
from .util import EasyDict, make_cache_dir_path
|
pix2pix3D-main/pix2pix3D-main/dnnlib/util.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Miscellaneous utility classes and functions."""
|
12 |
+
|
13 |
+
import ctypes
|
14 |
+
import fnmatch
|
15 |
+
import importlib
|
16 |
+
import inspect
|
17 |
+
import numpy as np
|
18 |
+
import os
|
19 |
+
import shutil
|
20 |
+
import sys
|
21 |
+
import types
|
22 |
+
import io
|
23 |
+
import pickle
|
24 |
+
import re
|
25 |
+
import requests
|
26 |
+
import html
|
27 |
+
import hashlib
|
28 |
+
import glob
|
29 |
+
import tempfile
|
30 |
+
import urllib
|
31 |
+
import urllib.request
|
32 |
+
import uuid
|
33 |
+
|
34 |
+
from distutils.util import strtobool
|
35 |
+
from typing import Any, List, Tuple, Union
|
36 |
+
|
37 |
+
|
38 |
+
# Util classes
|
39 |
+
# ------------------------------------------------------------------------------------------
|
40 |
+
|
41 |
+
|
42 |
+
class EasyDict(dict):
|
43 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
44 |
+
|
45 |
+
def __getattr__(self, name: str) -> Any:
|
46 |
+
try:
|
47 |
+
return self[name]
|
48 |
+
except KeyError:
|
49 |
+
raise AttributeError(name)
|
50 |
+
|
51 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
52 |
+
self[name] = value
|
53 |
+
|
54 |
+
def __delattr__(self, name: str) -> None:
|
55 |
+
del self[name]
|
56 |
+
|
57 |
+
|
58 |
+
class Logger(object):
|
59 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
60 |
+
|
61 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
62 |
+
self.file = None
|
63 |
+
|
64 |
+
if file_name is not None:
|
65 |
+
self.file = open(file_name, file_mode)
|
66 |
+
|
67 |
+
self.should_flush = should_flush
|
68 |
+
self.stdout = sys.stdout
|
69 |
+
self.stderr = sys.stderr
|
70 |
+
|
71 |
+
sys.stdout = self
|
72 |
+
sys.stderr = self
|
73 |
+
|
74 |
+
def __enter__(self) -> "Logger":
|
75 |
+
return self
|
76 |
+
|
77 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
78 |
+
self.close()
|
79 |
+
|
80 |
+
def write(self, text: Union[str, bytes]) -> None:
|
81 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
82 |
+
if isinstance(text, bytes):
|
83 |
+
text = text.decode()
|
84 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
85 |
+
return
|
86 |
+
|
87 |
+
if self.file is not None:
|
88 |
+
self.file.write(text)
|
89 |
+
|
90 |
+
self.stdout.write(text)
|
91 |
+
|
92 |
+
if self.should_flush:
|
93 |
+
self.flush()
|
94 |
+
|
95 |
+
def flush(self) -> None:
|
96 |
+
"""Flush written text to both stdout and a file, if open."""
|
97 |
+
if self.file is not None:
|
98 |
+
self.file.flush()
|
99 |
+
|
100 |
+
self.stdout.flush()
|
101 |
+
|
102 |
+
def close(self) -> None:
|
103 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
104 |
+
self.flush()
|
105 |
+
|
106 |
+
# if using multiple loggers, prevent closing in wrong order
|
107 |
+
if sys.stdout is self:
|
108 |
+
sys.stdout = self.stdout
|
109 |
+
if sys.stderr is self:
|
110 |
+
sys.stderr = self.stderr
|
111 |
+
|
112 |
+
if self.file is not None:
|
113 |
+
self.file.close()
|
114 |
+
self.file = None
|
115 |
+
|
116 |
+
|
117 |
+
# Cache directories
|
118 |
+
# ------------------------------------------------------------------------------------------
|
119 |
+
|
120 |
+
_dnnlib_cache_dir = None
|
121 |
+
|
122 |
+
def set_cache_dir(path: str) -> None:
|
123 |
+
global _dnnlib_cache_dir
|
124 |
+
_dnnlib_cache_dir = path
|
125 |
+
|
126 |
+
def make_cache_dir_path(*paths: str) -> str:
|
127 |
+
if _dnnlib_cache_dir is not None:
|
128 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
129 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
130 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
131 |
+
if 'HOME' in os.environ:
|
132 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
133 |
+
if 'USERPROFILE' in os.environ:
|
134 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
135 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
136 |
+
|
137 |
+
# Small util functions
|
138 |
+
# ------------------------------------------------------------------------------------------
|
139 |
+
|
140 |
+
|
141 |
+
def format_time(seconds: Union[int, float]) -> str:
|
142 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
143 |
+
s = int(np.rint(seconds))
|
144 |
+
|
145 |
+
if s < 60:
|
146 |
+
return "{0}s".format(s)
|
147 |
+
elif s < 60 * 60:
|
148 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
149 |
+
elif s < 24 * 60 * 60:
|
150 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
151 |
+
else:
|
152 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
153 |
+
|
154 |
+
|
155 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
156 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
157 |
+
s = int(np.rint(seconds))
|
158 |
+
|
159 |
+
if s < 60:
|
160 |
+
return "{0}s".format(s)
|
161 |
+
elif s < 60 * 60:
|
162 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
163 |
+
elif s < 24 * 60 * 60:
|
164 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
165 |
+
else:
|
166 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
167 |
+
|
168 |
+
|
169 |
+
def ask_yes_no(question: str) -> bool:
|
170 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
171 |
+
while True:
|
172 |
+
try:
|
173 |
+
print("{0} [y/n]".format(question))
|
174 |
+
return strtobool(input().lower())
|
175 |
+
except ValueError:
|
176 |
+
pass
|
177 |
+
|
178 |
+
|
179 |
+
def tuple_product(t: Tuple) -> Any:
|
180 |
+
"""Calculate the product of the tuple elements."""
|
181 |
+
result = 1
|
182 |
+
|
183 |
+
for v in t:
|
184 |
+
result *= v
|
185 |
+
|
186 |
+
return result
|
187 |
+
|
188 |
+
|
189 |
+
_str_to_ctype = {
|
190 |
+
"uint8": ctypes.c_ubyte,
|
191 |
+
"uint16": ctypes.c_uint16,
|
192 |
+
"uint32": ctypes.c_uint32,
|
193 |
+
"uint64": ctypes.c_uint64,
|
194 |
+
"int8": ctypes.c_byte,
|
195 |
+
"int16": ctypes.c_int16,
|
196 |
+
"int32": ctypes.c_int32,
|
197 |
+
"int64": ctypes.c_int64,
|
198 |
+
"float32": ctypes.c_float,
|
199 |
+
"float64": ctypes.c_double
|
200 |
+
}
|
201 |
+
|
202 |
+
|
203 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
204 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
205 |
+
type_str = None
|
206 |
+
|
207 |
+
if isinstance(type_obj, str):
|
208 |
+
type_str = type_obj
|
209 |
+
elif hasattr(type_obj, "__name__"):
|
210 |
+
type_str = type_obj.__name__
|
211 |
+
elif hasattr(type_obj, "name"):
|
212 |
+
type_str = type_obj.name
|
213 |
+
else:
|
214 |
+
raise RuntimeError("Cannot infer type name from input")
|
215 |
+
|
216 |
+
assert type_str in _str_to_ctype.keys()
|
217 |
+
|
218 |
+
my_dtype = np.dtype(type_str)
|
219 |
+
my_ctype = _str_to_ctype[type_str]
|
220 |
+
|
221 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
222 |
+
|
223 |
+
return my_dtype, my_ctype
|
224 |
+
|
225 |
+
|
226 |
+
def is_pickleable(obj: Any) -> bool:
|
227 |
+
try:
|
228 |
+
with io.BytesIO() as stream:
|
229 |
+
pickle.dump(obj, stream)
|
230 |
+
return True
|
231 |
+
except:
|
232 |
+
return False
|
233 |
+
|
234 |
+
|
235 |
+
# Functionality to import modules/objects by name, and call functions by name
|
236 |
+
# ------------------------------------------------------------------------------------------
|
237 |
+
|
238 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
239 |
+
"""Searches for the underlying module behind the name to some python object.
|
240 |
+
Returns the module and the object name (original name with module part removed)."""
|
241 |
+
|
242 |
+
# allow convenience shorthands, substitute them by full names
|
243 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
244 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
245 |
+
|
246 |
+
# list alternatives for (module_name, local_obj_name)
|
247 |
+
parts = obj_name.split(".")
|
248 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
249 |
+
|
250 |
+
# try each alternative in turn
|
251 |
+
for module_name, local_obj_name in name_pairs:
|
252 |
+
try:
|
253 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
254 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
255 |
+
return module, local_obj_name
|
256 |
+
except:
|
257 |
+
pass
|
258 |
+
|
259 |
+
# maybe some of the modules themselves contain errors?
|
260 |
+
for module_name, _local_obj_name in name_pairs:
|
261 |
+
try:
|
262 |
+
importlib.import_module(module_name) # may raise ImportError
|
263 |
+
except ImportError:
|
264 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
265 |
+
raise
|
266 |
+
|
267 |
+
# maybe the requested attribute is missing?
|
268 |
+
for module_name, local_obj_name in name_pairs:
|
269 |
+
try:
|
270 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
271 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
272 |
+
except ImportError:
|
273 |
+
pass
|
274 |
+
|
275 |
+
# we are out of luck, but we have no idea why
|
276 |
+
raise ImportError(obj_name)
|
277 |
+
|
278 |
+
|
279 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
280 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
281 |
+
if obj_name == '':
|
282 |
+
return module
|
283 |
+
obj = module
|
284 |
+
for part in obj_name.split("."):
|
285 |
+
obj = getattr(obj, part)
|
286 |
+
return obj
|
287 |
+
|
288 |
+
|
289 |
+
def get_obj_by_name(name: str) -> Any:
|
290 |
+
"""Finds the python object with the given name."""
|
291 |
+
module, obj_name = get_module_from_obj_name(name)
|
292 |
+
return get_obj_from_module(module, obj_name)
|
293 |
+
|
294 |
+
|
295 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
296 |
+
"""Finds the python object with the given name and calls it as a function."""
|
297 |
+
assert func_name is not None
|
298 |
+
func_obj = get_obj_by_name(func_name)
|
299 |
+
assert callable(func_obj)
|
300 |
+
return func_obj(*args, **kwargs)
|
301 |
+
|
302 |
+
|
303 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
304 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
305 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
306 |
+
|
307 |
+
|
308 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
309 |
+
"""Get the directory path of the module containing the given object name."""
|
310 |
+
module, _ = get_module_from_obj_name(obj_name)
|
311 |
+
return os.path.dirname(inspect.getfile(module))
|
312 |
+
|
313 |
+
|
314 |
+
def is_top_level_function(obj: Any) -> bool:
|
315 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
316 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
317 |
+
|
318 |
+
|
319 |
+
def get_top_level_function_name(obj: Any) -> str:
|
320 |
+
"""Return the fully-qualified name of a top-level function."""
|
321 |
+
assert is_top_level_function(obj)
|
322 |
+
module = obj.__module__
|
323 |
+
if module == '__main__':
|
324 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
325 |
+
return module + "." + obj.__name__
|
326 |
+
|
327 |
+
|
328 |
+
# File system helpers
|
329 |
+
# ------------------------------------------------------------------------------------------
|
330 |
+
|
331 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
332 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
333 |
+
Returns list of tuples containing both absolute and relative paths."""
|
334 |
+
assert os.path.isdir(dir_path)
|
335 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
336 |
+
|
337 |
+
if ignores is None:
|
338 |
+
ignores = []
|
339 |
+
|
340 |
+
result = []
|
341 |
+
|
342 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
343 |
+
for ignore_ in ignores:
|
344 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
345 |
+
|
346 |
+
# dirs need to be edited in-place
|
347 |
+
for d in dirs_to_remove:
|
348 |
+
dirs.remove(d)
|
349 |
+
|
350 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
351 |
+
|
352 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
353 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
354 |
+
|
355 |
+
if add_base_to_relative:
|
356 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
357 |
+
|
358 |
+
assert len(absolute_paths) == len(relative_paths)
|
359 |
+
result += zip(absolute_paths, relative_paths)
|
360 |
+
|
361 |
+
return result
|
362 |
+
|
363 |
+
|
364 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
365 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
366 |
+
Will create all necessary directories."""
|
367 |
+
for file in files:
|
368 |
+
target_dir_name = os.path.dirname(file[1])
|
369 |
+
|
370 |
+
# will create all intermediate-level directories
|
371 |
+
if not os.path.exists(target_dir_name):
|
372 |
+
os.makedirs(target_dir_name)
|
373 |
+
|
374 |
+
shutil.copyfile(file[0], file[1])
|
375 |
+
|
376 |
+
|
377 |
+
# URL helpers
|
378 |
+
# ------------------------------------------------------------------------------------------
|
379 |
+
|
380 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
381 |
+
"""Determine whether the given object is a valid URL string."""
|
382 |
+
if not isinstance(obj, str) or not "://" in obj:
|
383 |
+
return False
|
384 |
+
if allow_file_urls and obj.startswith('file://'):
|
385 |
+
return True
|
386 |
+
try:
|
387 |
+
res = requests.compat.urlparse(obj)
|
388 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
389 |
+
return False
|
390 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
391 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
392 |
+
return False
|
393 |
+
except:
|
394 |
+
return False
|
395 |
+
return True
|
396 |
+
|
397 |
+
|
398 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
399 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
400 |
+
assert num_attempts >= 1
|
401 |
+
assert not (return_filename and (not cache))
|
402 |
+
|
403 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
404 |
+
if not re.match('^[a-z]+://', url):
|
405 |
+
return url if return_filename else open(url, "rb")
|
406 |
+
|
407 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
408 |
+
# arise on Windows:
|
409 |
+
#
|
410 |
+
# file:///c:/foo.txt
|
411 |
+
#
|
412 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
413 |
+
# invalid. Drop the forward slash for such pathnames.
|
414 |
+
#
|
415 |
+
# If you touch this code path, you should test it on both Linux and
|
416 |
+
# Windows.
|
417 |
+
#
|
418 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
419 |
+
# but that converts forward slashes to backslashes and this causes
|
420 |
+
# its own set of problems.
|
421 |
+
if url.startswith('file://'):
|
422 |
+
filename = urllib.parse.urlparse(url).path
|
423 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
424 |
+
filename = filename[1:]
|
425 |
+
return filename if return_filename else open(filename, "rb")
|
426 |
+
|
427 |
+
assert is_url(url)
|
428 |
+
|
429 |
+
# Lookup from cache.
|
430 |
+
if cache_dir is None:
|
431 |
+
cache_dir = make_cache_dir_path('downloads')
|
432 |
+
|
433 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
434 |
+
if cache:
|
435 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
436 |
+
if len(cache_files) == 1:
|
437 |
+
filename = cache_files[0]
|
438 |
+
return filename if return_filename else open(filename, "rb")
|
439 |
+
|
440 |
+
# Download.
|
441 |
+
url_name = None
|
442 |
+
url_data = None
|
443 |
+
with requests.Session() as session:
|
444 |
+
if verbose:
|
445 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
446 |
+
for attempts_left in reversed(range(num_attempts)):
|
447 |
+
try:
|
448 |
+
with session.get(url) as res:
|
449 |
+
res.raise_for_status()
|
450 |
+
if len(res.content) == 0:
|
451 |
+
raise IOError("No data received")
|
452 |
+
|
453 |
+
if len(res.content) < 8192:
|
454 |
+
content_str = res.content.decode("utf-8")
|
455 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
456 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
457 |
+
if len(links) == 1:
|
458 |
+
url = requests.compat.urljoin(url, links[0])
|
459 |
+
raise IOError("Google Drive virus checker nag")
|
460 |
+
if "Google Drive - Quota exceeded" in content_str:
|
461 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
462 |
+
|
463 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
464 |
+
url_name = match[1] if match else url
|
465 |
+
url_data = res.content
|
466 |
+
if verbose:
|
467 |
+
print(" done")
|
468 |
+
break
|
469 |
+
except KeyboardInterrupt:
|
470 |
+
raise
|
471 |
+
except:
|
472 |
+
if not attempts_left:
|
473 |
+
if verbose:
|
474 |
+
print(" failed")
|
475 |
+
raise
|
476 |
+
if verbose:
|
477 |
+
print(".", end="", flush=True)
|
478 |
+
|
479 |
+
# Save to cache.
|
480 |
+
if cache:
|
481 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
482 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
483 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
484 |
+
os.makedirs(cache_dir, exist_ok=True)
|
485 |
+
with open(temp_file, "wb") as f:
|
486 |
+
f.write(url_data)
|
487 |
+
os.replace(temp_file, cache_file) # atomic
|
488 |
+
if return_filename:
|
489 |
+
return cache_file
|
490 |
+
|
491 |
+
# Return data as file object.
|
492 |
+
assert not return_filename
|
493 |
+
return io.BytesIO(url_data)
|
pix2pix3D-main/pix2pix3D-main/environment.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
name: pix2pix3d
|
12 |
+
channels:
|
13 |
+
- pytorch
|
14 |
+
- nvidia
|
15 |
+
dependencies:
|
16 |
+
- python >= 3.8
|
17 |
+
- pip
|
18 |
+
- numpy>=1.20
|
19 |
+
- click>=8.0
|
20 |
+
- pillow=8.3.1
|
21 |
+
- scipy=1.7.1
|
22 |
+
- pytorch=1.11.0
|
23 |
+
- cudatoolkit=11.1
|
24 |
+
- requests=2.26.0
|
25 |
+
- tqdm=4.62.2
|
26 |
+
- ninja=1.10.2
|
27 |
+
- matplotlib=3.4.2
|
28 |
+
- imageio=2.9.0
|
29 |
+
- pip:
|
30 |
+
- imgui==1.3.0
|
31 |
+
- glfw==2.2.0
|
32 |
+
- pyopengl==3.1.5
|
33 |
+
- imageio-ffmpeg==0.4.3
|
34 |
+
- pyspng
|
35 |
+
- psutil
|
36 |
+
- mrcfile
|
37 |
+
- tensorboard
|
38 |
+
- einops
|
39 |
+
- opencv-python
|
pix2pix3D-main/pix2pix3D-main/examples/example_input.png
ADDED
![]() |
pix2pix3D-main/pix2pix3D-main/examples/example_input_edge2car.png
ADDED
![]() |
pix2pix3D-main/pix2pix3D-main/examples/example_input_edge2cat.png
ADDED
![]() |
pix2pix3D-main/pix2pix3D-main/legacy.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Converting legacy network pickle into the new format."""
|
12 |
+
|
13 |
+
import click
|
14 |
+
import pickle
|
15 |
+
import re
|
16 |
+
import copy
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import dnnlib
|
20 |
+
from torch_utils import misc
|
21 |
+
|
22 |
+
#----------------------------------------------------------------------------
|
23 |
+
|
24 |
+
def load_network_pkl(f, force_fp16=False):
|
25 |
+
data = _LegacyUnpickler(f).load()
|
26 |
+
|
27 |
+
# Legacy TensorFlow pickle => convert.
|
28 |
+
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
|
29 |
+
tf_G, tf_D, tf_Gs = data
|
30 |
+
G = convert_tf_generator(tf_G)
|
31 |
+
D = convert_tf_discriminator(tf_D)
|
32 |
+
G_ema = convert_tf_generator(tf_Gs)
|
33 |
+
data = dict(G=G, D=D, G_ema=G_ema)
|
34 |
+
|
35 |
+
# Add missing fields.
|
36 |
+
if 'training_set_kwargs' not in data:
|
37 |
+
data['training_set_kwargs'] = None
|
38 |
+
if 'augment_pipe' not in data:
|
39 |
+
data['augment_pipe'] = None
|
40 |
+
|
41 |
+
# Validate contents.
|
42 |
+
assert isinstance(data['G'], torch.nn.Module)
|
43 |
+
assert isinstance(data['D'], torch.nn.Module)
|
44 |
+
assert isinstance(data['G_ema'], torch.nn.Module)
|
45 |
+
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
|
46 |
+
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
|
47 |
+
|
48 |
+
# Force FP16.
|
49 |
+
if force_fp16:
|
50 |
+
for key in ['G', 'D', 'G_ema']:
|
51 |
+
old = data[key]
|
52 |
+
kwargs = copy.deepcopy(old.init_kwargs)
|
53 |
+
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
|
54 |
+
fp16_kwargs.num_fp16_res = 4
|
55 |
+
fp16_kwargs.conv_clamp = 256
|
56 |
+
if kwargs != old.init_kwargs:
|
57 |
+
new = type(old)(**kwargs).eval().requires_grad_(False)
|
58 |
+
misc.copy_params_and_buffers(old, new, require_all=True)
|
59 |
+
data[key] = new
|
60 |
+
return data
|
61 |
+
|
62 |
+
#----------------------------------------------------------------------------
|
63 |
+
|
64 |
+
class _TFNetworkStub(dnnlib.EasyDict):
|
65 |
+
pass
|
66 |
+
|
67 |
+
class _LegacyUnpickler(pickle.Unpickler):
|
68 |
+
def find_class(self, module, name):
|
69 |
+
if module == 'dnnlib.tflib.network' and name == 'Network':
|
70 |
+
return _TFNetworkStub
|
71 |
+
return super().find_class(module, name)
|
72 |
+
|
73 |
+
#----------------------------------------------------------------------------
|
74 |
+
|
75 |
+
def _collect_tf_params(tf_net):
|
76 |
+
# pylint: disable=protected-access
|
77 |
+
tf_params = dict()
|
78 |
+
def recurse(prefix, tf_net):
|
79 |
+
for name, value in tf_net.variables:
|
80 |
+
tf_params[prefix + name] = value
|
81 |
+
for name, comp in tf_net.components.items():
|
82 |
+
recurse(prefix + name + '/', comp)
|
83 |
+
recurse('', tf_net)
|
84 |
+
return tf_params
|
85 |
+
|
86 |
+
#----------------------------------------------------------------------------
|
87 |
+
|
88 |
+
def _populate_module_params(module, *patterns):
|
89 |
+
for name, tensor in misc.named_params_and_buffers(module):
|
90 |
+
found = False
|
91 |
+
value = None
|
92 |
+
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
|
93 |
+
match = re.fullmatch(pattern, name)
|
94 |
+
if match:
|
95 |
+
found = True
|
96 |
+
if value_fn is not None:
|
97 |
+
value = value_fn(*match.groups())
|
98 |
+
break
|
99 |
+
try:
|
100 |
+
assert found
|
101 |
+
if value is not None:
|
102 |
+
tensor.copy_(torch.from_numpy(np.array(value)))
|
103 |
+
except:
|
104 |
+
print(name, list(tensor.shape))
|
105 |
+
raise
|
106 |
+
|
107 |
+
#----------------------------------------------------------------------------
|
108 |
+
|
109 |
+
def convert_tf_generator(tf_G):
|
110 |
+
if tf_G.version < 4:
|
111 |
+
raise ValueError('TensorFlow pickle version too low')
|
112 |
+
|
113 |
+
# Collect kwargs.
|
114 |
+
tf_kwargs = tf_G.static_kwargs
|
115 |
+
known_kwargs = set()
|
116 |
+
def kwarg(tf_name, default=None, none=None):
|
117 |
+
known_kwargs.add(tf_name)
|
118 |
+
val = tf_kwargs.get(tf_name, default)
|
119 |
+
return val if val is not None else none
|
120 |
+
|
121 |
+
# Convert kwargs.
|
122 |
+
from training import networks_stylegan2
|
123 |
+
network_class = networks_stylegan2.Generator
|
124 |
+
kwargs = dnnlib.EasyDict(
|
125 |
+
z_dim = kwarg('latent_size', 512),
|
126 |
+
c_dim = kwarg('label_size', 0),
|
127 |
+
w_dim = kwarg('dlatent_size', 512),
|
128 |
+
img_resolution = kwarg('resolution', 1024),
|
129 |
+
img_channels = kwarg('num_channels', 3),
|
130 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
131 |
+
channel_max = kwarg('fmap_max', 512),
|
132 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
133 |
+
conv_clamp = kwarg('conv_clamp', None),
|
134 |
+
architecture = kwarg('architecture', 'skip'),
|
135 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
136 |
+
use_noise = kwarg('use_noise', True),
|
137 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
138 |
+
mapping_kwargs = dnnlib.EasyDict(
|
139 |
+
num_layers = kwarg('mapping_layers', 8),
|
140 |
+
embed_features = kwarg('label_fmaps', None),
|
141 |
+
layer_features = kwarg('mapping_fmaps', None),
|
142 |
+
activation = kwarg('mapping_nonlinearity', 'lrelu'),
|
143 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.01),
|
144 |
+
w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
|
145 |
+
),
|
146 |
+
)
|
147 |
+
|
148 |
+
# Check for unknown kwargs.
|
149 |
+
kwarg('truncation_psi')
|
150 |
+
kwarg('truncation_cutoff')
|
151 |
+
kwarg('style_mixing_prob')
|
152 |
+
kwarg('structure')
|
153 |
+
kwarg('conditioning')
|
154 |
+
kwarg('fused_modconv')
|
155 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
156 |
+
if len(unknown_kwargs) > 0:
|
157 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
158 |
+
|
159 |
+
# Collect params.
|
160 |
+
tf_params = _collect_tf_params(tf_G)
|
161 |
+
for name, value in list(tf_params.items()):
|
162 |
+
match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
|
163 |
+
if match:
|
164 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
165 |
+
tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
|
166 |
+
kwargs.synthesis.kwargs.architecture = 'orig'
|
167 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
168 |
+
|
169 |
+
# Convert params.
|
170 |
+
G = network_class(**kwargs).eval().requires_grad_(False)
|
171 |
+
# pylint: disable=unnecessary-lambda
|
172 |
+
# pylint: disable=f-string-without-interpolation
|
173 |
+
_populate_module_params(G,
|
174 |
+
r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
|
175 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
|
176 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
|
177 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
|
178 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
|
179 |
+
r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
|
180 |
+
r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
181 |
+
r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
|
182 |
+
r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
|
183 |
+
r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
|
184 |
+
r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
|
185 |
+
r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
|
186 |
+
r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
187 |
+
r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
|
188 |
+
r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
|
189 |
+
r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
|
190 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
|
191 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
|
192 |
+
r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
|
193 |
+
r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
|
194 |
+
r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
|
195 |
+
r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
|
196 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
|
197 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
|
198 |
+
r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
|
199 |
+
r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
|
200 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
|
201 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
|
202 |
+
r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
203 |
+
r'.*\.resample_filter', None,
|
204 |
+
r'.*\.act_filter', None,
|
205 |
+
)
|
206 |
+
return G
|
207 |
+
|
208 |
+
#----------------------------------------------------------------------------
|
209 |
+
|
210 |
+
def convert_tf_discriminator(tf_D):
|
211 |
+
if tf_D.version < 4:
|
212 |
+
raise ValueError('TensorFlow pickle version too low')
|
213 |
+
|
214 |
+
# Collect kwargs.
|
215 |
+
tf_kwargs = tf_D.static_kwargs
|
216 |
+
known_kwargs = set()
|
217 |
+
def kwarg(tf_name, default=None):
|
218 |
+
known_kwargs.add(tf_name)
|
219 |
+
return tf_kwargs.get(tf_name, default)
|
220 |
+
|
221 |
+
# Convert kwargs.
|
222 |
+
kwargs = dnnlib.EasyDict(
|
223 |
+
c_dim = kwarg('label_size', 0),
|
224 |
+
img_resolution = kwarg('resolution', 1024),
|
225 |
+
img_channels = kwarg('num_channels', 3),
|
226 |
+
architecture = kwarg('architecture', 'resnet'),
|
227 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
228 |
+
channel_max = kwarg('fmap_max', 512),
|
229 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
230 |
+
conv_clamp = kwarg('conv_clamp', None),
|
231 |
+
cmap_dim = kwarg('mapping_fmaps', None),
|
232 |
+
block_kwargs = dnnlib.EasyDict(
|
233 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
234 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
235 |
+
freeze_layers = kwarg('freeze_layers', 0),
|
236 |
+
),
|
237 |
+
mapping_kwargs = dnnlib.EasyDict(
|
238 |
+
num_layers = kwarg('mapping_layers', 0),
|
239 |
+
embed_features = kwarg('mapping_fmaps', None),
|
240 |
+
layer_features = kwarg('mapping_fmaps', None),
|
241 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
242 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.1),
|
243 |
+
),
|
244 |
+
epilogue_kwargs = dnnlib.EasyDict(
|
245 |
+
mbstd_group_size = kwarg('mbstd_group_size', None),
|
246 |
+
mbstd_num_channels = kwarg('mbstd_num_features', 1),
|
247 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
# Check for unknown kwargs.
|
252 |
+
kwarg('structure')
|
253 |
+
kwarg('conditioning')
|
254 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
255 |
+
if len(unknown_kwargs) > 0:
|
256 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
257 |
+
|
258 |
+
# Collect params.
|
259 |
+
tf_params = _collect_tf_params(tf_D)
|
260 |
+
for name, value in list(tf_params.items()):
|
261 |
+
match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
|
262 |
+
if match:
|
263 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
264 |
+
tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
|
265 |
+
kwargs.architecture = 'orig'
|
266 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
267 |
+
|
268 |
+
# Convert params.
|
269 |
+
from training import networks_stylegan2
|
270 |
+
D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
|
271 |
+
# pylint: disable=unnecessary-lambda
|
272 |
+
# pylint: disable=f-string-without-interpolation
|
273 |
+
_populate_module_params(D,
|
274 |
+
r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
|
275 |
+
r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
|
276 |
+
r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
|
277 |
+
r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
|
278 |
+
r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
|
279 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
|
280 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
|
281 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
|
282 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
|
283 |
+
r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
284 |
+
r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
|
285 |
+
r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
|
286 |
+
r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
|
287 |
+
r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
|
288 |
+
r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
|
289 |
+
r'.*\.resample_filter', None,
|
290 |
+
)
|
291 |
+
return D
|
292 |
+
|
293 |
+
#----------------------------------------------------------------------------
|
294 |
+
|
295 |
+
@click.command()
|
296 |
+
@click.option('--source', help='Input pickle', required=True, metavar='PATH')
|
297 |
+
@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
|
298 |
+
@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
|
299 |
+
def convert_network_pickle(source, dest, force_fp16):
|
300 |
+
"""Convert legacy network pickle into the native PyTorch format.
|
301 |
+
|
302 |
+
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
|
303 |
+
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
|
304 |
+
|
305 |
+
Example:
|
306 |
+
|
307 |
+
\b
|
308 |
+
python legacy.py \\
|
309 |
+
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
|
310 |
+
--dest=stylegan2-cat-config-f.pkl
|
311 |
+
"""
|
312 |
+
print(f'Loading "{source}"...')
|
313 |
+
with dnnlib.util.open_url(source) as f:
|
314 |
+
data = load_network_pkl(f, force_fp16=force_fp16)
|
315 |
+
print(f'Saving "{dest}"...')
|
316 |
+
with open(dest, 'wb') as f:
|
317 |
+
pickle.dump(data, f)
|
318 |
+
print('Done.')
|
319 |
+
|
320 |
+
#----------------------------------------------------------------------------
|
321 |
+
|
322 |
+
if __name__ == "__main__":
|
323 |
+
convert_network_pickle() # pylint: disable=no-value-for-parameter
|
324 |
+
|
325 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/metrics/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
# empty
|
pix2pix3D-main/pix2pix3D-main/metrics/equivariance.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
|
12 |
+
"Alias-Free Generative Adversarial Networks"."""
|
13 |
+
|
14 |
+
import copy
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.fft
|
18 |
+
from torch_utils.ops import upfirdn2d
|
19 |
+
from . import metric_utils
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
# Utilities.
|
23 |
+
|
24 |
+
def sinc(x):
|
25 |
+
y = (x * np.pi).abs()
|
26 |
+
z = torch.sin(y) / y.clamp(1e-30, float('inf'))
|
27 |
+
return torch.where(y < 1e-30, torch.ones_like(x), z)
|
28 |
+
|
29 |
+
def lanczos_window(x, a):
|
30 |
+
x = x.abs() / a
|
31 |
+
return torch.where(x < 1, sinc(x), torch.zeros_like(x))
|
32 |
+
|
33 |
+
def rotation_matrix(angle):
|
34 |
+
angle = torch.as_tensor(angle).to(torch.float32)
|
35 |
+
mat = torch.eye(3, device=angle.device)
|
36 |
+
mat[0, 0] = angle.cos()
|
37 |
+
mat[0, 1] = angle.sin()
|
38 |
+
mat[1, 0] = -angle.sin()
|
39 |
+
mat[1, 1] = angle.cos()
|
40 |
+
return mat
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
# Apply integer translation to a batch of 2D images. Corresponds to the
|
44 |
+
# operator T_x in Appendix E.1.
|
45 |
+
|
46 |
+
def apply_integer_translation(x, tx, ty):
|
47 |
+
_N, _C, H, W = x.shape
|
48 |
+
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
|
49 |
+
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
|
50 |
+
ix = tx.round().to(torch.int64)
|
51 |
+
iy = ty.round().to(torch.int64)
|
52 |
+
|
53 |
+
z = torch.zeros_like(x)
|
54 |
+
m = torch.zeros_like(x)
|
55 |
+
if abs(ix) < W and abs(iy) < H:
|
56 |
+
y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
|
57 |
+
z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
|
58 |
+
m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
|
59 |
+
return z, m
|
60 |
+
|
61 |
+
#----------------------------------------------------------------------------
|
62 |
+
# Apply integer translation to a batch of 2D images. Corresponds to the
|
63 |
+
# operator T_x in Appendix E.2.
|
64 |
+
|
65 |
+
def apply_fractional_translation(x, tx, ty, a=3):
|
66 |
+
_N, _C, H, W = x.shape
|
67 |
+
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
|
68 |
+
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
|
69 |
+
ix = tx.floor().to(torch.int64)
|
70 |
+
iy = ty.floor().to(torch.int64)
|
71 |
+
fx = tx - ix
|
72 |
+
fy = ty - iy
|
73 |
+
b = a - 1
|
74 |
+
|
75 |
+
z = torch.zeros_like(x)
|
76 |
+
zx0 = max(ix - b, 0)
|
77 |
+
zy0 = max(iy - b, 0)
|
78 |
+
zx1 = min(ix + a, 0) + W
|
79 |
+
zy1 = min(iy + a, 0) + H
|
80 |
+
if zx0 < zx1 and zy0 < zy1:
|
81 |
+
taps = torch.arange(a * 2, device=x.device) - b
|
82 |
+
filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
|
83 |
+
filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
|
84 |
+
y = x
|
85 |
+
y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
|
86 |
+
y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
|
87 |
+
y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
|
88 |
+
z[:, :, zy0:zy1, zx0:zx1] = y
|
89 |
+
|
90 |
+
m = torch.zeros_like(x)
|
91 |
+
mx0 = max(ix + a, 0)
|
92 |
+
my0 = max(iy + a, 0)
|
93 |
+
mx1 = min(ix - b, 0) + W
|
94 |
+
my1 = min(iy - b, 0) + H
|
95 |
+
if mx0 < mx1 and my0 < my1:
|
96 |
+
m[:, :, my0:my1, mx0:mx1] = 1
|
97 |
+
return z, m
|
98 |
+
|
99 |
+
#----------------------------------------------------------------------------
|
100 |
+
# Construct an oriented low-pass filter that applies the appropriate
|
101 |
+
# bandlimit with respect to the input and output of the given affine 2D
|
102 |
+
# image transformation.
|
103 |
+
|
104 |
+
def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
|
105 |
+
assert a <= amax < aflt
|
106 |
+
mat = torch.as_tensor(mat).to(torch.float32)
|
107 |
+
|
108 |
+
# Construct 2D filter taps in input & output coordinate spaces.
|
109 |
+
taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
|
110 |
+
yi, xi = torch.meshgrid(taps, taps)
|
111 |
+
xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
|
112 |
+
|
113 |
+
# Convolution of two oriented 2D sinc filters.
|
114 |
+
fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
|
115 |
+
fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
|
116 |
+
f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
|
117 |
+
|
118 |
+
# Convolution of two oriented 2D Lanczos windows.
|
119 |
+
wi = lanczos_window(xi, a) * lanczos_window(yi, a)
|
120 |
+
wo = lanczos_window(xo, a) * lanczos_window(yo, a)
|
121 |
+
w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
|
122 |
+
|
123 |
+
# Construct windowed FIR filter.
|
124 |
+
f = f * w
|
125 |
+
|
126 |
+
# Finalize.
|
127 |
+
c = (aflt - amax) * up
|
128 |
+
f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
|
129 |
+
f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
|
130 |
+
f = f / f.sum([0,2], keepdim=True) / (up ** 2)
|
131 |
+
f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
|
132 |
+
return f
|
133 |
+
|
134 |
+
#----------------------------------------------------------------------------
|
135 |
+
# Apply the given affine transformation to a batch of 2D images.
|
136 |
+
|
137 |
+
def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
|
138 |
+
_N, _C, H, W = x.shape
|
139 |
+
mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
|
140 |
+
|
141 |
+
# Construct filter.
|
142 |
+
f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
|
143 |
+
assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
|
144 |
+
p = f.shape[0] // 2
|
145 |
+
|
146 |
+
# Construct sampling grid.
|
147 |
+
theta = mat.inverse()
|
148 |
+
theta[:2, 2] *= 2
|
149 |
+
theta[0, 2] += 1 / up / W
|
150 |
+
theta[1, 2] += 1 / up / H
|
151 |
+
theta[0, :] *= W / (W + p / up * 2)
|
152 |
+
theta[1, :] *= H / (H + p / up * 2)
|
153 |
+
theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
|
154 |
+
g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
|
155 |
+
|
156 |
+
# Resample image.
|
157 |
+
y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
|
158 |
+
z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
|
159 |
+
|
160 |
+
# Form mask.
|
161 |
+
m = torch.zeros_like(y)
|
162 |
+
c = p * 2 + 1
|
163 |
+
m[:, :, c:-c, c:-c] = 1
|
164 |
+
m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
|
165 |
+
return z, m
|
166 |
+
|
167 |
+
#----------------------------------------------------------------------------
|
168 |
+
# Apply fractional rotation to a batch of 2D images. Corresponds to the
|
169 |
+
# operator R_\alpha in Appendix E.3.
|
170 |
+
|
171 |
+
def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
|
172 |
+
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
|
173 |
+
mat = rotation_matrix(angle)
|
174 |
+
return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
|
175 |
+
|
176 |
+
#----------------------------------------------------------------------------
|
177 |
+
# Modify the frequency content of a batch of 2D images as if they had undergo
|
178 |
+
# fractional rotation -- but without actually rotating them. Corresponds to
|
179 |
+
# the operator R^*_\alpha in Appendix E.3.
|
180 |
+
|
181 |
+
def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
|
182 |
+
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
|
183 |
+
mat = rotation_matrix(-angle)
|
184 |
+
f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
|
185 |
+
y = upfirdn2d.filter2d(x=x, f=f)
|
186 |
+
m = torch.zeros_like(y)
|
187 |
+
c = f.shape[0] // 2
|
188 |
+
m[:, :, c:-c, c:-c] = 1
|
189 |
+
return y, m
|
190 |
+
|
191 |
+
#----------------------------------------------------------------------------
|
192 |
+
# Compute the selected equivariance metrics for the given generator.
|
193 |
+
|
194 |
+
def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
|
195 |
+
assert compute_eqt_int or compute_eqt_frac or compute_eqr
|
196 |
+
|
197 |
+
# Setup generator and labels.
|
198 |
+
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
199 |
+
I = torch.eye(3, device=opts.device)
|
200 |
+
M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
|
201 |
+
if M is None:
|
202 |
+
raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
|
203 |
+
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
|
204 |
+
|
205 |
+
# Sampling loop.
|
206 |
+
sums = None
|
207 |
+
progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
|
208 |
+
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
209 |
+
progress.update(batch_start)
|
210 |
+
s = []
|
211 |
+
|
212 |
+
# Randomize noise buffers, if any.
|
213 |
+
for name, buf in G.named_buffers():
|
214 |
+
if name.endswith('.noise_const'):
|
215 |
+
buf.copy_(torch.randn_like(buf))
|
216 |
+
|
217 |
+
# Run mapping network.
|
218 |
+
z = torch.randn([batch_size, G.z_dim], device=opts.device)
|
219 |
+
c = next(c_iter)
|
220 |
+
ws = G.mapping(z=z, c=c)
|
221 |
+
|
222 |
+
# Generate reference image.
|
223 |
+
M[:] = I
|
224 |
+
orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
225 |
+
|
226 |
+
# Integer translation (EQ-T).
|
227 |
+
if compute_eqt_int:
|
228 |
+
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
|
229 |
+
t = (t * G.img_resolution).round() / G.img_resolution
|
230 |
+
M[:] = I
|
231 |
+
M[:2, 2] = -t
|
232 |
+
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
233 |
+
ref, mask = apply_integer_translation(orig, t[0], t[1])
|
234 |
+
s += [(ref - img).square() * mask, mask]
|
235 |
+
|
236 |
+
# Fractional translation (EQ-T_frac).
|
237 |
+
if compute_eqt_frac:
|
238 |
+
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
|
239 |
+
M[:] = I
|
240 |
+
M[:2, 2] = -t
|
241 |
+
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
242 |
+
ref, mask = apply_fractional_translation(orig, t[0], t[1])
|
243 |
+
s += [(ref - img).square() * mask, mask]
|
244 |
+
|
245 |
+
# Rotation (EQ-R).
|
246 |
+
if compute_eqr:
|
247 |
+
angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
|
248 |
+
M[:] = rotation_matrix(-angle)
|
249 |
+
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
250 |
+
ref, ref_mask = apply_fractional_rotation(orig, angle)
|
251 |
+
pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
|
252 |
+
mask = ref_mask * pseudo_mask
|
253 |
+
s += [(ref - pseudo).square() * mask, mask]
|
254 |
+
|
255 |
+
# Accumulate results.
|
256 |
+
s = torch.stack([x.to(torch.float64).sum() for x in s])
|
257 |
+
sums = sums + s if sums is not None else s
|
258 |
+
progress.update(num_samples)
|
259 |
+
|
260 |
+
# Compute PSNRs.
|
261 |
+
if opts.num_gpus > 1:
|
262 |
+
torch.distributed.all_reduce(sums)
|
263 |
+
sums = sums.cpu()
|
264 |
+
mses = sums[0::2] / sums[1::2]
|
265 |
+
psnrs = np.log10(2) * 20 - mses.log10() * 10
|
266 |
+
psnrs = tuple(psnrs.numpy())
|
267 |
+
return psnrs[0] if len(psnrs) == 1 else psnrs
|
268 |
+
|
269 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/metrics/frechet_inception_distance.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Frechet Inception Distance (FID) from the paper
|
12 |
+
"GANs trained by a two time-scale update rule converge to a local Nash
|
13 |
+
equilibrium". Matches the original implementation by Heusel et al. at
|
14 |
+
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import scipy.linalg
|
18 |
+
from . import metric_utils
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
def compute_fid(opts, max_real, num_gen):
|
23 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
24 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
25 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
26 |
+
|
27 |
+
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
28 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
29 |
+
rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
|
30 |
+
|
31 |
+
mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
|
32 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
33 |
+
rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
|
34 |
+
|
35 |
+
if opts.rank != 0:
|
36 |
+
return float('nan')
|
37 |
+
|
38 |
+
m = np.square(mu_gen - mu_real).sum()
|
39 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
40 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
41 |
+
return float(fid)
|
42 |
+
|
43 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/metrics/inception_score.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Inception Score (IS) from the paper "Improved techniques for training
|
12 |
+
GANs". Matches the original implementation by Salimans et al. at
|
13 |
+
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from . import metric_utils
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
def compute_is(opts, num_gen, num_splits):
|
21 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
22 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
23 |
+
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
|
24 |
+
|
25 |
+
gen_probs = metric_utils.compute_feature_stats_for_generator(
|
26 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
27 |
+
capture_all=True, max_items=num_gen).get_all()
|
28 |
+
|
29 |
+
if opts.rank != 0:
|
30 |
+
return float('nan'), float('nan')
|
31 |
+
|
32 |
+
scores = []
|
33 |
+
for i in range(num_splits):
|
34 |
+
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
|
35 |
+
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
|
36 |
+
kl = np.mean(np.sum(kl, axis=1))
|
37 |
+
scores.append(np.exp(kl))
|
38 |
+
return float(np.mean(scores)), float(np.std(scores))
|
39 |
+
|
40 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/metrics/kernel_inception_distance.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
|
12 |
+
GANs". Matches the original implementation by Binkowski et al. at
|
13 |
+
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from . import metric_utils
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
|
21 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
22 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
23 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
24 |
+
|
25 |
+
real_features = metric_utils.compute_feature_stats_for_dataset(
|
26 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
27 |
+
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
|
28 |
+
|
29 |
+
gen_features = metric_utils.compute_feature_stats_for_generator(
|
30 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
31 |
+
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
|
32 |
+
|
33 |
+
if opts.rank != 0:
|
34 |
+
return float('nan')
|
35 |
+
|
36 |
+
n = real_features.shape[1]
|
37 |
+
m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
|
38 |
+
t = 0
|
39 |
+
for _subset_idx in range(num_subsets):
|
40 |
+
x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
|
41 |
+
y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
|
42 |
+
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
|
43 |
+
b = (x @ y.T / n + 1) ** 3
|
44 |
+
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
|
45 |
+
kid = t / num_subsets / m
|
46 |
+
return float(kid)
|
47 |
+
|
48 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/metrics/metric_main.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Main API for computing and reporting quality metrics."""
|
12 |
+
|
13 |
+
import os
|
14 |
+
import time
|
15 |
+
import json
|
16 |
+
import torch
|
17 |
+
import dnnlib
|
18 |
+
|
19 |
+
from . import metric_utils
|
20 |
+
from . import frechet_inception_distance
|
21 |
+
from . import kernel_inception_distance
|
22 |
+
from . import precision_recall
|
23 |
+
from . import perceptual_path_length
|
24 |
+
from . import inception_score
|
25 |
+
from . import equivariance
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
_metric_dict = dict() # name => fn
|
30 |
+
|
31 |
+
def register_metric(fn):
|
32 |
+
assert callable(fn)
|
33 |
+
_metric_dict[fn.__name__] = fn
|
34 |
+
return fn
|
35 |
+
|
36 |
+
def is_valid_metric(metric):
|
37 |
+
return metric in _metric_dict
|
38 |
+
|
39 |
+
def list_valid_metrics():
|
40 |
+
return list(_metric_dict.keys())
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
|
44 |
+
def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
|
45 |
+
assert is_valid_metric(metric)
|
46 |
+
opts = metric_utils.MetricOptions(**kwargs)
|
47 |
+
|
48 |
+
# Calculate.
|
49 |
+
start_time = time.time()
|
50 |
+
results = _metric_dict[metric](opts)
|
51 |
+
total_time = time.time() - start_time
|
52 |
+
|
53 |
+
# Broadcast results.
|
54 |
+
for key, value in list(results.items()):
|
55 |
+
if opts.num_gpus > 1:
|
56 |
+
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
|
57 |
+
torch.distributed.broadcast(tensor=value, src=0)
|
58 |
+
value = float(value.cpu())
|
59 |
+
results[key] = value
|
60 |
+
|
61 |
+
# Decorate with metadata.
|
62 |
+
return dnnlib.EasyDict(
|
63 |
+
results = dnnlib.EasyDict(results),
|
64 |
+
metric = metric,
|
65 |
+
total_time = total_time,
|
66 |
+
total_time_str = dnnlib.util.format_time(total_time),
|
67 |
+
num_gpus = opts.num_gpus,
|
68 |
+
)
|
69 |
+
|
70 |
+
#----------------------------------------------------------------------------
|
71 |
+
|
72 |
+
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
|
73 |
+
metric = result_dict['metric']
|
74 |
+
assert is_valid_metric(metric)
|
75 |
+
if run_dir is not None and snapshot_pkl is not None:
|
76 |
+
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
|
77 |
+
|
78 |
+
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
|
79 |
+
print(jsonl_line)
|
80 |
+
if run_dir is not None and os.path.isdir(run_dir):
|
81 |
+
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
|
82 |
+
f.write(jsonl_line + '\n')
|
83 |
+
|
84 |
+
#----------------------------------------------------------------------------
|
85 |
+
# Recommended metrics.
|
86 |
+
|
87 |
+
@register_metric
|
88 |
+
def fid50k_full(opts):
|
89 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
90 |
+
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
|
91 |
+
return dict(fid50k_full=fid)
|
92 |
+
|
93 |
+
@register_metric
|
94 |
+
def kid50k_full(opts):
|
95 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
96 |
+
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
97 |
+
return dict(kid50k_full=kid)
|
98 |
+
|
99 |
+
@register_metric
|
100 |
+
def pr50k3_full(opts):
|
101 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
102 |
+
precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
103 |
+
return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
|
104 |
+
|
105 |
+
@register_metric
|
106 |
+
def ppl2_wend(opts):
|
107 |
+
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
|
108 |
+
return dict(ppl2_wend=ppl)
|
109 |
+
|
110 |
+
@register_metric
|
111 |
+
def eqt50k_int(opts):
|
112 |
+
opts.G_kwargs.update(force_fp32=True)
|
113 |
+
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
|
114 |
+
return dict(eqt50k_int=psnr)
|
115 |
+
|
116 |
+
@register_metric
|
117 |
+
def eqt50k_frac(opts):
|
118 |
+
opts.G_kwargs.update(force_fp32=True)
|
119 |
+
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
|
120 |
+
return dict(eqt50k_frac=psnr)
|
121 |
+
|
122 |
+
@register_metric
|
123 |
+
def eqr50k(opts):
|
124 |
+
opts.G_kwargs.update(force_fp32=True)
|
125 |
+
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
|
126 |
+
return dict(eqr50k=psnr)
|
127 |
+
|
128 |
+
#----------------------------------------------------------------------------
|
129 |
+
# Legacy metrics.
|
130 |
+
|
131 |
+
@register_metric
|
132 |
+
def fid50k(opts):
|
133 |
+
opts.dataset_kwargs.update(max_size=None)
|
134 |
+
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
|
135 |
+
return dict(fid50k=fid)
|
136 |
+
|
137 |
+
@register_metric
|
138 |
+
def kid50k(opts):
|
139 |
+
opts.dataset_kwargs.update(max_size=None)
|
140 |
+
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
141 |
+
return dict(kid50k=kid)
|
142 |
+
|
143 |
+
@register_metric
|
144 |
+
def pr50k3(opts):
|
145 |
+
opts.dataset_kwargs.update(max_size=None)
|
146 |
+
precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
147 |
+
return dict(pr50k3_precision=precision, pr50k3_recall=recall)
|
148 |
+
|
149 |
+
@register_metric
|
150 |
+
def is50k(opts):
|
151 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
152 |
+
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
|
153 |
+
return dict(is50k_mean=mean, is50k_std=std)
|
154 |
+
|
155 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/metrics/metric_utils.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Miscellaneous utilities used internally by the quality metrics."""
|
12 |
+
|
13 |
+
import os
|
14 |
+
import time
|
15 |
+
import hashlib
|
16 |
+
import pickle
|
17 |
+
import copy
|
18 |
+
import uuid
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import dnnlib
|
22 |
+
|
23 |
+
#----------------------------------------------------------------------------
|
24 |
+
|
25 |
+
class MetricOptions:
|
26 |
+
def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
|
27 |
+
assert 0 <= rank < num_gpus
|
28 |
+
self.G = G
|
29 |
+
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
|
30 |
+
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
|
31 |
+
self.num_gpus = num_gpus
|
32 |
+
self.rank = rank
|
33 |
+
self.device = device if device is not None else torch.device('cuda', rank)
|
34 |
+
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
|
35 |
+
self.cache = cache
|
36 |
+
|
37 |
+
#----------------------------------------------------------------------------
|
38 |
+
|
39 |
+
_feature_detector_cache = dict()
|
40 |
+
|
41 |
+
def get_feature_detector_name(url):
|
42 |
+
return os.path.splitext(url.split('/')[-1])[0]
|
43 |
+
|
44 |
+
def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
|
45 |
+
assert 0 <= rank < num_gpus
|
46 |
+
key = (url, device)
|
47 |
+
if key not in _feature_detector_cache:
|
48 |
+
is_leader = (rank == 0)
|
49 |
+
if not is_leader and num_gpus > 1:
|
50 |
+
torch.distributed.barrier() # leader goes first
|
51 |
+
with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
|
52 |
+
_feature_detector_cache[key] = pickle.load(f).to(device)
|
53 |
+
if is_leader and num_gpus > 1:
|
54 |
+
torch.distributed.barrier() # others follow
|
55 |
+
return _feature_detector_cache[key]
|
56 |
+
|
57 |
+
#----------------------------------------------------------------------------
|
58 |
+
|
59 |
+
def iterate_random_labels(opts, batch_size):
|
60 |
+
if opts.G.c_dim == 0:
|
61 |
+
c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
|
62 |
+
while True:
|
63 |
+
yield c
|
64 |
+
else:
|
65 |
+
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
66 |
+
while True:
|
67 |
+
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
|
68 |
+
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
69 |
+
yield c
|
70 |
+
|
71 |
+
#----------------------------------------------------------------------------
|
72 |
+
|
73 |
+
class FeatureStats:
|
74 |
+
def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
|
75 |
+
self.capture_all = capture_all
|
76 |
+
self.capture_mean_cov = capture_mean_cov
|
77 |
+
self.max_items = max_items
|
78 |
+
self.num_items = 0
|
79 |
+
self.num_features = None
|
80 |
+
self.all_features = None
|
81 |
+
self.raw_mean = None
|
82 |
+
self.raw_cov = None
|
83 |
+
|
84 |
+
def set_num_features(self, num_features):
|
85 |
+
if self.num_features is not None:
|
86 |
+
assert num_features == self.num_features
|
87 |
+
else:
|
88 |
+
self.num_features = num_features
|
89 |
+
self.all_features = []
|
90 |
+
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
91 |
+
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
92 |
+
|
93 |
+
def is_full(self):
|
94 |
+
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
95 |
+
|
96 |
+
def append(self, x):
|
97 |
+
x = np.asarray(x, dtype=np.float32)
|
98 |
+
assert x.ndim == 2
|
99 |
+
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
100 |
+
if self.num_items >= self.max_items:
|
101 |
+
return
|
102 |
+
x = x[:self.max_items - self.num_items]
|
103 |
+
|
104 |
+
self.set_num_features(x.shape[1])
|
105 |
+
self.num_items += x.shape[0]
|
106 |
+
if self.capture_all:
|
107 |
+
self.all_features.append(x)
|
108 |
+
if self.capture_mean_cov:
|
109 |
+
x64 = x.astype(np.float64)
|
110 |
+
self.raw_mean += x64.sum(axis=0)
|
111 |
+
self.raw_cov += x64.T @ x64
|
112 |
+
|
113 |
+
def append_torch(self, x, num_gpus=1, rank=0):
|
114 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
115 |
+
assert 0 <= rank < num_gpus
|
116 |
+
if num_gpus > 1:
|
117 |
+
ys = []
|
118 |
+
for src in range(num_gpus):
|
119 |
+
y = x.clone()
|
120 |
+
torch.distributed.broadcast(y, src=src)
|
121 |
+
ys.append(y)
|
122 |
+
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
|
123 |
+
self.append(x.cpu().numpy())
|
124 |
+
|
125 |
+
def get_all(self):
|
126 |
+
assert self.capture_all
|
127 |
+
return np.concatenate(self.all_features, axis=0)
|
128 |
+
|
129 |
+
def get_all_torch(self):
|
130 |
+
return torch.from_numpy(self.get_all())
|
131 |
+
|
132 |
+
def get_mean_cov(self):
|
133 |
+
assert self.capture_mean_cov
|
134 |
+
mean = self.raw_mean / self.num_items
|
135 |
+
cov = self.raw_cov / self.num_items
|
136 |
+
cov = cov - np.outer(mean, mean)
|
137 |
+
return mean, cov
|
138 |
+
|
139 |
+
def save(self, pkl_file):
|
140 |
+
with open(pkl_file, 'wb') as f:
|
141 |
+
pickle.dump(self.__dict__, f)
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def load(pkl_file):
|
145 |
+
with open(pkl_file, 'rb') as f:
|
146 |
+
s = dnnlib.EasyDict(pickle.load(f))
|
147 |
+
obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
|
148 |
+
obj.__dict__.update(s)
|
149 |
+
return obj
|
150 |
+
|
151 |
+
#----------------------------------------------------------------------------
|
152 |
+
|
153 |
+
class ProgressMonitor:
|
154 |
+
def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
|
155 |
+
self.tag = tag
|
156 |
+
self.num_items = num_items
|
157 |
+
self.verbose = verbose
|
158 |
+
self.flush_interval = flush_interval
|
159 |
+
self.progress_fn = progress_fn
|
160 |
+
self.pfn_lo = pfn_lo
|
161 |
+
self.pfn_hi = pfn_hi
|
162 |
+
self.pfn_total = pfn_total
|
163 |
+
self.start_time = time.time()
|
164 |
+
self.batch_time = self.start_time
|
165 |
+
self.batch_items = 0
|
166 |
+
if self.progress_fn is not None:
|
167 |
+
self.progress_fn(self.pfn_lo, self.pfn_total)
|
168 |
+
|
169 |
+
def update(self, cur_items):
|
170 |
+
assert (self.num_items is None) or (cur_items <= self.num_items)
|
171 |
+
if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
|
172 |
+
return
|
173 |
+
cur_time = time.time()
|
174 |
+
total_time = cur_time - self.start_time
|
175 |
+
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
|
176 |
+
if (self.verbose) and (self.tag is not None):
|
177 |
+
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
|
178 |
+
self.batch_time = cur_time
|
179 |
+
self.batch_items = cur_items
|
180 |
+
|
181 |
+
if (self.progress_fn is not None) and (self.num_items is not None):
|
182 |
+
self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
|
183 |
+
|
184 |
+
def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
|
185 |
+
return ProgressMonitor(
|
186 |
+
tag = tag,
|
187 |
+
num_items = num_items,
|
188 |
+
flush_interval = flush_interval,
|
189 |
+
verbose = self.verbose,
|
190 |
+
progress_fn = self.progress_fn,
|
191 |
+
pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
|
192 |
+
pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
|
193 |
+
pfn_total = self.pfn_total,
|
194 |
+
)
|
195 |
+
|
196 |
+
#----------------------------------------------------------------------------
|
197 |
+
|
198 |
+
def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
|
199 |
+
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
200 |
+
if data_loader_kwargs is None:
|
201 |
+
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
|
202 |
+
|
203 |
+
# Try to lookup from cache.
|
204 |
+
cache_file = None
|
205 |
+
if opts.cache:
|
206 |
+
# Choose cache file name.
|
207 |
+
args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
|
208 |
+
md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
|
209 |
+
cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
|
210 |
+
cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
|
211 |
+
|
212 |
+
# Check if the file exists (all processes must agree).
|
213 |
+
flag = os.path.isfile(cache_file) if opts.rank == 0 else False
|
214 |
+
if opts.num_gpus > 1:
|
215 |
+
flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
|
216 |
+
torch.distributed.broadcast(tensor=flag, src=0)
|
217 |
+
flag = (float(flag.cpu()) != 0)
|
218 |
+
|
219 |
+
# Load.
|
220 |
+
if flag:
|
221 |
+
return FeatureStats.load(cache_file)
|
222 |
+
|
223 |
+
# Initialize.
|
224 |
+
num_items = len(dataset)
|
225 |
+
if max_items is not None:
|
226 |
+
num_items = min(num_items, max_items)
|
227 |
+
stats = FeatureStats(max_items=num_items, **stats_kwargs)
|
228 |
+
progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
229 |
+
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
230 |
+
|
231 |
+
# Main loop.
|
232 |
+
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
|
233 |
+
for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
|
234 |
+
if images.shape[1] == 1:
|
235 |
+
images = images.repeat([1, 3, 1, 1])
|
236 |
+
features = detector(images.to(opts.device), **detector_kwargs)
|
237 |
+
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
238 |
+
progress.update(stats.num_items)
|
239 |
+
|
240 |
+
# Save to cache.
|
241 |
+
if cache_file is not None and opts.rank == 0:
|
242 |
+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
243 |
+
temp_file = cache_file + '.' + uuid.uuid4().hex
|
244 |
+
stats.save(temp_file)
|
245 |
+
os.replace(temp_file, cache_file) # atomic
|
246 |
+
return stats
|
247 |
+
|
248 |
+
#----------------------------------------------------------------------------
|
249 |
+
|
250 |
+
def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs):
|
251 |
+
if batch_gen is None:
|
252 |
+
batch_gen = min(batch_size, 4)
|
253 |
+
assert batch_size % batch_gen == 0
|
254 |
+
|
255 |
+
# Setup generator and labels.
|
256 |
+
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
257 |
+
c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
|
258 |
+
|
259 |
+
# Initialize.
|
260 |
+
stats = FeatureStats(**stats_kwargs)
|
261 |
+
assert stats.max_items is not None
|
262 |
+
progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
263 |
+
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
264 |
+
|
265 |
+
# Main loop.
|
266 |
+
while not stats.is_full():
|
267 |
+
images = []
|
268 |
+
for _i in range(batch_size // batch_gen):
|
269 |
+
z = torch.randn([batch_gen, G.z_dim], device=opts.device)
|
270 |
+
img = G(z=z, c=next(c_iter), **opts.G_kwargs)['image']
|
271 |
+
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
272 |
+
images.append(img)
|
273 |
+
images = torch.cat(images)
|
274 |
+
if images.shape[1] == 1:
|
275 |
+
images = images.repeat([1, 3, 1, 1])
|
276 |
+
features = detector(images, **detector_kwargs)
|
277 |
+
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
278 |
+
progress.update(stats.num_items)
|
279 |
+
return stats
|
280 |
+
|
281 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/metrics/perceptual_path_length.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
|
12 |
+
Architecture for Generative Adversarial Networks". Matches the original
|
13 |
+
implementation by Karras et al. at
|
14 |
+
https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
|
15 |
+
|
16 |
+
import copy
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from . import metric_utils
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
# Spherical interpolation of a batch of vectors.
|
24 |
+
def slerp(a, b, t):
|
25 |
+
a = a / a.norm(dim=-1, keepdim=True)
|
26 |
+
b = b / b.norm(dim=-1, keepdim=True)
|
27 |
+
d = (a * b).sum(dim=-1, keepdim=True)
|
28 |
+
p = t * torch.acos(d)
|
29 |
+
c = b - d * a
|
30 |
+
c = c / c.norm(dim=-1, keepdim=True)
|
31 |
+
d = a * torch.cos(p) + c * torch.sin(p)
|
32 |
+
d = d / d.norm(dim=-1, keepdim=True)
|
33 |
+
return d
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
class PPLSampler(torch.nn.Module):
|
38 |
+
def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
|
39 |
+
assert space in ['z', 'w']
|
40 |
+
assert sampling in ['full', 'end']
|
41 |
+
super().__init__()
|
42 |
+
self.G = copy.deepcopy(G)
|
43 |
+
self.G_kwargs = G_kwargs
|
44 |
+
self.epsilon = epsilon
|
45 |
+
self.space = space
|
46 |
+
self.sampling = sampling
|
47 |
+
self.crop = crop
|
48 |
+
self.vgg16 = copy.deepcopy(vgg16)
|
49 |
+
|
50 |
+
def forward(self, c):
|
51 |
+
# Generate random latents and interpolation t-values.
|
52 |
+
t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
|
53 |
+
z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
|
54 |
+
|
55 |
+
# Interpolate in W or Z.
|
56 |
+
if self.space == 'w':
|
57 |
+
w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
|
58 |
+
wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
|
59 |
+
wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
|
60 |
+
else: # space == 'z'
|
61 |
+
zt0 = slerp(z0, z1, t.unsqueeze(1))
|
62 |
+
zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
|
63 |
+
wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
|
64 |
+
|
65 |
+
# Randomize noise buffers.
|
66 |
+
for name, buf in self.G.named_buffers():
|
67 |
+
if name.endswith('.noise_const'):
|
68 |
+
buf.copy_(torch.randn_like(buf))
|
69 |
+
|
70 |
+
# Generate images.
|
71 |
+
img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
|
72 |
+
|
73 |
+
# Center crop.
|
74 |
+
if self.crop:
|
75 |
+
assert img.shape[2] == img.shape[3]
|
76 |
+
c = img.shape[2] // 8
|
77 |
+
img = img[:, :, c*3 : c*7, c*2 : c*6]
|
78 |
+
|
79 |
+
# Downsample to 256x256.
|
80 |
+
factor = self.G.img_resolution // 256
|
81 |
+
if factor > 1:
|
82 |
+
img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
|
83 |
+
|
84 |
+
# Scale dynamic range from [-1,1] to [0,255].
|
85 |
+
img = (img + 1) * (255 / 2)
|
86 |
+
if self.G.img_channels == 1:
|
87 |
+
img = img.repeat([1, 3, 1, 1])
|
88 |
+
|
89 |
+
# Evaluate differential LPIPS.
|
90 |
+
lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
|
91 |
+
dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
|
92 |
+
return dist
|
93 |
+
|
94 |
+
#----------------------------------------------------------------------------
|
95 |
+
|
96 |
+
def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
|
97 |
+
vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
|
98 |
+
vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
|
99 |
+
|
100 |
+
# Setup sampler and labels.
|
101 |
+
sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
|
102 |
+
sampler.eval().requires_grad_(False).to(opts.device)
|
103 |
+
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
|
104 |
+
|
105 |
+
# Sampling loop.
|
106 |
+
dist = []
|
107 |
+
progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
|
108 |
+
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
109 |
+
progress.update(batch_start)
|
110 |
+
x = sampler(next(c_iter))
|
111 |
+
for src in range(opts.num_gpus):
|
112 |
+
y = x.clone()
|
113 |
+
if opts.num_gpus > 1:
|
114 |
+
torch.distributed.broadcast(y, src=src)
|
115 |
+
dist.append(y)
|
116 |
+
progress.update(num_samples)
|
117 |
+
|
118 |
+
# Compute PPL.
|
119 |
+
if opts.rank != 0:
|
120 |
+
return float('nan')
|
121 |
+
dist = torch.cat(dist)[:num_samples].cpu().numpy()
|
122 |
+
lo = np.percentile(dist, 1, interpolation='lower')
|
123 |
+
hi = np.percentile(dist, 99, interpolation='higher')
|
124 |
+
ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
|
125 |
+
return float(ppl)
|
126 |
+
|
127 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/metrics/precision_recall.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Precision/Recall (PR) from the paper "Improved Precision and Recall
|
12 |
+
Metric for Assessing Generative Models". Matches the original implementation
|
13 |
+
by Kynkaanniemi et al. at
|
14 |
+
https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from . import metric_utils
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
|
22 |
+
assert 0 <= rank < num_gpus
|
23 |
+
num_cols = col_features.shape[0]
|
24 |
+
num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
|
25 |
+
col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
|
26 |
+
dist_batches = []
|
27 |
+
for col_batch in col_batches[rank :: num_gpus]:
|
28 |
+
dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
|
29 |
+
for src in range(num_gpus):
|
30 |
+
dist_broadcast = dist_batch.clone()
|
31 |
+
if num_gpus > 1:
|
32 |
+
torch.distributed.broadcast(dist_broadcast, src=src)
|
33 |
+
dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
|
34 |
+
return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
|
35 |
+
|
36 |
+
#----------------------------------------------------------------------------
|
37 |
+
|
38 |
+
def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
|
39 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
|
40 |
+
detector_kwargs = dict(return_features=True)
|
41 |
+
|
42 |
+
real_features = metric_utils.compute_feature_stats_for_dataset(
|
43 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
44 |
+
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
|
45 |
+
|
46 |
+
gen_features = metric_utils.compute_feature_stats_for_generator(
|
47 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
48 |
+
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
|
49 |
+
|
50 |
+
results = dict()
|
51 |
+
for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
|
52 |
+
kth = []
|
53 |
+
for manifold_batch in manifold.split(row_batch_size):
|
54 |
+
dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
55 |
+
kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
|
56 |
+
kth = torch.cat(kth) if opts.rank == 0 else None
|
57 |
+
pred = []
|
58 |
+
for probes_batch in probes.split(row_batch_size):
|
59 |
+
dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
60 |
+
pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
|
61 |
+
results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
|
62 |
+
return results['precision'], results['recall']
|
63 |
+
|
64 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/torch_utils/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
# empty
|
pix2pix3D-main/pix2pix3D-main/torch_utils/custom_ops.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
import glob
|
12 |
+
import hashlib
|
13 |
+
import importlib
|
14 |
+
import os
|
15 |
+
import re
|
16 |
+
import shutil
|
17 |
+
import uuid
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.utils.cpp_extension
|
21 |
+
from torch.utils.file_baton import FileBaton
|
22 |
+
|
23 |
+
#----------------------------------------------------------------------------
|
24 |
+
# Global options.
|
25 |
+
|
26 |
+
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
27 |
+
|
28 |
+
#----------------------------------------------------------------------------
|
29 |
+
# Internal helper funcs.
|
30 |
+
|
31 |
+
def _find_compiler_bindir():
|
32 |
+
patterns = [
|
33 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
34 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
35 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
36 |
+
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
37 |
+
]
|
38 |
+
for pattern in patterns:
|
39 |
+
matches = sorted(glob.glob(pattern))
|
40 |
+
if len(matches):
|
41 |
+
return matches[-1]
|
42 |
+
return None
|
43 |
+
|
44 |
+
#----------------------------------------------------------------------------
|
45 |
+
|
46 |
+
def _get_mangled_gpu_name():
|
47 |
+
name = torch.cuda.get_device_name().lower()
|
48 |
+
out = []
|
49 |
+
for c in name:
|
50 |
+
if re.match('[a-z0-9_-]+', c):
|
51 |
+
out.append(c)
|
52 |
+
else:
|
53 |
+
out.append('-')
|
54 |
+
return ''.join(out)
|
55 |
+
|
56 |
+
#----------------------------------------------------------------------------
|
57 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
58 |
+
|
59 |
+
_cached_plugins = dict()
|
60 |
+
|
61 |
+
def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
|
62 |
+
assert verbosity in ['none', 'brief', 'full']
|
63 |
+
if headers is None:
|
64 |
+
headers = []
|
65 |
+
if source_dir is not None:
|
66 |
+
sources = [os.path.join(source_dir, fname) for fname in sources]
|
67 |
+
headers = [os.path.join(source_dir, fname) for fname in headers]
|
68 |
+
|
69 |
+
# Already cached?
|
70 |
+
if module_name in _cached_plugins:
|
71 |
+
return _cached_plugins[module_name]
|
72 |
+
|
73 |
+
# Print status.
|
74 |
+
if verbosity == 'full':
|
75 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
76 |
+
elif verbosity == 'brief':
|
77 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
78 |
+
verbose_build = (verbosity == 'full')
|
79 |
+
|
80 |
+
# Compile and load.
|
81 |
+
try: # pylint: disable=too-many-nested-blocks
|
82 |
+
# Make sure we can find the necessary compiler binaries.
|
83 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
84 |
+
compiler_bindir = _find_compiler_bindir()
|
85 |
+
if compiler_bindir is None:
|
86 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
87 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
88 |
+
|
89 |
+
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
|
90 |
+
# break the build or unnecessarily restrict what's available to nvcc.
|
91 |
+
# Unset it to let nvcc decide based on what's available on the
|
92 |
+
# machine.
|
93 |
+
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
94 |
+
|
95 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
96 |
+
# into a cached build directory under a combined md5 digest of the input
|
97 |
+
# source files. Copying is done only if the combined digest has changed.
|
98 |
+
# This keeps input file timestamps and filenames the same as in previous
|
99 |
+
# extension builds, allowing for fast incremental rebuilds.
|
100 |
+
#
|
101 |
+
# This optimization is done only in case all the source files reside in
|
102 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
103 |
+
# environment variable is set (we take this as a signal that the user
|
104 |
+
# actually cares about this.)
|
105 |
+
#
|
106 |
+
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
|
107 |
+
# around the *.cu dependency bug in ninja config.
|
108 |
+
#
|
109 |
+
all_source_files = sorted(sources + headers)
|
110 |
+
all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
|
111 |
+
if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
112 |
+
|
113 |
+
# Compute combined hash digest for all source files.
|
114 |
+
hash_md5 = hashlib.md5()
|
115 |
+
for src in all_source_files:
|
116 |
+
with open(src, 'rb') as f:
|
117 |
+
hash_md5.update(f.read())
|
118 |
+
|
119 |
+
# Select cached build directory name.
|
120 |
+
source_digest = hash_md5.hexdigest()
|
121 |
+
build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
122 |
+
cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
|
123 |
+
|
124 |
+
if not os.path.isdir(cached_build_dir):
|
125 |
+
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
|
126 |
+
os.makedirs(tmpdir)
|
127 |
+
for src in all_source_files:
|
128 |
+
shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
|
129 |
+
try:
|
130 |
+
os.replace(tmpdir, cached_build_dir) # atomic
|
131 |
+
except OSError:
|
132 |
+
# source directory already exists, delete tmpdir and its contents.
|
133 |
+
shutil.rmtree(tmpdir)
|
134 |
+
if not os.path.isdir(cached_build_dir): raise
|
135 |
+
|
136 |
+
# Compile.
|
137 |
+
cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
|
138 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
|
139 |
+
verbose=verbose_build, sources=cached_sources, **build_kwargs)
|
140 |
+
else:
|
141 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
142 |
+
|
143 |
+
# Load.
|
144 |
+
module = importlib.import_module(module_name)
|
145 |
+
|
146 |
+
except:
|
147 |
+
if verbosity == 'brief':
|
148 |
+
print('Failed!')
|
149 |
+
raise
|
150 |
+
|
151 |
+
# Print status and add to cache dict.
|
152 |
+
if verbosity == 'full':
|
153 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
154 |
+
elif verbosity == 'brief':
|
155 |
+
print('Done.')
|
156 |
+
_cached_plugins[module_name] = module
|
157 |
+
return module
|
158 |
+
|
159 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/torch_utils/misc.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
import re
|
12 |
+
import contextlib
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import warnings
|
16 |
+
import dnnlib
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
20 |
+
# same constant is used multiple times.
|
21 |
+
|
22 |
+
_constant_cache = dict()
|
23 |
+
|
24 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
25 |
+
value = np.asarray(value)
|
26 |
+
if shape is not None:
|
27 |
+
shape = tuple(shape)
|
28 |
+
if dtype is None:
|
29 |
+
dtype = torch.get_default_dtype()
|
30 |
+
if device is None:
|
31 |
+
device = torch.device('cpu')
|
32 |
+
if memory_format is None:
|
33 |
+
memory_format = torch.contiguous_format
|
34 |
+
|
35 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
36 |
+
tensor = _constant_cache.get(key, None)
|
37 |
+
if tensor is None:
|
38 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
39 |
+
if shape is not None:
|
40 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
41 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
42 |
+
_constant_cache[key] = tensor
|
43 |
+
return tensor
|
44 |
+
|
45 |
+
#----------------------------------------------------------------------------
|
46 |
+
# Replace NaN/Inf with specified numerical values.
|
47 |
+
|
48 |
+
try:
|
49 |
+
nan_to_num = torch.nan_to_num # 1.8.0a0
|
50 |
+
except AttributeError:
|
51 |
+
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
52 |
+
assert isinstance(input, torch.Tensor)
|
53 |
+
if posinf is None:
|
54 |
+
posinf = torch.finfo(input.dtype).max
|
55 |
+
if neginf is None:
|
56 |
+
neginf = torch.finfo(input.dtype).min
|
57 |
+
assert nan == 0
|
58 |
+
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
59 |
+
|
60 |
+
#----------------------------------------------------------------------------
|
61 |
+
# Symbolic assert.
|
62 |
+
|
63 |
+
try:
|
64 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
65 |
+
except AttributeError:
|
66 |
+
symbolic_assert = torch.Assert # 1.7.0
|
67 |
+
|
68 |
+
#----------------------------------------------------------------------------
|
69 |
+
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
70 |
+
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
71 |
+
|
72 |
+
@contextlib.contextmanager
|
73 |
+
def suppress_tracer_warnings():
|
74 |
+
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
75 |
+
warnings.filters.insert(0, flt)
|
76 |
+
yield
|
77 |
+
warnings.filters.remove(flt)
|
78 |
+
|
79 |
+
#----------------------------------------------------------------------------
|
80 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
81 |
+
# None indicates that the size of a dimension is allowed to vary.
|
82 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
83 |
+
|
84 |
+
def assert_shape(tensor, ref_shape):
|
85 |
+
if tensor.ndim != len(ref_shape):
|
86 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
87 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
88 |
+
if ref_size is None:
|
89 |
+
pass
|
90 |
+
elif isinstance(ref_size, torch.Tensor):
|
91 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
92 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
93 |
+
elif isinstance(size, torch.Tensor):
|
94 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
95 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
96 |
+
elif size != ref_size:
|
97 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
98 |
+
|
99 |
+
#----------------------------------------------------------------------------
|
100 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
101 |
+
|
102 |
+
def profiled_function(fn):
|
103 |
+
def decorator(*args, **kwargs):
|
104 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
105 |
+
return fn(*args, **kwargs)
|
106 |
+
decorator.__name__ = fn.__name__
|
107 |
+
return decorator
|
108 |
+
|
109 |
+
#----------------------------------------------------------------------------
|
110 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
111 |
+
# indefinitely, shuffling items as it goes.
|
112 |
+
|
113 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
114 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
115 |
+
assert len(dataset) > 0
|
116 |
+
assert num_replicas > 0
|
117 |
+
assert 0 <= rank < num_replicas
|
118 |
+
assert 0 <= window_size <= 1
|
119 |
+
super().__init__(dataset)
|
120 |
+
self.dataset = dataset
|
121 |
+
self.rank = rank
|
122 |
+
self.num_replicas = num_replicas
|
123 |
+
self.shuffle = shuffle
|
124 |
+
self.seed = seed
|
125 |
+
self.window_size = window_size
|
126 |
+
|
127 |
+
def __iter__(self):
|
128 |
+
order = np.arange(len(self.dataset))
|
129 |
+
rnd = None
|
130 |
+
window = 0
|
131 |
+
if self.shuffle:
|
132 |
+
rnd = np.random.RandomState(self.seed)
|
133 |
+
rnd.shuffle(order)
|
134 |
+
window = int(np.rint(order.size * self.window_size))
|
135 |
+
|
136 |
+
idx = 0
|
137 |
+
while True:
|
138 |
+
i = idx % order.size
|
139 |
+
if idx % self.num_replicas == self.rank:
|
140 |
+
yield order[i]
|
141 |
+
if window >= 2:
|
142 |
+
j = (i - rnd.randint(window)) % order.size
|
143 |
+
order[i], order[j] = order[j], order[i]
|
144 |
+
idx += 1
|
145 |
+
|
146 |
+
#----------------------------------------------------------------------------
|
147 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
148 |
+
|
149 |
+
def params_and_buffers(module):
|
150 |
+
assert isinstance(module, torch.nn.Module)
|
151 |
+
return list(module.parameters()) + list(module.buffers())
|
152 |
+
|
153 |
+
def named_params_and_buffers(module):
|
154 |
+
assert isinstance(module, torch.nn.Module)
|
155 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
156 |
+
|
157 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False, allow_mismatch=False):
|
158 |
+
assert isinstance(src_module, torch.nn.Module)
|
159 |
+
assert isinstance(dst_module, torch.nn.Module)
|
160 |
+
src_tensors = dict(named_params_and_buffers(src_module))
|
161 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
162 |
+
assert (name in src_tensors) or (not require_all)
|
163 |
+
if name not in src_tensors and name.replace('_semantic', '') not in src_tensors:
|
164 |
+
print(f'Warning: {name} not found in source module')
|
165 |
+
continue
|
166 |
+
if name not in src_tensors:
|
167 |
+
# print(f'Warning: {name} not found in source module, using {name.replace("_semantic", "")}')
|
168 |
+
name_src = name.replace('_semantic', '')
|
169 |
+
else:
|
170 |
+
name_src = name
|
171 |
+
if src_tensors[name_src].shape != tensor.shape and allow_mismatch:
|
172 |
+
print(f'Warning: {name_src} shape mismatch: {src_tensors[name_src].shape} vs {tensor.shape}')
|
173 |
+
continue
|
174 |
+
if name_src in src_tensors:
|
175 |
+
tensor.copy_(src_tensors[name_src].detach()).requires_grad_(tensor.requires_grad)
|
176 |
+
# print(f'Copied {name}')
|
177 |
+
|
178 |
+
#----------------------------------------------------------------------------
|
179 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
180 |
+
# synchronization.
|
181 |
+
|
182 |
+
@contextlib.contextmanager
|
183 |
+
def ddp_sync(module, sync):
|
184 |
+
assert isinstance(module, torch.nn.Module)
|
185 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
186 |
+
yield
|
187 |
+
else:
|
188 |
+
with module.no_sync():
|
189 |
+
yield
|
190 |
+
|
191 |
+
#----------------------------------------------------------------------------
|
192 |
+
# Check DistributedDataParallel consistency across processes.
|
193 |
+
|
194 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
195 |
+
assert isinstance(module, torch.nn.Module)
|
196 |
+
for name, tensor in named_params_and_buffers(module):
|
197 |
+
fullname = type(module).__name__ + '.' + name
|
198 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
199 |
+
continue
|
200 |
+
tensor = tensor.detach()
|
201 |
+
if tensor.is_floating_point():
|
202 |
+
tensor = nan_to_num(tensor)
|
203 |
+
other = tensor.clone()
|
204 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
205 |
+
assert (tensor == other).all(), fullname
|
206 |
+
|
207 |
+
#----------------------------------------------------------------------------
|
208 |
+
# Print summary table of module hierarchy.
|
209 |
+
|
210 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
211 |
+
assert isinstance(module, torch.nn.Module)
|
212 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
213 |
+
assert isinstance(inputs, (tuple, list))
|
214 |
+
|
215 |
+
# Register hooks.
|
216 |
+
entries = []
|
217 |
+
nesting = [0]
|
218 |
+
def pre_hook(_mod, _inputs):
|
219 |
+
nesting[0] += 1
|
220 |
+
def post_hook(mod, _inputs, outputs):
|
221 |
+
nesting[0] -= 1
|
222 |
+
if nesting[0] <= max_nesting:
|
223 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
224 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
225 |
+
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
226 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
227 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
228 |
+
|
229 |
+
# Run module.
|
230 |
+
outputs = module(*inputs)
|
231 |
+
for hook in hooks:
|
232 |
+
hook.remove()
|
233 |
+
|
234 |
+
# Identify unique outputs, parameters, and buffers.
|
235 |
+
tensors_seen = set()
|
236 |
+
for e in entries:
|
237 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
238 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
239 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
240 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
241 |
+
|
242 |
+
# Filter out redundant entries.
|
243 |
+
if skip_redundant:
|
244 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
245 |
+
|
246 |
+
# Construct table.
|
247 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
248 |
+
rows += [['---'] * len(rows[0])]
|
249 |
+
param_total = 0
|
250 |
+
buffer_total = 0
|
251 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
252 |
+
for e in entries:
|
253 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
254 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
255 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
256 |
+
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
257 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
258 |
+
rows += [[
|
259 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
260 |
+
str(param_size) if param_size else '-',
|
261 |
+
str(buffer_size) if buffer_size else '-',
|
262 |
+
(output_shapes + ['-'])[0],
|
263 |
+
(output_dtypes + ['-'])[0],
|
264 |
+
]]
|
265 |
+
for idx in range(1, len(e.outputs)):
|
266 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
267 |
+
param_total += param_size
|
268 |
+
buffer_total += buffer_size
|
269 |
+
rows += [['---'] * len(rows[0])]
|
270 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
271 |
+
|
272 |
+
# Print table.
|
273 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
274 |
+
print()
|
275 |
+
for row in rows:
|
276 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
277 |
+
print()
|
278 |
+
return outputs
|
279 |
+
|
280 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
# empty
|
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.cpp
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
4 |
+
*
|
5 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
6 |
+
* property and proprietary rights in and to this material, related
|
7 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
8 |
+
* disclosure or distribution of this material and related documentation
|
9 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
10 |
+
* its affiliates is strictly prohibited.
|
11 |
+
*/
|
12 |
+
|
13 |
+
#include <torch/extension.h>
|
14 |
+
#include <ATen/cuda/CUDAContext.h>
|
15 |
+
#include <c10/cuda/CUDAGuard.h>
|
16 |
+
#include "bias_act.h"
|
17 |
+
|
18 |
+
//------------------------------------------------------------------------
|
19 |
+
|
20 |
+
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
21 |
+
{
|
22 |
+
if (x.dim() != y.dim())
|
23 |
+
return false;
|
24 |
+
for (int64_t i = 0; i < x.dim(); i++)
|
25 |
+
{
|
26 |
+
if (x.size(i) != y.size(i))
|
27 |
+
return false;
|
28 |
+
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
29 |
+
return false;
|
30 |
+
}
|
31 |
+
return true;
|
32 |
+
}
|
33 |
+
|
34 |
+
//------------------------------------------------------------------------
|
35 |
+
|
36 |
+
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
37 |
+
{
|
38 |
+
// Validate arguments.
|
39 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
40 |
+
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
41 |
+
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
42 |
+
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
43 |
+
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
44 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
45 |
+
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
46 |
+
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
47 |
+
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
48 |
+
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
49 |
+
|
50 |
+
// Validate layout.
|
51 |
+
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
52 |
+
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
53 |
+
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
54 |
+
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
55 |
+
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
56 |
+
|
57 |
+
// Create output tensor.
|
58 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
59 |
+
torch::Tensor y = torch::empty_like(x);
|
60 |
+
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
61 |
+
|
62 |
+
// Initialize CUDA kernel parameters.
|
63 |
+
bias_act_kernel_params p;
|
64 |
+
p.x = x.data_ptr();
|
65 |
+
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
66 |
+
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
67 |
+
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
68 |
+
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
69 |
+
p.y = y.data_ptr();
|
70 |
+
p.grad = grad;
|
71 |
+
p.act = act;
|
72 |
+
p.alpha = alpha;
|
73 |
+
p.gain = gain;
|
74 |
+
p.clamp = clamp;
|
75 |
+
p.sizeX = (int)x.numel();
|
76 |
+
p.sizeB = (int)b.numel();
|
77 |
+
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
78 |
+
|
79 |
+
// Choose CUDA kernel.
|
80 |
+
void* kernel;
|
81 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
82 |
+
{
|
83 |
+
kernel = choose_bias_act_kernel<scalar_t>(p);
|
84 |
+
});
|
85 |
+
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
86 |
+
|
87 |
+
// Launch CUDA kernel.
|
88 |
+
p.loopX = 4;
|
89 |
+
int blockSize = 4 * 32;
|
90 |
+
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
91 |
+
void* args[] = {&p};
|
92 |
+
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
93 |
+
return y;
|
94 |
+
}
|
95 |
+
|
96 |
+
//------------------------------------------------------------------------
|
97 |
+
|
98 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
99 |
+
{
|
100 |
+
m.def("bias_act", &bias_act);
|
101 |
+
}
|
102 |
+
|
103 |
+
//------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.cu
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
4 |
+
*
|
5 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
6 |
+
* property and proprietary rights in and to this material, related
|
7 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
8 |
+
* disclosure or distribution of this material and related documentation
|
9 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
10 |
+
* its affiliates is strictly prohibited.
|
11 |
+
*/
|
12 |
+
|
13 |
+
#include <c10/util/Half.h>
|
14 |
+
#include "bias_act.h"
|
15 |
+
|
16 |
+
//------------------------------------------------------------------------
|
17 |
+
// Helpers.
|
18 |
+
|
19 |
+
template <class T> struct InternalType;
|
20 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
21 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
22 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
23 |
+
|
24 |
+
//------------------------------------------------------------------------
|
25 |
+
// CUDA kernel.
|
26 |
+
|
27 |
+
template <class T, int A>
|
28 |
+
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
29 |
+
{
|
30 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
31 |
+
int G = p.grad;
|
32 |
+
scalar_t alpha = (scalar_t)p.alpha;
|
33 |
+
scalar_t gain = (scalar_t)p.gain;
|
34 |
+
scalar_t clamp = (scalar_t)p.clamp;
|
35 |
+
scalar_t one = (scalar_t)1;
|
36 |
+
scalar_t two = (scalar_t)2;
|
37 |
+
scalar_t expRange = (scalar_t)80;
|
38 |
+
scalar_t halfExpRange = (scalar_t)40;
|
39 |
+
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
40 |
+
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
41 |
+
|
42 |
+
// Loop over elements.
|
43 |
+
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
44 |
+
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
45 |
+
{
|
46 |
+
// Load.
|
47 |
+
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
48 |
+
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
49 |
+
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
50 |
+
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
51 |
+
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
52 |
+
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
53 |
+
scalar_t y = 0;
|
54 |
+
|
55 |
+
// Apply bias.
|
56 |
+
((G == 0) ? x : xref) += b;
|
57 |
+
|
58 |
+
// linear
|
59 |
+
if (A == 1)
|
60 |
+
{
|
61 |
+
if (G == 0) y = x;
|
62 |
+
if (G == 1) y = x;
|
63 |
+
}
|
64 |
+
|
65 |
+
// relu
|
66 |
+
if (A == 2)
|
67 |
+
{
|
68 |
+
if (G == 0) y = (x > 0) ? x : 0;
|
69 |
+
if (G == 1) y = (yy > 0) ? x : 0;
|
70 |
+
}
|
71 |
+
|
72 |
+
// lrelu
|
73 |
+
if (A == 3)
|
74 |
+
{
|
75 |
+
if (G == 0) y = (x > 0) ? x : x * alpha;
|
76 |
+
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
77 |
+
}
|
78 |
+
|
79 |
+
// tanh
|
80 |
+
if (A == 4)
|
81 |
+
{
|
82 |
+
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
83 |
+
if (G == 1) y = x * (one - yy * yy);
|
84 |
+
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
85 |
+
}
|
86 |
+
|
87 |
+
// sigmoid
|
88 |
+
if (A == 5)
|
89 |
+
{
|
90 |
+
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
91 |
+
if (G == 1) y = x * yy * (one - yy);
|
92 |
+
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
93 |
+
}
|
94 |
+
|
95 |
+
// elu
|
96 |
+
if (A == 6)
|
97 |
+
{
|
98 |
+
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
99 |
+
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
100 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
101 |
+
}
|
102 |
+
|
103 |
+
// selu
|
104 |
+
if (A == 7)
|
105 |
+
{
|
106 |
+
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
107 |
+
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
108 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
109 |
+
}
|
110 |
+
|
111 |
+
// softplus
|
112 |
+
if (A == 8)
|
113 |
+
{
|
114 |
+
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
115 |
+
if (G == 1) y = x * (one - exp(-yy));
|
116 |
+
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
117 |
+
}
|
118 |
+
|
119 |
+
// swish
|
120 |
+
if (A == 9)
|
121 |
+
{
|
122 |
+
if (G == 0)
|
123 |
+
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
124 |
+
else
|
125 |
+
{
|
126 |
+
scalar_t c = exp(xref);
|
127 |
+
scalar_t d = c + one;
|
128 |
+
if (G == 1)
|
129 |
+
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
130 |
+
else
|
131 |
+
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
132 |
+
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
133 |
+
}
|
134 |
+
}
|
135 |
+
|
136 |
+
// Apply gain.
|
137 |
+
y *= gain * dy;
|
138 |
+
|
139 |
+
// Clamp.
|
140 |
+
if (clamp >= 0)
|
141 |
+
{
|
142 |
+
if (G == 0)
|
143 |
+
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
144 |
+
else
|
145 |
+
y = (yref > -clamp & yref < clamp) ? y : 0;
|
146 |
+
}
|
147 |
+
|
148 |
+
// Store.
|
149 |
+
((T*)p.y)[xi] = (T)y;
|
150 |
+
}
|
151 |
+
}
|
152 |
+
|
153 |
+
//------------------------------------------------------------------------
|
154 |
+
// CUDA kernel selection.
|
155 |
+
|
156 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
157 |
+
{
|
158 |
+
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
159 |
+
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
160 |
+
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
161 |
+
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
162 |
+
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
163 |
+
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
164 |
+
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
165 |
+
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
166 |
+
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
167 |
+
return NULL;
|
168 |
+
}
|
169 |
+
|
170 |
+
//------------------------------------------------------------------------
|
171 |
+
// Template specializations.
|
172 |
+
|
173 |
+
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
174 |
+
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
175 |
+
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
176 |
+
|
177 |
+
//------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.h
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
4 |
+
*
|
5 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
6 |
+
* property and proprietary rights in and to this material, related
|
7 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
8 |
+
* disclosure or distribution of this material and related documentation
|
9 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
10 |
+
* its affiliates is strictly prohibited.
|
11 |
+
*/
|
12 |
+
|
13 |
+
//------------------------------------------------------------------------
|
14 |
+
// CUDA kernel parameters.
|
15 |
+
|
16 |
+
struct bias_act_kernel_params
|
17 |
+
{
|
18 |
+
const void* x; // [sizeX]
|
19 |
+
const void* b; // [sizeB] or NULL
|
20 |
+
const void* xref; // [sizeX] or NULL
|
21 |
+
const void* yref; // [sizeX] or NULL
|
22 |
+
const void* dy; // [sizeX] or NULL
|
23 |
+
void* y; // [sizeX]
|
24 |
+
|
25 |
+
int grad;
|
26 |
+
int act;
|
27 |
+
float alpha;
|
28 |
+
float gain;
|
29 |
+
float clamp;
|
30 |
+
|
31 |
+
int sizeX;
|
32 |
+
int sizeB;
|
33 |
+
int stepB;
|
34 |
+
int loopX;
|
35 |
+
};
|
36 |
+
|
37 |
+
//------------------------------------------------------------------------
|
38 |
+
// CUDA kernel selection.
|
39 |
+
|
40 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
41 |
+
|
42 |
+
//------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/bias_act.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Custom PyTorch ops for efficient bias and activation."""
|
12 |
+
|
13 |
+
import os
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import dnnlib
|
17 |
+
|
18 |
+
from .. import custom_ops
|
19 |
+
from .. import misc
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
activation_funcs = {
|
24 |
+
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
25 |
+
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
26 |
+
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
27 |
+
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
28 |
+
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
29 |
+
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
30 |
+
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
31 |
+
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
32 |
+
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
33 |
+
}
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
_plugin = None
|
38 |
+
_null_tensor = torch.empty([0])
|
39 |
+
|
40 |
+
def _init():
|
41 |
+
global _plugin
|
42 |
+
if _plugin is None:
|
43 |
+
_plugin = custom_ops.get_plugin(
|
44 |
+
module_name='bias_act_plugin',
|
45 |
+
sources=['bias_act.cpp', 'bias_act.cu'],
|
46 |
+
headers=['bias_act.h'],
|
47 |
+
source_dir=os.path.dirname(__file__),
|
48 |
+
extra_cuda_cflags=['--use_fast_math'],
|
49 |
+
)
|
50 |
+
return True
|
51 |
+
|
52 |
+
#----------------------------------------------------------------------------
|
53 |
+
|
54 |
+
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
55 |
+
r"""Fused bias and activation function.
|
56 |
+
|
57 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
58 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
59 |
+
the fused op is considerably more efficient than performing the same calculation
|
60 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
61 |
+
but not third order gradients.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
x: Input activation tensor. Can be of any shape.
|
65 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
66 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
67 |
+
corresponding to `dim`.
|
68 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
69 |
+
The value of `dim` is ignored if `b` is not specified.
|
70 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
71 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
72 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
73 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
74 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
75 |
+
See `activation_funcs` for the default scaling of each activation function.
|
76 |
+
If unsure, consider specifying 1.
|
77 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
78 |
+
the clamping (default).
|
79 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
Tensor of the same shape and datatype as `x`.
|
83 |
+
"""
|
84 |
+
assert isinstance(x, torch.Tensor)
|
85 |
+
assert impl in ['ref', 'cuda']
|
86 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
87 |
+
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
88 |
+
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
89 |
+
|
90 |
+
#----------------------------------------------------------------------------
|
91 |
+
|
92 |
+
@misc.profiled_function
|
93 |
+
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
94 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
95 |
+
"""
|
96 |
+
assert isinstance(x, torch.Tensor)
|
97 |
+
assert clamp is None or clamp >= 0
|
98 |
+
spec = activation_funcs[act]
|
99 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
100 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
101 |
+
clamp = float(clamp if clamp is not None else -1)
|
102 |
+
|
103 |
+
# Add bias.
|
104 |
+
if b is not None:
|
105 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
106 |
+
assert 0 <= dim < x.ndim
|
107 |
+
assert b.shape[0] == x.shape[dim]
|
108 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
109 |
+
|
110 |
+
# Evaluate activation function.
|
111 |
+
alpha = float(alpha)
|
112 |
+
x = spec.func(x, alpha=alpha)
|
113 |
+
|
114 |
+
# Scale by gain.
|
115 |
+
gain = float(gain)
|
116 |
+
if gain != 1:
|
117 |
+
x = x * gain
|
118 |
+
|
119 |
+
# Clamp.
|
120 |
+
if clamp >= 0:
|
121 |
+
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
122 |
+
return x
|
123 |
+
|
124 |
+
#----------------------------------------------------------------------------
|
125 |
+
|
126 |
+
_bias_act_cuda_cache = dict()
|
127 |
+
|
128 |
+
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
129 |
+
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
130 |
+
"""
|
131 |
+
# Parse arguments.
|
132 |
+
assert clamp is None or clamp >= 0
|
133 |
+
spec = activation_funcs[act]
|
134 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
135 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
136 |
+
clamp = float(clamp if clamp is not None else -1)
|
137 |
+
|
138 |
+
# Lookup from cache.
|
139 |
+
key = (dim, act, alpha, gain, clamp)
|
140 |
+
if key in _bias_act_cuda_cache:
|
141 |
+
return _bias_act_cuda_cache[key]
|
142 |
+
|
143 |
+
# Forward op.
|
144 |
+
class BiasActCuda(torch.autograd.Function):
|
145 |
+
@staticmethod
|
146 |
+
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
147 |
+
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
|
148 |
+
x = x.contiguous(memory_format=ctx.memory_format)
|
149 |
+
b = b.contiguous() if b is not None else _null_tensor
|
150 |
+
y = x
|
151 |
+
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
152 |
+
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
153 |
+
ctx.save_for_backward(
|
154 |
+
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
155 |
+
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
156 |
+
y if 'y' in spec.ref else _null_tensor)
|
157 |
+
return y
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
161 |
+
dy = dy.contiguous(memory_format=ctx.memory_format)
|
162 |
+
x, b, y = ctx.saved_tensors
|
163 |
+
dx = None
|
164 |
+
db = None
|
165 |
+
|
166 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
167 |
+
dx = dy
|
168 |
+
if act != 'linear' or gain != 1 or clamp >= 0:
|
169 |
+
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
170 |
+
|
171 |
+
if ctx.needs_input_grad[1]:
|
172 |
+
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
173 |
+
|
174 |
+
return dx, db
|
175 |
+
|
176 |
+
# Backward op.
|
177 |
+
class BiasActCudaGrad(torch.autograd.Function):
|
178 |
+
@staticmethod
|
179 |
+
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
180 |
+
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
|
181 |
+
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
182 |
+
ctx.save_for_backward(
|
183 |
+
dy if spec.has_2nd_grad else _null_tensor,
|
184 |
+
x, b, y)
|
185 |
+
return dx
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
189 |
+
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
190 |
+
dy, x, b, y = ctx.saved_tensors
|
191 |
+
d_dy = None
|
192 |
+
d_x = None
|
193 |
+
d_b = None
|
194 |
+
d_y = None
|
195 |
+
|
196 |
+
if ctx.needs_input_grad[0]:
|
197 |
+
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
198 |
+
|
199 |
+
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
200 |
+
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
201 |
+
|
202 |
+
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
203 |
+
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
204 |
+
|
205 |
+
return d_dy, d_x, d_b, d_y
|
206 |
+
|
207 |
+
# Add to cache.
|
208 |
+
_bias_act_cuda_cache[key] = BiasActCuda
|
209 |
+
return BiasActCuda
|
210 |
+
|
211 |
+
#----------------------------------------------------------------------------
|
pix2pix3D-main/pix2pix3D-main/torch_utils/ops/conv2d_gradfix.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
# property and proprietary rights in and to this material, related
|
6 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
# disclosure or distribution of this material and related documentation
|
8 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
# its affiliates is strictly prohibited.
|
10 |
+
|
11 |
+
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
12 |
+
arbitrarily high order gradients with zero performance penalty."""
|
13 |
+
|
14 |
+
import contextlib
|
15 |
+
import torch
|
16 |
+
|
17 |
+
# pylint: disable=redefined-builtin
|
18 |
+
# pylint: disable=arguments-differ
|
19 |
+
# pylint: disable=protected-access
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
enabled = False # Enable the custom op by setting this to true.
|
24 |
+
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
25 |
+
|
26 |
+
@contextlib.contextmanager
|
27 |
+
def no_weight_gradients(disable=True):
|
28 |
+
global weight_gradients_disabled
|
29 |
+
old = weight_gradients_disabled
|
30 |
+
if disable:
|
31 |
+
weight_gradients_disabled = True
|
32 |
+
yield
|
33 |
+
weight_gradients_disabled = old
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
38 |
+
if _should_use_custom_op(input):
|
39 |
+
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
40 |
+
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
41 |
+
|
42 |
+
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
43 |
+
if _should_use_custom_op(input):
|
44 |
+
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
45 |
+
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
46 |
+
|
47 |
+
#----------------------------------------------------------------------------
|
48 |
+
|
49 |
+
def _should_use_custom_op(input):
|
50 |
+
assert isinstance(input, torch.Tensor)
|
51 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
52 |
+
return False
|
53 |
+
if input.device.type != 'cuda':
|
54 |
+
return False
|
55 |
+
return True
|
56 |
+
|
57 |
+
def _tuple_of_ints(xs, ndim):
|
58 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
59 |
+
assert len(xs) == ndim
|
60 |
+
assert all(isinstance(x, int) for x in xs)
|
61 |
+
return xs
|
62 |
+
|
63 |
+
#----------------------------------------------------------------------------
|
64 |
+
|
65 |
+
_conv2d_gradfix_cache = dict()
|
66 |
+
_null_tensor = torch.empty([0])
|
67 |
+
|
68 |
+
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
69 |
+
# Parse arguments.
|
70 |
+
ndim = 2
|
71 |
+
weight_shape = tuple(weight_shape)
|
72 |
+
stride = _tuple_of_ints(stride, ndim)
|
73 |
+
padding = _tuple_of_ints(padding, ndim)
|
74 |
+
output_padding = _tuple_of_ints(output_padding, ndim)
|
75 |
+
dilation = _tuple_of_ints(dilation, ndim)
|
76 |
+
|
77 |
+
# Lookup from cache.
|
78 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
79 |
+
if key in _conv2d_gradfix_cache:
|
80 |
+
return _conv2d_gradfix_cache[key]
|
81 |
+
|
82 |
+
# Validate arguments.
|
83 |
+
assert groups >= 1
|
84 |
+
assert len(weight_shape) == ndim + 2
|
85 |
+
assert all(stride[i] >= 1 for i in range(ndim))
|
86 |
+
assert all(padding[i] >= 0 for i in range(ndim))
|
87 |
+
assert all(dilation[i] >= 0 for i in range(ndim))
|
88 |
+
if not transpose:
|
89 |
+
assert all(output_padding[i] == 0 for i in range(ndim))
|
90 |
+
else: # transpose
|
91 |
+
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
92 |
+
|
93 |
+
# Helpers.
|
94 |
+
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
95 |
+
def calc_output_padding(input_shape, output_shape):
|
96 |
+
if transpose:
|
97 |
+
return [0, 0]
|
98 |
+
return [
|
99 |
+
input_shape[i + 2]
|
100 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
101 |
+
- (1 - 2 * padding[i])
|
102 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
103 |
+
for i in range(ndim)
|
104 |
+
]
|
105 |
+
|
106 |
+
# Forward & backward.
|
107 |
+
class Conv2d(torch.autograd.Function):
|
108 |
+
@staticmethod
|
109 |
+
def forward(ctx, input, weight, bias):
|
110 |
+
assert weight.shape == weight_shape
|
111 |
+
ctx.save_for_backward(
|
112 |
+
input if weight.requires_grad else _null_tensor,
|
113 |
+
weight if input.requires_grad else _null_tensor,
|
114 |
+
)
|
115 |
+
ctx.input_shape = input.shape
|
116 |
+
|
117 |
+
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
|
118 |
+
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
|
119 |
+
a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
|
120 |
+
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
|
121 |
+
c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
|
122 |
+
c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
|
123 |
+
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
124 |
+
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
125 |
+
|
126 |
+
# General case => cuDNN.
|
127 |
+
if transpose:
|
128 |
+
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
129 |
+
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def backward(ctx, grad_output):
|
133 |
+
input, weight = ctx.saved_tensors
|
134 |
+
input_shape = ctx.input_shape
|
135 |
+
grad_input = None
|
136 |
+
grad_weight = None
|
137 |
+
grad_bias = None
|
138 |
+
|
139 |
+
if ctx.needs_input_grad[0]:
|
140 |
+
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
|
141 |
+
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
142 |
+
grad_input = op.apply(grad_output, weight, None)
|
143 |
+
assert grad_input.shape == input_shape
|
144 |
+
|
145 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
146 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
|
147 |
+
assert grad_weight.shape == weight_shape
|
148 |
+
|
149 |
+
if ctx.needs_input_grad[2]:
|
150 |
+
grad_bias = grad_output.sum([0, 2, 3])
|
151 |
+
|
152 |
+
return grad_input, grad_weight, grad_bias
|
153 |
+
|
154 |
+
# Gradient with respect to the weights.
|
155 |
+
class Conv2dGradWeight(torch.autograd.Function):
|
156 |
+
@staticmethod
|
157 |
+
def forward(ctx, grad_output, input, weight):
|
158 |
+
ctx.save_for_backward(
|
159 |
+
grad_output if input.requires_grad else _null_tensor,
|
160 |
+
input if grad_output.requires_grad else _null_tensor,
|
161 |
+
)
|
162 |
+
ctx.grad_output_shape = grad_output.shape
|
163 |
+
ctx.input_shape = input.shape
|
164 |
+
|
165 |
+
# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
|
166 |
+
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
|
167 |
+
a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
168 |
+
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
169 |
+
c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
|
170 |
+
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
171 |
+
|
172 |
+
# General case => cuDNN.
|
173 |
+
return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
|
174 |
+
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def backward(ctx, grad2_grad_weight):
|
178 |
+
grad_output, input = ctx.saved_tensors
|
179 |
+
grad_output_shape = ctx.grad_output_shape
|
180 |
+
input_shape = ctx.input_shape
|
181 |
+
grad2_grad_output = None
|
182 |
+
grad2_input = None
|
183 |
+
|
184 |
+
if ctx.needs_input_grad[0]:
|
185 |
+
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
186 |
+
assert grad2_grad_output.shape == grad_output_shape
|
187 |
+
|
188 |
+
if ctx.needs_input_grad[1]:
|
189 |
+
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
|
190 |
+
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
191 |
+
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
|
192 |
+
assert grad2_input.shape == input_shape
|
193 |
+
|
194 |
+
return grad2_grad_output, grad2_input
|
195 |
+
|
196 |
+
_conv2d_gradfix_cache[key] = Conv2d
|
197 |
+
return Conv2d
|
198 |
+
|
199 |
+
#----------------------------------------------------------------------------
|