yansong1616 commited on
Commit
b177539
1 Parent(s): 5ad334d

Upload 384 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +38 -35
  2. .gitignore +132 -0
  3. .gitmodules +3 -0
  4. .idea/.gitignore +8 -0
  5. .idea/dust3r.iml +12 -0
  6. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  7. .idea/misc.xml +7 -0
  8. .idea/modules.xml +8 -0
  9. .idea/other.xml +6 -0
  10. .idea/workspace.xml +653 -0
  11. LICENSE +7 -0
  12. NOTICE +13 -0
  13. README.md +12 -12
  14. SAM/__init__.py +15 -0
  15. SAM/__pycache__/__init__.cpython-310.pyc +0 -0
  16. SAM/__pycache__/automatic_mask_generator.cpython-310.pyc +0 -0
  17. SAM/__pycache__/build_sam.cpython-310.pyc +0 -0
  18. SAM/__pycache__/predictor.cpython-310.pyc +0 -0
  19. SAM/automatic_mask_generator.py +372 -0
  20. SAM/build_sam.py +107 -0
  21. SAM/modeling/__init__.py +11 -0
  22. SAM/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
  23. SAM/modeling/__pycache__/common.cpython-310.pyc +0 -0
  24. SAM/modeling/__pycache__/image_encoder.cpython-310.pyc +0 -0
  25. SAM/modeling/__pycache__/mask_decoder.cpython-310.pyc +0 -0
  26. SAM/modeling/__pycache__/prompt_encoder.cpython-310.pyc +0 -0
  27. SAM/modeling/__pycache__/sam.cpython-310.pyc +0 -0
  28. SAM/modeling/__pycache__/transformer.cpython-310.pyc +0 -0
  29. SAM/modeling/common.py +43 -0
  30. SAM/modeling/image_encoder.py +395 -0
  31. SAM/modeling/mask_decoder.py +192 -0
  32. SAM/modeling/prompt_encoder.py +214 -0
  33. SAM/modeling/sam.py +187 -0
  34. SAM/modeling/transformer.py +240 -0
  35. SAM/predictor.py +269 -0
  36. SAM/utils/__pycache__/amg.cpython-310.pyc +0 -0
  37. SAM/utils/__pycache__/transforms.cpython-310.pyc +0 -0
  38. SAM/utils/amg.py +346 -0
  39. SAM/utils/transforms.py +102 -0
  40. __pycache__/evaluate.cpython-310.pyc +0 -0
  41. __pycache__/load_nvos.cpython-310.pyc +0 -0
  42. app.py +353 -0
  43. checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth +3 -0
  44. checkpoints/sam_vit_b_01ec64.pth +3 -0
  45. configs/default.py +119 -0
  46. configs/lerf/book_store.py +16 -0
  47. configs/lerf/bouquet.py +16 -0
  48. configs/lerf/donuts.py +16 -0
  49. configs/lerf/dozer_nerfgun_waldo.py +16 -0
  50. configs/lerf/espresso.py +16 -0
.gitattributes CHANGED
@@ -1,35 +1,38 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ output/llff(sanerf-hq)/fenceflower/point_cloud_projection.png filter=lfs diff=lfs merge=lfs -text
37
+ output/llff(sanerf-hq)/mattcecsit/point_cloud_projection.png filter=lfs diff=lfs merge=lfs -text
38
+ output/llff(sanerf-hq)/mattwrite/point_cloud_projection.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ checkpoints/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ pip-wheel-metadata/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "croco"]
2
+ path = croco
3
+ url = https://github.com/naver/croco
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/dust3r.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="dust3r" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="edgesam" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="dust3r" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/dust3r.iml" filepath="$PROJECT_DIR$/.idea/dust3r.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/other.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PySciProjectComponent">
4
+ <option name="PY_INTERACTIVE_PLOTS_SUGGESTED" value="true" />
5
+ </component>
6
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
+ <component name="ChangeListManager">
7
+ <list default="true" id="de0dddb6-4a99-4847-9050-a2cb006d71c9" name="Changes" comment="" />
8
+ <option name="SHOW_DIALOG" value="false" />
9
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
11
+ <option name="LAST_RESOLUTION" value="IGNORE" />
12
+ </component>
13
+ <component name="FileTemplateManagerImpl">
14
+ <option name="RECENT_TEMPLATES">
15
+ <list>
16
+ <option value="Python Script" />
17
+ </list>
18
+ </option>
19
+ </component>
20
+ <component name="MarkdownSettingsMigration">
21
+ <option name="stateVersion" value="1" />
22
+ </component>
23
+ <component name="ProjectColorInfo">{
24
+ &quot;associatedIndex&quot;: 6
25
+ }</component>
26
+ <component name="ProjectId" id="2fAGUbZMWEGJzrHYLuaOE0replo" />
27
+ <component name="ProjectViewState">
28
+ <option name="hideEmptyMiddlePackages" value="true" />
29
+ <option name="showLibraryContents" value="true" />
30
+ </component>
31
+ <component name="PropertiesComponent">{
32
+ &quot;keyToString&quot;: {
33
+ &quot;Python.base_opt.executor&quot;: &quot;Debug&quot;,
34
+ &quot;Python.demo.executor&quot;: &quot;Debug&quot;,
35
+ &quot;Python.evaluate.executor&quot;: &quot;Run&quot;,
36
+ &quot;Python.gys_util.executor&quot;: &quot;Run&quot;,
37
+ &quot;Python.load_nvos.executor&quot;: &quot;Debug&quot;,
38
+ &quot;Python.prepare_prompts.executor&quot;: &quot;Debug&quot;,
39
+ &quot;Python.segment_eval_mask.executor&quot;: &quot;Run&quot;,
40
+ &quot;Python.test_vis.executor&quot;: &quot;Run&quot;,
41
+ &quot;RunOnceActivity.OpenProjectViewOnStart&quot;: &quot;true&quot;,
42
+ &quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
43
+ &quot;last_opened_file_path&quot;: &quot;D:/XMU/mac/hujie/3D/DUSt3R/dust3r/data/nerf_llff_data(NVOS-all)/orchids&quot;,
44
+ &quot;node.js.detected.package.eslint&quot;: &quot;true&quot;,
45
+ &quot;node.js.detected.package.tslint&quot;: &quot;true&quot;,
46
+ &quot;node.js.selected.package.eslint&quot;: &quot;(autodetect)&quot;,
47
+ &quot;node.js.selected.package.tslint&quot;: &quot;(autodetect)&quot;,
48
+ &quot;nodejs_package_manager_path&quot;: &quot;npm&quot;,
49
+ &quot;settings.editor.selected.configurable&quot;: &quot;editor.preferences.fonts.default&quot;,
50
+ &quot;vue.rearranger.settings.migration&quot;: &quot;true&quot;
51
+ }
52
+ }</component>
53
+ <component name="RecentsManager">
54
+ <key name="CopyFile.RECENT_KEYS">
55
+ <recent name="D:\XMU\mac\hujie\3D\DUSt3R\dust3r\data\nerf_llff_data(NVOS-all)\orchids" />
56
+ <recent name="D:\XMU\mac\hujie\3D\DUSt3R\dust3r\data\nerf_llff_data(NVOS-all)\leaves" />
57
+ <recent name="D:\XMU\mac\hujie\3D\DUSt3R\dust3r\data\nerf_llff_data(NVOS-all)\fortress" />
58
+ <recent name="D:\XMU\mac\hujie\3D\DUSt3R\dust3r\data\nerf_llff_data(NVOS-all)\flower" />
59
+ <recent name="D:\XMU\mac\hujie\3D\DUSt3R\dust3r\data\nerf_llff_data(NVOS-all)\flower\images_8" />
60
+ </key>
61
+ <key name="MoveFile.RECENT_KEYS">
62
+ <recent name="D:\XMU\mac\hujie\3D\DUSt3R\dust3r\data\nerf_llff_data(NVOS-all)\masks\horns_center" />
63
+ <recent name="D:\XMU\mac\hujie\3D\DUSt3R\dust3r\data\nerf_llff_data(NVOS)\masks\horns_center" />
64
+ </key>
65
+ </component>
66
+ <component name="RunManager" selected="Python.segment_eval_mask">
67
+ <configuration name="base_opt" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
68
+ <module name="dust3r" />
69
+ <option name="ENV_FILES" value="" />
70
+ <option name="INTERPRETER_OPTIONS" value="" />
71
+ <option name="PARENT_ENVS" value="true" />
72
+ <envs>
73
+ <env name="PYTHONUNBUFFERED" value="1" />
74
+ </envs>
75
+ <option name="SDK_HOME" value="" />
76
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/dust3r/cloud_opt" />
77
+ <option name="IS_MODULE_SDK" value="true" />
78
+ <option name="ADD_CONTENT_ROOTS" value="true" />
79
+ <option name="ADD_SOURCE_ROOTS" value="true" />
80
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
81
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py" />
82
+ <option name="PARAMETERS" value="" />
83
+ <option name="SHOW_COMMAND_LINE" value="false" />
84
+ <option name="EMULATE_TERMINAL" value="false" />
85
+ <option name="MODULE_MODE" value="false" />
86
+ <option name="REDIRECT_INPUT" value="false" />
87
+ <option name="INPUT_FILE" value="" />
88
+ <method v="2" />
89
+ </configuration>
90
+ <configuration name="evaluate" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
91
+ <module name="dust3r" />
92
+ <option name="ENV_FILES" value="" />
93
+ <option name="INTERPRETER_OPTIONS" value="" />
94
+ <option name="PARENT_ENVS" value="true" />
95
+ <envs>
96
+ <env name="PYTHONUNBUFFERED" value="1" />
97
+ </envs>
98
+ <option name="SDK_HOME" value="" />
99
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
100
+ <option name="IS_MODULE_SDK" value="true" />
101
+ <option name="ADD_CONTENT_ROOTS" value="true" />
102
+ <option name="ADD_SOURCE_ROOTS" value="true" />
103
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
104
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/evaluate.py" />
105
+ <option name="PARAMETERS" value="" />
106
+ <option name="SHOW_COMMAND_LINE" value="false" />
107
+ <option name="EMULATE_TERMINAL" value="false" />
108
+ <option name="MODULE_MODE" value="false" />
109
+ <option name="REDIRECT_INPUT" value="false" />
110
+ <option name="INPUT_FILE" value="" />
111
+ <method v="2" />
112
+ </configuration>
113
+ <configuration name="gys_util" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
114
+ <module name="dust3r" />
115
+ <option name="ENV_FILES" value="" />
116
+ <option name="INTERPRETER_OPTIONS" value="" />
117
+ <option name="PARENT_ENVS" value="true" />
118
+ <envs>
119
+ <env name="PYTHONUNBUFFERED" value="1" />
120
+ </envs>
121
+ <option name="SDK_HOME" value="" />
122
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
123
+ <option name="IS_MODULE_SDK" value="true" />
124
+ <option name="ADD_CONTENT_ROOTS" value="true" />
125
+ <option name="ADD_SOURCE_ROOTS" value="true" />
126
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
127
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/gys_util.py" />
128
+ <option name="PARAMETERS" value="" />
129
+ <option name="SHOW_COMMAND_LINE" value="false" />
130
+ <option name="EMULATE_TERMINAL" value="false" />
131
+ <option name="MODULE_MODE" value="false" />
132
+ <option name="REDIRECT_INPUT" value="false" />
133
+ <option name="INPUT_FILE" value="" />
134
+ <method v="2" />
135
+ </configuration>
136
+ <configuration name="load_nvos" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
137
+ <module name="dust3r" />
138
+ <option name="ENV_FILES" value="" />
139
+ <option name="INTERPRETER_OPTIONS" value="" />
140
+ <option name="PARENT_ENVS" value="true" />
141
+ <envs>
142
+ <env name="PYTHONUNBUFFERED" value="1" />
143
+ </envs>
144
+ <option name="SDK_HOME" value="" />
145
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
146
+ <option name="IS_MODULE_SDK" value="true" />
147
+ <option name="ADD_CONTENT_ROOTS" value="true" />
148
+ <option name="ADD_SOURCE_ROOTS" value="true" />
149
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
150
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/load_nvos.py" />
151
+ <option name="PARAMETERS" value="" />
152
+ <option name="SHOW_COMMAND_LINE" value="false" />
153
+ <option name="EMULATE_TERMINAL" value="false" />
154
+ <option name="MODULE_MODE" value="false" />
155
+ <option name="REDIRECT_INPUT" value="false" />
156
+ <option name="INPUT_FILE" value="" />
157
+ <method v="2" />
158
+ </configuration>
159
+ <configuration name="segment_eval_mask" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
160
+ <module name="dust3r" />
161
+ <option name="ENV_FILES" value="" />
162
+ <option name="INTERPRETER_OPTIONS" value="" />
163
+ <option name="PARENT_ENVS" value="true" />
164
+ <envs>
165
+ <env name="PYTHONUNBUFFERED" value="1" />
166
+ </envs>
167
+ <option name="SDK_HOME" value="" />
168
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
169
+ <option name="IS_MODULE_SDK" value="true" />
170
+ <option name="ADD_CONTENT_ROOTS" value="true" />
171
+ <option name="ADD_SOURCE_ROOTS" value="true" />
172
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
173
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/segment_eval_mask.py" />
174
+ <option name="PARAMETERS" value="" />
175
+ <option name="SHOW_COMMAND_LINE" value="false" />
176
+ <option name="EMULATE_TERMINAL" value="false" />
177
+ <option name="MODULE_MODE" value="false" />
178
+ <option name="REDIRECT_INPUT" value="false" />
179
+ <option name="INPUT_FILE" value="" />
180
+ <method v="2" />
181
+ </configuration>
182
+ <list>
183
+ <item itemvalue="Python.base_opt" />
184
+ <item itemvalue="Python.evaluate" />
185
+ <item itemvalue="Python.gys_util" />
186
+ <item itemvalue="Python.load_nvos" />
187
+ <item itemvalue="Python.segment_eval_mask" />
188
+ </list>
189
+ <recent_temporary>
190
+ <list>
191
+ <item itemvalue="Python.segment_eval_mask" />
192
+ <item itemvalue="Python.gys_util" />
193
+ <item itemvalue="Python.base_opt" />
194
+ <item itemvalue="Python.load_nvos" />
195
+ <item itemvalue="Python.evaluate" />
196
+ </list>
197
+ </recent_temporary>
198
+ </component>
199
+ <component name="SharedIndexes">
200
+ <attachedChunks>
201
+ <set>
202
+ <option value="bundled-python-sdk-d68999036c7f-b11f5e8da5ad-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-233.14475.56" />
203
+ </set>
204
+ </attachedChunks>
205
+ </component>
206
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
207
+ <component name="TaskManager">
208
+ <task active="true" id="Default" summary="Default task">
209
+ <changelist id="de0dddb6-4a99-4847-9050-a2cb006d71c9" name="Changes" comment="" />
210
+ <created>1713236486096</created>
211
+ <option name="number" value="Default" />
212
+ <option name="presentableId" value="Default" />
213
+ <updated>1713236486096</updated>
214
+ <workItem from="1713236487157" duration="10677000" />
215
+ <workItem from="1713322792937" duration="2998000" />
216
+ <workItem from="1713327850893" duration="7179000" />
217
+ <workItem from="1713345425711" duration="3092000" />
218
+ <workItem from="1713365007655" duration="3000" />
219
+ <workItem from="1713367330362" duration="10000" />
220
+ <workItem from="1713424119618" duration="7824000" />
221
+ <workItem from="1713490936554" duration="2000" />
222
+ <workItem from="1713505652781" duration="4699000" />
223
+ <workItem from="1713663410571" duration="2000" />
224
+ <workItem from="1713692836014" duration="2000" />
225
+ <workItem from="1713706421213" duration="483000" />
226
+ <workItem from="1713760410928" duration="1449000" />
227
+ <workItem from="1713859656888" duration="121000" />
228
+ <workItem from="1713868202296" duration="766000" />
229
+ <workItem from="1713871959150" duration="9000" />
230
+ <workItem from="1714029659735" duration="12549000" />
231
+ <workItem from="1714107476497" duration="2510000" />
232
+ <workItem from="1714111340455" duration="1140000" />
233
+ <workItem from="1714112654480" duration="4607000" />
234
+ <workItem from="1714306019181" duration="6300000" />
235
+ <workItem from="1714374776625" duration="9143000" />
236
+ <workItem from="1714477007344" duration="1203000" />
237
+ <workItem from="1714977472055" duration="12684000" />
238
+ <workItem from="1715235695003" duration="6444000" />
239
+ <workItem from="1715266491201" duration="657000" />
240
+ <workItem from="1715322636502" duration="14461000" />
241
+ <workItem from="1715407622615" duration="5364000" />
242
+ <workItem from="1715496542428" duration="15485000" />
243
+ <workItem from="1715578333845" duration="3525000" />
244
+ <workItem from="1715654635430" duration="19165000" />
245
+ <workItem from="1715737090799" duration="8221000" />
246
+ <workItem from="1715825644950" duration="25707000" />
247
+ <workItem from="1715912343385" duration="771000" />
248
+ <workItem from="1715913133797" duration="6272000" />
249
+ <workItem from="1715959257185" duration="1615000" />
250
+ <workItem from="1716202913497" duration="1828000" />
251
+ <workItem from="1716518387871" duration="58000" />
252
+ <workItem from="1716625270304" duration="311000" />
253
+ <workItem from="1717227431906" duration="27000" />
254
+ <workItem from="1717554542796" duration="3744000" />
255
+ <workItem from="1717639168925" duration="599000" />
256
+ <workItem from="1717723619398" duration="15918000" />
257
+ <workItem from="1717815585723" duration="5160000" />
258
+ <workItem from="1717901397527" duration="3653000" />
259
+ <workItem from="1718069302578" duration="604000" />
260
+ <workItem from="1719749124178" duration="2000" />
261
+ <workItem from="1721024783860" duration="3000" />
262
+ <workItem from="1721484421877" duration="939000" />
263
+ <workItem from="1721528085975" duration="6199000" />
264
+ <workItem from="1721613881635" duration="3849000" />
265
+ <workItem from="1722040494765" duration="13950000" />
266
+ <workItem from="1722062512246" duration="18586000" />
267
+ <workItem from="1722129426856" duration="17038000" />
268
+ <workItem from="1722215818784" duration="16344000" />
269
+ <workItem from="1722304780775" duration="18623000" />
270
+ <workItem from="1722407231490" duration="6609000" />
271
+ <workItem from="1722472667237" duration="5193000" />
272
+ <workItem from="1722657508626" duration="597000" />
273
+ <workItem from="1723690716890" duration="9000" />
274
+ <workItem from="1723793077905" duration="5340000" />
275
+ <workItem from="1723806823176" duration="6000" />
276
+ </task>
277
+ <servers />
278
+ </component>
279
+ <component name="TypeScriptGeneratedFilesManager">
280
+ <option name="version" value="3" />
281
+ </component>
282
+ <component name="XDebuggerManager">
283
+ <breakpoint-manager>
284
+ <breakpoints>
285
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
286
+ <url>file://$PROJECT_DIR$/demo.py</url>
287
+ <line>352</line>
288
+ <option name="timeStamp" value="21" />
289
+ </line-breakpoint>
290
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
291
+ <url>file://$PROJECT_DIR$/demo.py</url>
292
+ <line>350</line>
293
+ <option name="timeStamp" value="22" />
294
+ </line-breakpoint>
295
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
296
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
297
+ <line>49</line>
298
+ <option name="timeStamp" value="62" />
299
+ </line-breakpoint>
300
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
301
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py</url>
302
+ <line>347</line>
303
+ <option name="timeStamp" value="64" />
304
+ </line-breakpoint>
305
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
306
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
307
+ <line>39</line>
308
+ <option name="timeStamp" value="65" />
309
+ </line-breakpoint>
310
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
311
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py</url>
312
+ <line>307</line>
313
+ <option name="timeStamp" value="66" />
314
+ </line-breakpoint>
315
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
316
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py</url>
317
+ <line>301</line>
318
+ <option name="timeStamp" value="67" />
319
+ </line-breakpoint>
320
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
321
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
322
+ <line>64</line>
323
+ <option name="timeStamp" value="114" />
324
+ </line-breakpoint>
325
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
326
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
327
+ <line>107</line>
328
+ <option name="timeStamp" value="115" />
329
+ </line-breakpoint>
330
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
331
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
332
+ <line>106</line>
333
+ <option name="timeStamp" value="116" />
334
+ </line-breakpoint>
335
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
336
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
337
+ <line>204</line>
338
+ <option name="timeStamp" value="149" />
339
+ </line-breakpoint>
340
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
341
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
342
+ <line>75</line>
343
+ <option name="timeStamp" value="164" />
344
+ </line-breakpoint>
345
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
346
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
347
+ <line>187</line>
348
+ <option name="timeStamp" value="378" />
349
+ </line-breakpoint>
350
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
351
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
352
+ <line>179</line>
353
+ <option name="timeStamp" value="382" />
354
+ </line-breakpoint>
355
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
356
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
357
+ <line>166</line>
358
+ <option name="timeStamp" value="383" />
359
+ </line-breakpoint>
360
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
361
+ <url>file://$PROJECT_DIR$/croco/models/dpt_block.py</url>
362
+ <line>444</line>
363
+ <option name="timeStamp" value="386" />
364
+ </line-breakpoint>
365
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
366
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
367
+ <line>71</line>
368
+ <option name="timeStamp" value="388" />
369
+ </line-breakpoint>
370
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
371
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
372
+ <line>184</line>
373
+ <option name="timeStamp" value="391" />
374
+ </line-breakpoint>
375
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
376
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
377
+ <line>58</line>
378
+ <option name="timeStamp" value="421" />
379
+ </line-breakpoint>
380
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
381
+ <url>file://$PROJECT_DIR$/load_nvos.py</url>
382
+ <line>163</line>
383
+ <option name="timeStamp" value="431" />
384
+ </line-breakpoint>
385
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
386
+ <url>file://$PROJECT_DIR$/load_nvos.py</url>
387
+ <line>185</line>
388
+ <option name="timeStamp" value="432" />
389
+ </line-breakpoint>
390
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
391
+ <url>file://$PROJECT_DIR$/load_nvos.py</url>
392
+ <line>187</line>
393
+ <option name="timeStamp" value="433" />
394
+ </line-breakpoint>
395
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
396
+ <url>file://$PROJECT_DIR$/load_nvos.py</url>
397
+ <line>184</line>
398
+ <option name="timeStamp" value="434" />
399
+ </line-breakpoint>
400
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
401
+ <url>file://$PROJECT_DIR$/evaluate.py</url>
402
+ <line>90</line>
403
+ <option name="timeStamp" value="435" />
404
+ </line-breakpoint>
405
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
406
+ <url>file://$PROJECT_DIR$/evaluate.py</url>
407
+ <line>94</line>
408
+ <option name="timeStamp" value="437" />
409
+ </line-breakpoint>
410
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
411
+ <url>file://$PROJECT_DIR$/evaluate.py</url>
412
+ <line>56</line>
413
+ <option name="timeStamp" value="438" />
414
+ </line-breakpoint>
415
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
416
+ <url>file://$PROJECT_DIR$/evaluate.py</url>
417
+ <line>89</line>
418
+ <option name="timeStamp" value="441" />
419
+ </line-breakpoint>
420
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
421
+ <url>file://$PROJECT_DIR$/evaluate.py</url>
422
+ <line>95</line>
423
+ <option name="timeStamp" value="442" />
424
+ </line-breakpoint>
425
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
426
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
427
+ <line>533</line>
428
+ <option name="timeStamp" value="444" />
429
+ </line-breakpoint>
430
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
431
+ <url>file://$PROJECT_DIR$/load_nvos.py</url>
432
+ <line>171</line>
433
+ <option name="timeStamp" value="446" />
434
+ </line-breakpoint>
435
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
436
+ <url>file://$PROJECT_DIR$/load_nvos.py</url>
437
+ <line>172</line>
438
+ <option name="timeStamp" value="447" />
439
+ </line-breakpoint>
440
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
441
+ <url>file://$PROJECT_DIR$/load_nvos.py</url>
442
+ <line>167</line>
443
+ <option name="timeStamp" value="449" />
444
+ </line-breakpoint>
445
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
446
+ <url>file://$PROJECT_DIR$/load_nvos.py</url>
447
+ <line>166</line>
448
+ <option name="timeStamp" value="450" />
449
+ </line-breakpoint>
450
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
451
+ <url>file://$PROJECT_DIR$/load_nvos.py</url>
452
+ <line>170</line>
453
+ <option name="timeStamp" value="451" />
454
+ </line-breakpoint>
455
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
456
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
457
+ <line>37</line>
458
+ <option name="timeStamp" value="455" />
459
+ </line-breakpoint>
460
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
461
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
462
+ <line>128</line>
463
+ <option name="timeStamp" value="456" />
464
+ </line-breakpoint>
465
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
466
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
467
+ <line>131</line>
468
+ <option name="timeStamp" value="457" />
469
+ </line-breakpoint>
470
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
471
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
472
+ <line>146</line>
473
+ <option name="timeStamp" value="458" />
474
+ </line-breakpoint>
475
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
476
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
477
+ <line>136</line>
478
+ <option name="timeStamp" value="459" />
479
+ </line-breakpoint>
480
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
481
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
482
+ <line>140</line>
483
+ <option name="timeStamp" value="460" />
484
+ </line-breakpoint>
485
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
486
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
487
+ <line>143</line>
488
+ <option name="timeStamp" value="461" />
489
+ </line-breakpoint>
490
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
491
+ <url>file://$PROJECT_DIR$/dust3r/post_process.py</url>
492
+ <line>16</line>
493
+ <option name="timeStamp" value="465" />
494
+ </line-breakpoint>
495
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
496
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
497
+ <line>291</line>
498
+ <option name="timeStamp" value="474" />
499
+ </line-breakpoint>
500
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
501
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/init_im_poses.py</url>
502
+ <line>292</line>
503
+ <option name="timeStamp" value="475" />
504
+ </line-breakpoint>
505
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
506
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py</url>
507
+ <line>370</line>
508
+ <option name="timeStamp" value="484" />
509
+ </line-breakpoint>
510
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
511
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py</url>
512
+ <line>270</line>
513
+ <option name="timeStamp" value="486" />
514
+ </line-breakpoint>
515
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
516
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py</url>
517
+ <line>269</line>
518
+ <option name="timeStamp" value="487" />
519
+ </line-breakpoint>
520
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
521
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/optimizer.py</url>
522
+ <line>179</line>
523
+ <option name="timeStamp" value="490" />
524
+ </line-breakpoint>
525
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
526
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/optimizer.py</url>
527
+ <line>195</line>
528
+ <option name="timeStamp" value="492" />
529
+ </line-breakpoint>
530
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
531
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/optimizer.py</url>
532
+ <line>176</line>
533
+ <option name="timeStamp" value="493" />
534
+ </line-breakpoint>
535
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
536
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/optimizer.py</url>
537
+ <line>197</line>
538
+ <option name="timeStamp" value="494" />
539
+ </line-breakpoint>
540
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
541
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/optimizer.py</url>
542
+ <line>187</line>
543
+ <option name="timeStamp" value="495" />
544
+ </line-breakpoint>
545
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
546
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
547
+ <line>30</line>
548
+ <option name="timeStamp" value="497" />
549
+ </line-breakpoint>
550
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
551
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py</url>
552
+ <line>377</line>
553
+ <option name="timeStamp" value="503" />
554
+ </line-breakpoint>
555
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
556
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py</url>
557
+ <line>140</line>
558
+ <option name="timeStamp" value="504" />
559
+ </line-breakpoint>
560
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
561
+ <url>file://$PROJECT_DIR$/dust3r/cloud_opt/base_opt.py</url>
562
+ <line>139</line>
563
+ <option name="timeStamp" value="505" />
564
+ </line-breakpoint>
565
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
566
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
567
+ <line>167</line>
568
+ <option name="timeStamp" value="506" />
569
+ </line-breakpoint>
570
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
571
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
572
+ <line>173</line>
573
+ <option name="timeStamp" value="508" />
574
+ </line-breakpoint>
575
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
576
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
577
+ <line>171</line>
578
+ <option name="timeStamp" value="509" />
579
+ </line-breakpoint>
580
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
581
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
582
+ <line>162</line>
583
+ <option name="timeStamp" value="513" />
584
+ </line-breakpoint>
585
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
586
+ <url>file://$PROJECT_DIR$/SAM/predictor.py</url>
587
+ <line>162</line>
588
+ <option name="timeStamp" value="514" />
589
+ </line-breakpoint>
590
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
591
+ <url>file://$PROJECT_DIR$/SAM/predictor.py</url>
592
+ <line>153</line>
593
+ <option name="timeStamp" value="515" />
594
+ </line-breakpoint>
595
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
596
+ <url>file://$PROJECT_DIR$/SAM/predictor.py</url>
597
+ <line>237</line>
598
+ <option name="timeStamp" value="516" />
599
+ </line-breakpoint>
600
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
601
+ <url>file://$PROJECT_DIR$/SAM/predictor.py</url>
602
+ <line>239</line>
603
+ <option name="timeStamp" value="517" />
604
+ </line-breakpoint>
605
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
606
+ <url>file://$PROJECT_DIR$/gys_util.py</url>
607
+ <line>109</line>
608
+ <option name="timeStamp" value="522" />
609
+ </line-breakpoint>
610
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
611
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
612
+ <line>20</line>
613
+ <option name="timeStamp" value="526" />
614
+ </line-breakpoint>
615
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
616
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
617
+ <line>21</line>
618
+ <option name="timeStamp" value="527" />
619
+ </line-breakpoint>
620
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
621
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
622
+ <line>94</line>
623
+ <option name="timeStamp" value="530" />
624
+ </line-breakpoint>
625
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
626
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
627
+ <line>337</line>
628
+ <option name="timeStamp" value="536" />
629
+ </line-breakpoint>
630
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
631
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
632
+ <line>350</line>
633
+ <option name="timeStamp" value="537" />
634
+ </line-breakpoint>
635
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
636
+ <url>file://$PROJECT_DIR$/segment_eval_mask.py</url>
637
+ <line>505</line>
638
+ <option name="timeStamp" value="539" />
639
+ </line-breakpoint>
640
+ </breakpoints>
641
+ </breakpoint-manager>
642
+ </component>
643
+ <component name="com.intellij.coverage.CoverageDataManagerImpl">
644
+ <SUITE FILE_PATH="coverage/dust3r$test_vis.coverage" NAME="test_vis Coverage Results" MODIFIED="1714045279462" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
645
+ <SUITE FILE_PATH="coverage/dust3r$gys_util.coverage" NAME="gys_util Coverage Results" MODIFIED="1722411390675" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
646
+ <SUITE FILE_PATH="coverage/dust3r$load_nvos.coverage" NAME="load_nvos Coverage Results" MODIFIED="1722071842346" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
647
+ <SUITE FILE_PATH="coverage/dust3r$prepare_prompts.coverage" NAME="prepare_prompts Coverage Results" MODIFIED="1714108229869" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/lib" />
648
+ <SUITE FILE_PATH="coverage/dust3r$base_opt.coverage" NAME="base_opt Coverage Results" MODIFIED="1722220609861" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/dust3r/cloud_opt" />
649
+ <SUITE FILE_PATH="coverage/dust3r$segment_eval_mask.coverage" NAME="segment_eval_mask Coverage Results" MODIFIED="1723797134862" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
650
+ <SUITE FILE_PATH="coverage/dust3r$demo.coverage" NAME="demo Coverage Results" MODIFIED="1714038776406" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
651
+ <SUITE FILE_PATH="coverage/dust3r$evaluate.coverage" NAME="evaluate Coverage Results" MODIFIED="1722070716981" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
652
+ </component>
653
+ </project>
LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
NOTICE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DUSt3R
2
+ Copyright 2024-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ naver/croco
10
+ https://github.com/naver/croco/
11
+
12
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0
13
+
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Our3D
3
- emoji: 🏆
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.42.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: 3D
3
+ emoji: 🐨
4
+ colorFrom: yellow
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
SAM/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_sam import (
8
+ build_sam,
9
+ build_sam_vit_h,
10
+ build_sam_vit_l,
11
+ build_sam_vit_b,
12
+ sam_model_registry,
13
+ )
14
+ from .predictor import SamPredictor
15
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
SAM/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (401 Bytes). View file
 
SAM/__pycache__/automatic_mask_generator.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
SAM/__pycache__/build_sam.cpython-310.pyc ADDED
Binary file (2.15 kB). View file
 
SAM/__pycache__/predictor.cpython-310.pyc ADDED
Binary file (9.94 kB). View file
 
SAM/automatic_mask_generator.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10
+
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ from .modeling import Sam
14
+ from .predictor import SamPredictor
15
+ from .utils.amg import (
16
+ MaskData,
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ remove_small_regions,
28
+ rle_to_mask,
29
+ uncrop_boxes_xyxy,
30
+ uncrop_masks,
31
+ uncrop_points,
32
+ )
33
+
34
+
35
+ class SamAutomaticMaskGenerator:
36
+ def __init__(
37
+ self,
38
+ model: Sam,
39
+ points_per_side: Optional[int] = 32,
40
+ points_per_batch: int = 64,
41
+ pred_iou_thresh: float = 0.88,
42
+ stability_score_thresh: float = 0.95,
43
+ stability_score_offset: float = 1.0,
44
+ box_nms_thresh: float = 0.7,
45
+ crop_n_layers: int = 0,
46
+ crop_nms_thresh: float = 0.7,
47
+ crop_overlap_ratio: float = 512 / 1500,
48
+ crop_n_points_downscale_factor: int = 1,
49
+ point_grids: Optional[List[np.ndarray]] = None,
50
+ min_mask_region_area: int = 0,
51
+ output_mode: str = "binary_mask",
52
+ ) -> None:
53
+ """
54
+ Using a SAM model, generates masks for the entire image.
55
+ Generates a grid of point prompts over the image, then filters
56
+ low quality and duplicate masks. The default settings are chosen
57
+ for SAM with a ViT-H backbone.
58
+
59
+ Arguments:
60
+ model (Sam): The SAM model to use for mask prediction.
61
+ points_per_side (int or None): The number of points to be sampled
62
+ along one side of the image. The total number of points is
63
+ points_per_side**2. If None, 'point_grids' must provide explicit
64
+ point sampling.
65
+ points_per_batch (int): Sets the number of points run simultaneously
66
+ by the model. Higher numbers may be faster but use more GPU memory.
67
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
68
+ model's predicted mask quality.
69
+ stability_score_thresh (float): A filtering threshold in [0,1], using
70
+ the stability of the mask under changes to the cutoff used to binarize
71
+ the model's mask predictions.
72
+ stability_score_offset (float): The amount to shift the cutoff when
73
+ calculated the stability score.
74
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
75
+ suppression to filter duplicate masks.
76
+ crop_n_layers (int): If >0, mask prediction will be run again on
77
+ crops of the image. Sets the number of layers to run, where each
78
+ layer has 2**i_layer number of image crops.
79
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
80
+ suppression to filter duplicate masks between different crops.
81
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
82
+ In the first crop layer, crops will overlap by this fraction of
83
+ the image length. Later layers with more crops scale down this overlap.
84
+ crop_n_points_downscale_factor (int): The number of points-per-side
85
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86
+ point_grids (list(np.ndarray) or None): A list over explicit grids
87
+ of points used for sampling, normalized to [0,1]. The nth grid in the
88
+ list is used in the nth crop layer. Exclusive with points_per_side.
89
+ min_mask_region_area (int): If >0, postprocessing will be applied
90
+ to remove disconnected regions and holes in masks with area smaller
91
+ than min_mask_region_area. Requires opencv.
92
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
93
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94
+ For large resolutions, 'binary_mask' may consume large amounts of
95
+ memory.
96
+ """
97
+
98
+ assert (points_per_side is None) != (
99
+ point_grids is None
100
+ ), "Exactly one of points_per_side or point_grid must be provided."
101
+ if points_per_side is not None:
102
+ self.point_grids = build_all_layer_point_grids(
103
+ points_per_side,
104
+ crop_n_layers,
105
+ crop_n_points_downscale_factor,
106
+ )
107
+ elif point_grids is not None:
108
+ self.point_grids = point_grids
109
+ else:
110
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
111
+
112
+ assert output_mode in [
113
+ "binary_mask",
114
+ "uncompressed_rle",
115
+ "coco_rle",
116
+ ], f"Unknown output_mode {output_mode}."
117
+ if output_mode == "coco_rle":
118
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119
+
120
+ if min_mask_region_area > 0:
121
+ import cv2 # type: ignore # noqa: F401
122
+
123
+ self.predictor = SamPredictor(model)
124
+ self.points_per_batch = points_per_batch
125
+ self.pred_iou_thresh = pred_iou_thresh
126
+ self.stability_score_thresh = stability_score_thresh
127
+ self.stability_score_offset = stability_score_offset
128
+ self.box_nms_thresh = box_nms_thresh
129
+ self.crop_n_layers = crop_n_layers
130
+ self.crop_nms_thresh = crop_nms_thresh
131
+ self.crop_overlap_ratio = crop_overlap_ratio
132
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133
+ self.min_mask_region_area = min_mask_region_area
134
+ self.output_mode = output_mode
135
+
136
+ @torch.no_grad()
137
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138
+ """
139
+ Generates masks for the given image.
140
+
141
+ Arguments:
142
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143
+
144
+ Returns:
145
+ list(dict(str, any)): A list over records for masks. Each record is
146
+ a dict containing the following keys:
147
+ segmentation (dict(str, any) or np.ndarray): The mask. If
148
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
149
+ is a dictionary containing the RLE.
150
+ bbox (list(float)): The box around the mask, in XYWH format.
151
+ area (int): The area in pixels of the mask.
152
+ predicted_iou (float): The model's own prediction of the mask's
153
+ quality. This is filtered by the pred_iou_thresh parameter.
154
+ point_coords (list(list(float))): The point coordinates input
155
+ to the model to generate this mask.
156
+ stability_score (float): A measure of the mask's quality. This
157
+ is filtered on using the stability_score_thresh parameter.
158
+ crop_box (list(float)): The crop of the image used to generate
159
+ the mask, given in XYWH format.
160
+ """
161
+
162
+ # Generate masks
163
+ mask_data = self._generate_masks(image)
164
+
165
+ # Filter small disconnected regions and holes in masks
166
+ if self.min_mask_region_area > 0:
167
+ mask_data = self.postprocess_small_regions(
168
+ mask_data,
169
+ self.min_mask_region_area,
170
+ max(self.box_nms_thresh, self.crop_nms_thresh),
171
+ )
172
+
173
+ # Encode masks
174
+ if self.output_mode == "coco_rle":
175
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176
+ elif self.output_mode == "binary_mask":
177
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178
+ else:
179
+ mask_data["segmentations"] = mask_data["rles"]
180
+
181
+ # Write mask records
182
+ curr_anns = []
183
+ for idx in range(len(mask_data["segmentations"])):
184
+ ann = {
185
+ "segmentation": mask_data["segmentations"][idx],
186
+ "area": area_from_rle(mask_data["rles"][idx]),
187
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
189
+ "point_coords": [mask_data["points"][idx].tolist()],
190
+ "stability_score": mask_data["stability_score"][idx].item(),
191
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192
+ }
193
+ curr_anns.append(ann)
194
+
195
+ return curr_anns
196
+
197
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
198
+ orig_size = image.shape[:2]
199
+ crop_boxes, layer_idxs = generate_crop_boxes(
200
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
201
+ )
202
+
203
+ # Iterate over image crops
204
+ data = MaskData()
205
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
206
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207
+ data.cat(crop_data)
208
+
209
+ # Remove duplicate masks between crops
210
+ if len(crop_boxes) > 1:
211
+ # Prefer masks from smaller crops
212
+ scores = 1 / box_area(data["crop_boxes"])
213
+ scores = scores.to(data["boxes"].device)
214
+ keep_by_nms = batched_nms(
215
+ data["boxes"].float(),
216
+ scores,
217
+ torch.zeros_like(data["boxes"][:, 0]), # categories
218
+ iou_threshold=self.crop_nms_thresh,
219
+ )
220
+ data.filter(keep_by_nms)
221
+
222
+ data.to_numpy()
223
+ return data
224
+
225
+ def _process_crop(
226
+ self,
227
+ image: np.ndarray,
228
+ crop_box: List[int],
229
+ crop_layer_idx: int,
230
+ orig_size: Tuple[int, ...],
231
+ ) -> MaskData:
232
+ # Crop the image and calculate embeddings
233
+ x0, y0, x1, y1 = crop_box
234
+ cropped_im = image[y0:y1, x0:x1, :]
235
+ cropped_im_size = cropped_im.shape[:2]
236
+ self.predictor.set_image(cropped_im)
237
+
238
+ # Get points for this crop
239
+ points_scale = np.array(cropped_im_size)[None, ::-1]
240
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
241
+
242
+ # Generate masks for this crop in batches
243
+ data = MaskData()
244
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246
+ data.cat(batch_data)
247
+ del batch_data
248
+ self.predictor.reset_image()
249
+
250
+ # Remove duplicates within this crop.
251
+ keep_by_nms = batched_nms(
252
+ data["boxes"].float(),
253
+ data["iou_preds"],
254
+ torch.zeros_like(data["boxes"][:, 0]), # categories
255
+ iou_threshold=self.box_nms_thresh,
256
+ )
257
+ data.filter(keep_by_nms)
258
+
259
+ # Return to the original image frame
260
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
261
+ data["points"] = uncrop_points(data["points"], crop_box)
262
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
263
+
264
+ return data
265
+
266
+ def _process_batch(
267
+ self,
268
+ points: np.ndarray,
269
+ im_size: Tuple[int, ...],
270
+ crop_box: List[int],
271
+ orig_size: Tuple[int, ...],
272
+ ) -> MaskData:
273
+ orig_h, orig_w = orig_size
274
+
275
+ # Run model on this batch
276
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
277
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279
+ masks, iou_preds, _ = self.predictor.predict_torch(
280
+ in_points[:, None, :],
281
+ in_labels[:, None],
282
+ multimask_output=True,
283
+ return_logits=True,
284
+ )
285
+
286
+ # Serialize predictions and store in MaskData
287
+ data = MaskData(
288
+ masks=masks.flatten(0, 1),
289
+ iou_preds=iou_preds.flatten(0, 1),
290
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
291
+ )
292
+ del masks
293
+
294
+ # Filter by predicted IoU
295
+ if self.pred_iou_thresh > 0.0:
296
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
297
+ data.filter(keep_mask)
298
+
299
+ # Calculate stability score
300
+ data["stability_score"] = calculate_stability_score(
301
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
302
+ )
303
+ if self.stability_score_thresh > 0.0:
304
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
305
+ data.filter(keep_mask)
306
+
307
+ # Threshold masks and calculate boxes
308
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
309
+ data["boxes"] = batched_mask_to_box(data["masks"])
310
+
311
+ # Filter boxes that touch crop boundaries
312
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
313
+ if not torch.all(keep_mask):
314
+ data.filter(keep_mask)
315
+
316
+ # Compress to RLE
317
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
319
+ del data["masks"]
320
+
321
+ return data
322
+
323
+ @staticmethod
324
+ def postprocess_small_regions(
325
+ mask_data: MaskData, min_area: int, nms_thresh: float
326
+ ) -> MaskData:
327
+ """
328
+ Removes small disconnected regions and holes in masks, then reruns
329
+ box NMS to remove any new duplicates.
330
+
331
+ Edits mask_data in place.
332
+
333
+ Requires open-cv as a dependency.
334
+ """
335
+ if len(mask_data["rles"]) == 0:
336
+ return mask_data
337
+
338
+ # Filter small disconnected regions and holes
339
+ new_masks = []
340
+ scores = []
341
+ for rle in mask_data["rles"]:
342
+ mask = rle_to_mask(rle)
343
+
344
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
345
+ unchanged = not changed
346
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
347
+ unchanged = unchanged and not changed
348
+
349
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350
+ # Give score=0 to changed masks and score=1 to unchanged masks
351
+ # so NMS will prefer ones that didn't need postprocessing
352
+ scores.append(float(unchanged))
353
+
354
+ # Recalculate boxes and remove any new duplicates
355
+ masks = torch.cat(new_masks, dim=0)
356
+ boxes = batched_mask_to_box(masks)
357
+ keep_by_nms = batched_nms(
358
+ boxes.float(),
359
+ torch.as_tensor(scores),
360
+ torch.zeros_like(boxes[:, 0]), # categories
361
+ iou_threshold=nms_thresh,
362
+ )
363
+
364
+ # Only recalculate RLEs for masks that have changed
365
+ for i_mask in keep_by_nms:
366
+ if scores[i_mask] == 0.0:
367
+ mask_torch = masks[i_mask].unsqueeze(0)
368
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370
+ mask_data.filter(keep_by_nms)
371
+
372
+ return mask_data
SAM/build_sam.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12
+
13
+
14
+ def build_sam_vit_h(checkpoint=None):
15
+ return _build_sam(
16
+ encoder_embed_dim=1280,
17
+ encoder_depth=32,
18
+ encoder_num_heads=16,
19
+ encoder_global_attn_indexes=[7, 15, 23, 31],
20
+ checkpoint=checkpoint,
21
+ )
22
+
23
+
24
+ build_sam = build_sam_vit_h
25
+
26
+
27
+ def build_sam_vit_l(checkpoint=None):
28
+ return _build_sam(
29
+ encoder_embed_dim=1024,
30
+ encoder_depth=24,
31
+ encoder_num_heads=16,
32
+ encoder_global_attn_indexes=[5, 11, 17, 23],
33
+ checkpoint=checkpoint,
34
+ )
35
+
36
+
37
+ def build_sam_vit_b(checkpoint=None):
38
+ return _build_sam(
39
+ encoder_embed_dim=768,
40
+ encoder_depth=12,
41
+ encoder_num_heads=12,
42
+ encoder_global_attn_indexes=[2, 5, 8, 11],
43
+ checkpoint=checkpoint,
44
+ )
45
+
46
+
47
+ sam_model_registry = {
48
+ "default": build_sam_vit_h,
49
+ "vit_h": build_sam_vit_h,
50
+ "vit_l": build_sam_vit_l,
51
+ "vit_b": build_sam_vit_b,
52
+ }
53
+
54
+
55
+ def _build_sam(
56
+ encoder_embed_dim,
57
+ encoder_depth,
58
+ encoder_num_heads,
59
+ encoder_global_attn_indexes,
60
+ checkpoint=None,
61
+ ):
62
+ prompt_embed_dim = 256
63
+ image_size = 1024
64
+ vit_patch_size = 16
65
+ image_embedding_size = image_size // vit_patch_size
66
+ sam = Sam(
67
+ image_encoder=ImageEncoderViT(
68
+ depth=encoder_depth,
69
+ embed_dim=encoder_embed_dim,
70
+ img_size=image_size,
71
+ mlp_ratio=4,
72
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73
+ num_heads=encoder_num_heads,
74
+ patch_size=vit_patch_size,
75
+ qkv_bias=True,
76
+ use_rel_pos=True,
77
+ global_attn_indexes=encoder_global_attn_indexes,
78
+ window_size=14,
79
+ out_chans=prompt_embed_dim,
80
+ ),
81
+ prompt_encoder=PromptEncoder(
82
+ embed_dim=prompt_embed_dim,
83
+ image_embedding_size=(image_embedding_size, image_embedding_size),
84
+ input_image_size=(image_size, image_size),
85
+ mask_in_chans=16,
86
+ ),
87
+ mask_decoder=MaskDecoder(
88
+ num_multimask_outputs=3,
89
+ transformer=TwoWayTransformer(
90
+ depth=2,
91
+ embedding_dim=prompt_embed_dim,
92
+ mlp_dim=2048,
93
+ num_heads=8,
94
+ ),
95
+ transformer_dim=prompt_embed_dim,
96
+ iou_head_depth=3,
97
+ iou_head_hidden_dim=256,
98
+ ),
99
+ pixel_mean=[123.675, 116.28, 103.53],
100
+ pixel_std=[58.395, 57.12, 57.375],
101
+ )
102
+ sam.eval()
103
+ if checkpoint is not None:
104
+ with open(checkpoint, "rb") as f:
105
+ state_dict = torch.load(f)
106
+ sam.load_state_dict(state_dict)
107
+ return sam
SAM/modeling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .sam import Sam
8
+ from .mask_decoder import MaskDecoder
9
+ from .prompt_encoder import PromptEncoder
10
+ from .transformer import TwoWayTransformer
11
+ from .image_encoder import ImageEncoderViT
SAM/modeling/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (388 Bytes). View file
 
SAM/modeling/__pycache__/common.cpython-310.pyc ADDED
Binary file (1.74 kB). View file
 
SAM/modeling/__pycache__/image_encoder.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
SAM/modeling/__pycache__/mask_decoder.cpython-310.pyc ADDED
Binary file (5.94 kB). View file
 
SAM/modeling/__pycache__/prompt_encoder.cpython-310.pyc ADDED
Binary file (7.67 kB). View file
 
SAM/modeling/__pycache__/sam.cpython-310.pyc ADDED
Binary file (6.76 kB). View file
 
SAM/modeling/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (6.6 kB). View file
 
SAM/modeling/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from typing import Type
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
SAM/modeling/image_encoder.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d, MLPBlock
14
+
15
+
16
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17
+ class ImageEncoderViT(nn.Module):
18
+ def __init__(
19
+ self,
20
+ img_size: int = 1024,
21
+ patch_size: int = 16,
22
+ in_chans: int = 3,
23
+ embed_dim: int = 768,
24
+ depth: int = 12,
25
+ num_heads: int = 12,
26
+ mlp_ratio: float = 4.0,
27
+ out_chans: int = 256,
28
+ qkv_bias: bool = True,
29
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
30
+ act_layer: Type[nn.Module] = nn.GELU,
31
+ use_abs_pos: bool = True,
32
+ use_rel_pos: bool = False,
33
+ rel_pos_zero_init: bool = True,
34
+ window_size: int = 0,
35
+ global_attn_indexes: Tuple[int, ...] = (),
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ img_size (int): Input image size.
40
+ patch_size (int): Patch size.
41
+ in_chans (int): Number of input image channels.
42
+ embed_dim (int): Patch embedding dimension.
43
+ depth (int): Depth of ViT.
44
+ num_heads (int): Number of attention heads in each ViT block.
45
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
47
+ norm_layer (nn.Module): Normalization layer.
48
+ act_layer (nn.Module): Activation layer.
49
+ use_abs_pos (bool): If True, use absolute positional embeddings.
50
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52
+ window_size (int): Window size for window attention blocks.
53
+ global_attn_indexes (list): Indexes for blocks using global attention.
54
+ """
55
+ super().__init__()
56
+ self.img_size = img_size
57
+
58
+ self.patch_embed = PatchEmbed(
59
+ kernel_size=(patch_size, patch_size),
60
+ stride=(patch_size, patch_size),
61
+ in_chans=in_chans,
62
+ embed_dim=embed_dim,
63
+ )
64
+
65
+ self.pos_embed: Optional[nn.Parameter] = None
66
+ if use_abs_pos:
67
+ # Initialize absolute positional embedding with pretrain image size.
68
+ self.pos_embed = nn.Parameter(
69
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
70
+ )
71
+
72
+ self.blocks = nn.ModuleList()
73
+ for i in range(depth):
74
+ block = Block(
75
+ dim=embed_dim,
76
+ num_heads=num_heads,
77
+ mlp_ratio=mlp_ratio,
78
+ qkv_bias=qkv_bias,
79
+ norm_layer=norm_layer,
80
+ act_layer=act_layer,
81
+ use_rel_pos=use_rel_pos,
82
+ rel_pos_zero_init=rel_pos_zero_init,
83
+ window_size=window_size if i not in global_attn_indexes else 0,
84
+ input_size=(img_size // patch_size, img_size // patch_size),
85
+ )
86
+ self.blocks.append(block)
87
+
88
+ self.neck = nn.Sequential(
89
+ nn.Conv2d(
90
+ embed_dim,
91
+ out_chans,
92
+ kernel_size=1,
93
+ bias=False,
94
+ ),
95
+ LayerNorm2d(out_chans),
96
+ nn.Conv2d(
97
+ out_chans,
98
+ out_chans,
99
+ kernel_size=3,
100
+ padding=1,
101
+ bias=False,
102
+ ),
103
+ LayerNorm2d(out_chans),
104
+ )
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ x = self.patch_embed(x)
108
+ if self.pos_embed is not None:
109
+ x = x + self.pos_embed
110
+
111
+ for blk in self.blocks:
112
+ x = blk(x)
113
+
114
+ x = self.neck(x.permute(0, 3, 1, 2))
115
+
116
+ return x
117
+
118
+
119
+ class Block(nn.Module):
120
+ """Transformer blocks with support of window attention and residual propagation blocks"""
121
+
122
+ def __init__(
123
+ self,
124
+ dim: int,
125
+ num_heads: int,
126
+ mlp_ratio: float = 4.0,
127
+ qkv_bias: bool = True,
128
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
129
+ act_layer: Type[nn.Module] = nn.GELU,
130
+ use_rel_pos: bool = False,
131
+ rel_pos_zero_init: bool = True,
132
+ window_size: int = 0,
133
+ input_size: Optional[Tuple[int, int]] = None,
134
+ ) -> None:
135
+ """
136
+ Args:
137
+ dim (int): Number of input channels.
138
+ num_heads (int): Number of attention heads in each ViT block.
139
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
140
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
141
+ norm_layer (nn.Module): Normalization layer.
142
+ act_layer (nn.Module): Activation layer.
143
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
144
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
145
+ window_size (int): Window size for window attention blocks. If it equals 0, then
146
+ use global attention.
147
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
148
+ positional parameter size.
149
+ """
150
+ super().__init__()
151
+ self.norm1 = norm_layer(dim)
152
+ self.attn = Attention(
153
+ dim,
154
+ num_heads=num_heads,
155
+ qkv_bias=qkv_bias,
156
+ use_rel_pos=use_rel_pos,
157
+ rel_pos_zero_init=rel_pos_zero_init,
158
+ input_size=input_size if window_size == 0 else (window_size, window_size),
159
+ )
160
+
161
+ self.norm2 = norm_layer(dim)
162
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
163
+
164
+ self.window_size = window_size
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ shortcut = x
168
+ x = self.norm1(x)
169
+ # Window partition
170
+ if self.window_size > 0:
171
+ H, W = x.shape[1], x.shape[2]
172
+ x, pad_hw = window_partition(x, self.window_size)
173
+
174
+ x = self.attn(x)
175
+ # Reverse window partition
176
+ if self.window_size > 0:
177
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
178
+
179
+ x = shortcut + x
180
+ x = x + self.mlp(self.norm2(x))
181
+
182
+ return x
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """Multi-head Attention block with relative position embeddings."""
187
+
188
+ def __init__(
189
+ self,
190
+ dim: int,
191
+ num_heads: int = 8,
192
+ qkv_bias: bool = True,
193
+ use_rel_pos: bool = False,
194
+ rel_pos_zero_init: bool = True,
195
+ input_size: Optional[Tuple[int, int]] = None,
196
+ ) -> None:
197
+ """
198
+ Args:
199
+ dim (int): Number of input channels.
200
+ num_heads (int): Number of attention heads.
201
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
202
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
203
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
204
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
205
+ positional parameter size.
206
+ """
207
+ super().__init__()
208
+ self.num_heads = num_heads
209
+ head_dim = dim // num_heads
210
+ self.scale = head_dim**-0.5
211
+
212
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
213
+ self.proj = nn.Linear(dim, dim)
214
+
215
+ self.use_rel_pos = use_rel_pos
216
+ if self.use_rel_pos:
217
+ assert (
218
+ input_size is not None
219
+ ), "Input size must be provided if using relative positional encoding."
220
+ # initialize relative positional embeddings
221
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
222
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
223
+
224
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
225
+ B, H, W, _ = x.shape
226
+ # qkv with shape (3, B, nHead, H * W, C)
227
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
228
+ # q, k, v with shape (B * nHead, H * W, C)
229
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
230
+
231
+ attn = (q * self.scale) @ k.transpose(-2, -1)
232
+
233
+ if self.use_rel_pos:
234
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
235
+
236
+ attn = attn.softmax(dim=-1)
237
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
238
+ x = self.proj(x)
239
+
240
+ return x
241
+
242
+
243
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
244
+ """
245
+ Partition into non-overlapping windows with padding if needed.
246
+ Args:
247
+ x (tensor): input tokens with [B, H, W, C].
248
+ window_size (int): window size.
249
+
250
+ Returns:
251
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
252
+ (Hp, Wp): padded height and width before partition
253
+ """
254
+ B, H, W, C = x.shape
255
+
256
+ pad_h = (window_size - H % window_size) % window_size
257
+ pad_w = (window_size - W % window_size) % window_size
258
+ if pad_h > 0 or pad_w > 0:
259
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
260
+ Hp, Wp = H + pad_h, W + pad_w
261
+
262
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
263
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
264
+ return windows, (Hp, Wp)
265
+
266
+
267
+ def window_unpartition(
268
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
269
+ ) -> torch.Tensor:
270
+ """
271
+ Window unpartition into original sequences and removing padding.
272
+ Args:
273
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
274
+ window_size (int): window size.
275
+ pad_hw (Tuple): padded height and width (Hp, Wp).
276
+ hw (Tuple): original height and width (H, W) before padding.
277
+
278
+ Returns:
279
+ x: unpartitioned sequences with [B, H, W, C].
280
+ """
281
+ Hp, Wp = pad_hw
282
+ H, W = hw
283
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
284
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
285
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
286
+
287
+ if Hp > H or Wp > W:
288
+ x = x[:, :H, :W, :].contiguous()
289
+ return x
290
+
291
+
292
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
293
+ """
294
+ Get relative positional embeddings according to the relative positions of
295
+ query and key sizes.
296
+ Args:
297
+ q_size (int): size of query q.
298
+ k_size (int): size of key k.
299
+ rel_pos (Tensor): relative position embeddings (L, C).
300
+
301
+ Returns:
302
+ Extracted positional embeddings according to relative positions.
303
+ """
304
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
305
+ # Interpolate rel pos if needed.
306
+ if rel_pos.shape[0] != max_rel_dist:
307
+ # Interpolate rel pos.
308
+ rel_pos_resized = F.interpolate(
309
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
310
+ size=max_rel_dist,
311
+ mode="linear",
312
+ )
313
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
314
+ else:
315
+ rel_pos_resized = rel_pos
316
+
317
+ # Scale the coords with short length if shapes for q and k are different.
318
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
319
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
320
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
321
+
322
+ return rel_pos_resized[relative_coords.long()]
323
+
324
+
325
+ def add_decomposed_rel_pos(
326
+ attn: torch.Tensor,
327
+ q: torch.Tensor,
328
+ rel_pos_h: torch.Tensor,
329
+ rel_pos_w: torch.Tensor,
330
+ q_size: Tuple[int, int],
331
+ k_size: Tuple[int, int],
332
+ ) -> torch.Tensor:
333
+ """
334
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
335
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
336
+ Args:
337
+ attn (Tensor): attention map.
338
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
339
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
340
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
341
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
342
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
343
+
344
+ Returns:
345
+ attn (Tensor): attention map with added relative positional embeddings.
346
+ """
347
+ q_h, q_w = q_size
348
+ k_h, k_w = k_size
349
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
350
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
351
+
352
+ B, _, dim = q.shape
353
+ r_q = q.reshape(B, q_h, q_w, dim)
354
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
355
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
356
+
357
+ attn = (
358
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
359
+ ).view(B, q_h * q_w, k_h * k_w)
360
+
361
+ return attn
362
+
363
+
364
+ class PatchEmbed(nn.Module):
365
+ """
366
+ Image to Patch Embedding.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ kernel_size: Tuple[int, int] = (16, 16),
372
+ stride: Tuple[int, int] = (16, 16),
373
+ padding: Tuple[int, int] = (0, 0),
374
+ in_chans: int = 3,
375
+ embed_dim: int = 768,
376
+ ) -> None:
377
+ """
378
+ Args:
379
+ kernel_size (Tuple): kernel size of the projection layer.
380
+ stride (Tuple): stride of the projection layer.
381
+ padding (Tuple): padding size of the projection layer.
382
+ in_chans (int): Number of input image channels.
383
+ embed_dim (int): Patch embedding dimension.
384
+ """
385
+ super().__init__()
386
+
387
+ self.proj = nn.Conv2d(
388
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
389
+ )
390
+
391
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
392
+ x = self.proj(x)
393
+ # B C H W -> B H W C
394
+ x = x.permute(0, 2, 3, 1)
395
+ return x
SAM/modeling/mask_decoder.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import List, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class MaskDecoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ *,
20
+ transformer_dim: int,
21
+ transformer: nn.Module,
22
+ num_multimask_outputs: int = 3,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ iou_head_depth: int = 3,
25
+ iou_head_hidden_dim: int = 256,
26
+ ) -> None:
27
+ """
28
+ Predicts masks given an image and prompt embeddings, using a
29
+ transformer architecture.
30
+
31
+ Arguments:
32
+ transformer_dim (int): the channel dimension of the transformer
33
+ transformer (nn.Module): the transformer used to predict masks
34
+ num_multimask_outputs (int): the number of masks to predict
35
+ when disambiguating masks
36
+ activation (nn.Module): the type of activation to use when
37
+ upscaling masks
38
+ iou_head_depth (int): the depth of the MLP used to predict
39
+ mask quality
40
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
41
+ used to predict mask quality
42
+ """
43
+ super().__init__()
44
+ self.transformer_dim = transformer_dim
45
+ self.transformer = transformer
46
+
47
+ self.num_multimask_outputs = num_multimask_outputs
48
+
49
+ self.iou_token = nn.Embedding(1, transformer_dim)
50
+ self.num_mask_tokens = num_multimask_outputs + 1
51
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52
+
53
+ self.output_upscaling = nn.Sequential(
54
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55
+ LayerNorm2d(transformer_dim // 4),
56
+ activation(),
57
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58
+ activation(),
59
+ )
60
+ self.output_hypernetworks_mlps = nn.ModuleList(
61
+ [
62
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63
+ for i in range(self.num_mask_tokens)
64
+ ]
65
+ )
66
+
67
+ self.iou_prediction_head = MLP(
68
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ image_embeddings: torch.Tensor,
74
+ image_pe: torch.Tensor,
75
+ sparse_prompt_embeddings: torch.Tensor,
76
+ dense_prompt_embeddings: torch.Tensor,
77
+ multimask_output: bool,
78
+ batch_ind_list: List[int] = None,
79
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ """
81
+ Predict masks given image and prompt embeddings.
82
+
83
+ Arguments:
84
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
85
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
86
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
87
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
88
+ multimask_output (bool): Whether to return multiple masks or a single
89
+ mask.
90
+
91
+ Returns:
92
+ torch.Tensor: batched predicted masks
93
+ torch.Tensor: batched predictions of mask quality
94
+ """
95
+ masks, iou_pred = self.predict_masks(
96
+ image_embeddings=image_embeddings,
97
+ image_pe=image_pe,
98
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
99
+ dense_prompt_embeddings=dense_prompt_embeddings,
100
+ batch_ind_list=batch_ind_list,
101
+ )
102
+
103
+ # Select the correct mask or masks for output
104
+ if multimask_output:
105
+ mask_slice = slice(1, None)
106
+ else:
107
+ mask_slice = slice(0, 1)
108
+ masks = masks[:, mask_slice, :, :]
109
+ iou_pred = iou_pred[:, mask_slice]
110
+
111
+ # Prepare output
112
+ return masks, iou_pred
113
+
114
+ def predict_masks(
115
+ self,
116
+ image_embeddings: torch.Tensor,
117
+ image_pe: torch.Tensor,
118
+ sparse_prompt_embeddings: torch.Tensor,
119
+ dense_prompt_embeddings: torch.Tensor,
120
+ batch_ind_list: List[int],
121
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ """Predicts masks. See 'forward' for more details."""
123
+ # Concatenate output tokens
124
+ if batch_ind_list is None:
125
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
126
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
127
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
128
+
129
+ # Expand per-image data in batch direction to be per-mask
130
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
131
+ src = src + dense_prompt_embeddings
132
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
133
+ b, c, h, w = src.shape
134
+ else:
135
+ num_instances = int(sparse_prompt_embeddings.size(0))
136
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
137
+ output_tokens = output_tokens.unsqueeze(0).expand(num_instances, -1, -1)
138
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
139
+
140
+ # Expand per-image data in batch direction to be per-mask
141
+ image_embeddings = torch.cat([image_embeddings[i].unsqueeze(0).repeat(n, 1, 1, 1) for i, n in enumerate(batch_ind_list)], dim=0)
142
+ src = image_embeddings
143
+ src = src + dense_prompt_embeddings
144
+ pos_src = torch.repeat_interleave(image_pe, num_instances, dim=0)
145
+ b, c, h, w = src.shape
146
+
147
+ # Run the transformer
148
+ hs, src = self.transformer(src, pos_src, tokens)
149
+ iou_token_out = hs[:, 0, :]
150
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
151
+
152
+ # Upscale mask embeddings and predict masks using the mask tokens
153
+ src = src.transpose(1, 2).view(b, c, h, w)
154
+ upscaled_embedding = self.output_upscaling(src)
155
+ hyper_in_list: List[torch.Tensor] = []
156
+ for i in range(self.num_mask_tokens):
157
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
158
+ hyper_in = torch.stack(hyper_in_list, dim=1)
159
+ b, c, h, w = upscaled_embedding.shape
160
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
161
+
162
+ # Generate mask quality predictions
163
+ iou_pred = self.iou_prediction_head(iou_token_out)
164
+
165
+ return masks, iou_pred
166
+
167
+
168
+ # Lightly adapted from
169
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
170
+ class MLP(nn.Module):
171
+ def __init__(
172
+ self,
173
+ input_dim: int,
174
+ hidden_dim: int,
175
+ output_dim: int,
176
+ num_layers: int,
177
+ sigmoid_output: bool = False,
178
+ ) -> None:
179
+ super().__init__()
180
+ self.num_layers = num_layers
181
+ h = [hidden_dim] * (num_layers - 1)
182
+ self.layers = nn.ModuleList(
183
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
184
+ )
185
+ self.sigmoid_output = sigmoid_output
186
+
187
+ def forward(self, x):
188
+ for i, layer in enumerate(self.layers):
189
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
190
+ if self.sigmoid_output:
191
+ x = F.sigmoid(x)
192
+ return x
SAM/modeling/prompt_encoder.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+ from typing import Any, Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class PromptEncoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ embed_dim: int,
20
+ image_embedding_size: Tuple[int, int],
21
+ input_image_size: Tuple[int, int],
22
+ mask_in_chans: int,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ ) -> None:
25
+ """
26
+ Encodes prompts for input to SAM's mask decoder.
27
+
28
+ Arguments:
29
+ embed_dim (int): The prompts' embedding dimension
30
+ image_embedding_size (tuple(int, int)): The spatial size of the
31
+ image embedding, as (H, W).
32
+ input_image_size (int): The padded size of the image as input
33
+ to the image encoder, as (H, W).
34
+ mask_in_chans (int): The number of hidden channels used for
35
+ encoding input masks.
36
+ activation (nn.Module): The activation to use when encoding
37
+ input masks.
38
+ """
39
+ super().__init__()
40
+ self.embed_dim = embed_dim
41
+ self.input_image_size = input_image_size
42
+ self.image_embedding_size = image_embedding_size
43
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44
+
45
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47
+ self.point_embeddings = nn.ModuleList(point_embeddings)
48
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
49
+
50
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51
+ self.mask_downscaling = nn.Sequential(
52
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53
+ LayerNorm2d(mask_in_chans // 4),
54
+ activation(),
55
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56
+ LayerNorm2d(mask_in_chans),
57
+ activation(),
58
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59
+ )
60
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
61
+
62
+ def get_dense_pe(self) -> torch.Tensor:
63
+ """
64
+ Returns the positional encoding used to encode point prompts,
65
+ applied to a dense set of points the shape of the image encoding.
66
+
67
+ Returns:
68
+ torch.Tensor: Positional encoding with shape
69
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
70
+ """
71
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72
+
73
+ def _embed_points(
74
+ self,
75
+ points: torch.Tensor,
76
+ labels: torch.Tensor,
77
+ pad: bool,
78
+ ) -> torch.Tensor:
79
+ """Embeds point prompts."""
80
+ points = points + 0.5 # Shift to center of pixel
81
+ if pad:
82
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84
+ points = torch.cat([points, padding_point], dim=1)
85
+ labels = torch.cat([labels, padding_label], dim=1)
86
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87
+ point_embedding[labels == -1] = 0.0
88
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
89
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
90
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
91
+ return point_embedding
92
+
93
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94
+ """Embeds box prompts."""
95
+ boxes = boxes + 0.5 # Shift to center of pixel
96
+ coords = boxes.reshape(-1, 2, 2)
97
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100
+ return corner_embedding
101
+
102
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103
+ """Embeds mask inputs."""
104
+ mask_embedding = self.mask_downscaling(masks)
105
+ return mask_embedding
106
+
107
+ def _get_batch_size(
108
+ self,
109
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110
+ boxes: Optional[torch.Tensor],
111
+ masks: Optional[torch.Tensor],
112
+ ) -> int:
113
+ """
114
+ Gets the batch size of the output given the batch size of the input prompts.
115
+ """
116
+ if points is not None:
117
+ return points[0].shape[0]
118
+ elif boxes is not None:
119
+ return boxes.shape[0]
120
+ elif masks is not None:
121
+ return masks.shape[0]
122
+ else:
123
+ return 1
124
+
125
+ def _get_device(self) -> torch.device:
126
+ return self.point_embeddings[0].weight.device
127
+
128
+ def forward(
129
+ self,
130
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
131
+ boxes: Optional[torch.Tensor],
132
+ masks: Optional[torch.Tensor],
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ """
135
+ Embeds different types of prompts, returning both sparse and dense
136
+ embeddings.
137
+
138
+ Arguments:
139
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
140
+ and labels to embed.
141
+ boxes (torch.Tensor or none): boxes to embed
142
+ masks (torch.Tensor or none): masks to embed
143
+
144
+ Returns:
145
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
146
+ BxNx(embed_dim), where N is determined by the number of input points
147
+ and boxes.
148
+ torch.Tensor: dense embeddings for the masks, in the shape
149
+ Bx(embed_dim)x(embed_H)x(embed_W)
150
+ """
151
+ bs = self._get_batch_size(points, boxes, masks)
152
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
153
+ if points is not None:
154
+ coords, labels = points
155
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157
+ if boxes is not None:
158
+ box_embeddings = self._embed_boxes(boxes)
159
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
160
+
161
+ if masks is not None:
162
+ dense_embeddings = self._embed_masks(masks)
163
+ else:
164
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
165
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
166
+ )
167
+
168
+ return sparse_embeddings, dense_embeddings
169
+
170
+
171
+ class PositionEmbeddingRandom(nn.Module):
172
+ """
173
+ Positional encoding using random spatial frequencies.
174
+ """
175
+
176
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
177
+ super().__init__()
178
+ if scale is None or scale <= 0.0:
179
+ scale = 1.0
180
+ self.register_buffer(
181
+ "positional_encoding_gaussian_matrix",
182
+ scale * torch.randn((2, num_pos_feats)),
183
+ )
184
+
185
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186
+ """Positionally encode points that are normalized to [0,1]."""
187
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188
+ coords = 2 * coords - 1
189
+ coords = coords @ self.positional_encoding_gaussian_matrix
190
+ coords = 2 * np.pi * coords
191
+ # outputs d_1 x ... x d_n x C shape
192
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
193
+
194
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
195
+ """Generate positional encoding for a grid of the specified size."""
196
+ h, w = size
197
+ device: Any = self.positional_encoding_gaussian_matrix.device
198
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
199
+ y_embed = grid.cumsum(dim=0) - 0.5
200
+ x_embed = grid.cumsum(dim=1) - 0.5
201
+ y_embed = y_embed / h
202
+ x_embed = x_embed / w
203
+
204
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
205
+ return pe.permute(2, 0, 1) # C x H x W
206
+
207
+ def forward_with_coords(
208
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
209
+ ) -> torch.Tensor:
210
+ """Positionally encode points that are not normalized to [0,1]."""
211
+ coords = coords_input.clone()
212
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
213
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
214
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
SAM/modeling/sam.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import Any, Dict, List, Tuple
12
+
13
+ from .mask_decoder import MaskDecoder
14
+ from .prompt_encoder import PromptEncoder
15
+ from .image_encoder import ImageEncoderViT
16
+
17
+ class Sam(nn.Module):
18
+ mask_threshold: float = 0.0
19
+ image_format: str = "RGB"
20
+
21
+ def __init__(
22
+ self,
23
+ image_encoder: ImageEncoderViT,
24
+ prompt_encoder: PromptEncoder,
25
+ mask_decoder: MaskDecoder,
26
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
27
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
28
+ ) -> None:
29
+ """
30
+ SAM predicts object masks from an image and input prompts.
31
+
32
+ Arguments:
33
+ image_encoder (ImageEncoderViT): The backbone used to encode the
34
+ image into image embeddings that allow for efficient mask prediction.
35
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
36
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
37
+ and encoded prompts.
38
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
39
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
40
+ """
41
+ super().__init__()
42
+ self.image_encoder = image_encoder
43
+ self.prompt_encoder = prompt_encoder
44
+ self.mask_decoder = mask_decoder
45
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
46
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
47
+
48
+ @property
49
+ def device(self) -> Any:
50
+ return self.pixel_mean.device
51
+
52
+ @torch.no_grad()
53
+ def forward(
54
+ self,
55
+ batched_input: List[Dict[str, Any]],
56
+ multimask_output: bool,
57
+ ) -> List[Dict[str, torch.Tensor]]:
58
+ """
59
+ Predicts masks end-to-end from provided images and prompts.
60
+ If prompts are not known in advance, using SamPredictor is
61
+ recommended over calling the model directly.
62
+
63
+ Arguments:
64
+ batched_input (list(dict)): A list over input images, each a
65
+ dictionary with the following keys. A prompt key can be
66
+ excluded if it is not present.
67
+ 'image': The image as a torch tensor in 3xHxW format,
68
+ already transformed for input to the model.
69
+ 'original_size': (tuple(int, int)) The original size of
70
+ the image before transformation, as (H, W).
71
+ 'point_coords': (torch.Tensor) Batched point prompts for
72
+ this image, with shape BxNx2. Already transformed to the
73
+ input frame of the model.
74
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
75
+ with shape BxN.
76
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
77
+ Already transformed to the input frame of the model.
78
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
79
+ in the form Bx1xHxW.
80
+ multimask_output (bool): Whether the model should predict multiple
81
+ disambiguating masks, or return a single mask.
82
+
83
+ Returns:
84
+ (list(dict)): A list over input images, where each element is
85
+ as dictionary with the following keys.
86
+ 'masks': (torch.Tensor) Batched binary mask predictions,
87
+ with shape BxCxHxW, where B is the number of input prompts,
88
+ C is determined by multimask_output, and (H, W) is the
89
+ original size of the image.
90
+ 'iou_predictions': (torch.Tensor) The model's predictions
91
+ of mask quality, in shape BxC.
92
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
93
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
94
+ to subsequent iterations of prediction.
95
+ """
96
+ spase_embed_list = []
97
+ dense_embed_list = []
98
+ batch_ind_list = []
99
+ input_images_list = []
100
+ for idx, image_record in enumerate(batched_input):
101
+ input_images_list.append(self.preprocess(image_record["image"]))
102
+ if "point_coords" in image_record:
103
+ points = (image_record["point_coords"], image_record["point_labels"])
104
+ else:
105
+ points = None
106
+ sparse_embed, dense_embed = self.prompt_encoder(
107
+ points=points,
108
+ boxes=image_record.get("boxes", None),
109
+ masks=image_record.get("mask_inputs", None),
110
+ )
111
+ assert len(sparse_embed) == len(dense_embed)
112
+ spase_embed_list.append(sparse_embed)
113
+ dense_embed_list.append(dense_embed)
114
+ batch_ind_list.append(len(sparse_embed))
115
+
116
+ image_embeddings = self.image_encoder(torch.stack(input_images_list, dim=0))
117
+ sparse_embed = torch.cat(spase_embed_list)
118
+ dense_embed = torch.cat(dense_embed_list)
119
+ low_res_masks, iou_predictions = self.mask_decoder(
120
+ image_embeddings=image_embeddings,
121
+ image_pe=self.prompt_encoder.get_dense_pe(),
122
+ sparse_prompt_embeddings=sparse_embed,
123
+ dense_prompt_embeddings=dense_embed,
124
+ multimask_output=multimask_output,
125
+ batch_ind_list=batch_ind_list,
126
+ )
127
+ low_res_masks = torch.split(low_res_masks, batch_ind_list, dim=0)
128
+ iou_predictions = torch.split(iou_predictions, batch_ind_list, dim=0)
129
+ outputs = []
130
+ for image_record, low_res_mask, iou_prediction in zip(batched_input, low_res_masks, iou_predictions):
131
+ masks = self.postprocess_masks(
132
+ low_res_mask,
133
+ input_size=image_record["image"].shape[-2:],
134
+ original_size=image_record["original_size"],
135
+ )
136
+ masks = masks > self.mask_threshold
137
+ outputs.append(
138
+ {
139
+ "masks": masks,
140
+ "iou_predictions": iou_prediction,
141
+ "low_res_logits": low_res_mask,
142
+ }
143
+ )
144
+ return outputs
145
+
146
+ def postprocess_masks(
147
+ self,
148
+ masks: torch.Tensor,
149
+ input_size: Tuple[int, ...],
150
+ original_size: Tuple[int, ...],
151
+ ) -> torch.Tensor:
152
+ """
153
+ Remove padding and upscale masks to the original image size.
154
+
155
+ Arguments:
156
+ masks (torch.Tensor): Batched masks from the mask_decoder,
157
+ in BxCxHxW format.
158
+ input_size (tuple(int, int)): The size of the image input to the
159
+ model, in (H, W) format. Used to remove padding.
160
+ original_size (tuple(int, int)): The original size of the image
161
+ before resizing for input to the model, in (H, W) format.
162
+
163
+ Returns:
164
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
165
+ is given by original_size.
166
+ """
167
+ masks = F.interpolate(
168
+ masks,
169
+ (self.image_encoder.img_size, self.image_encoder.img_size),
170
+ mode="bilinear",
171
+ align_corners=False,
172
+ )
173
+ masks = masks[..., : input_size[0], : input_size[1]]
174
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
175
+ return masks
176
+
177
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
178
+ """Normalize pixel values and pad to a square input."""
179
+ # Normalize colors
180
+ x = (x - self.pixel_mean) / self.pixel_std
181
+
182
+ # Pad
183
+ h, w = x.shape[-2:]
184
+ padh = self.image_encoder.img_size - h
185
+ padw = self.image_encoder.img_size - w
186
+ x = F.pad(x, (0, padw, 0, padh))
187
+ return x
SAM/modeling/transformer.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ import math
11
+ from typing import Tuple, Type
12
+
13
+ from .common import MLPBlock
14
+
15
+
16
+ class TwoWayTransformer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ depth: int,
20
+ embedding_dim: int,
21
+ num_heads: int,
22
+ mlp_dim: int,
23
+ activation: Type[nn.Module] = nn.ReLU,
24
+ attention_downsample_rate: int = 2,
25
+ ) -> None:
26
+ """
27
+ A transformer decoder that attends to an input image using
28
+ queries whose positional embedding is supplied.
29
+
30
+ Args:
31
+ depth (int): number of layers in the transformer
32
+ embedding_dim (int): the channel dimension for the input embeddings
33
+ num_heads (int): the number of heads for multihead attention. Must
34
+ divide embedding_dim
35
+ mlp_dim (int): the channel dimension internal to the MLP block
36
+ activation (nn.Module): the activation to use in the MLP block
37
+ """
38
+ super().__init__()
39
+ self.depth = depth
40
+ self.embedding_dim = embedding_dim
41
+ self.num_heads = num_heads
42
+ self.mlp_dim = mlp_dim
43
+ self.layers = nn.ModuleList()
44
+
45
+ for i in range(depth):
46
+ self.layers.append(
47
+ TwoWayAttentionBlock(
48
+ embedding_dim=embedding_dim,
49
+ num_heads=num_heads,
50
+ mlp_dim=mlp_dim,
51
+ activation=activation,
52
+ attention_downsample_rate=attention_downsample_rate,
53
+ skip_first_layer_pe=(i == 0),
54
+ )
55
+ )
56
+
57
+ self.final_attn_token_to_image = Attention(
58
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59
+ )
60
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
61
+
62
+ def forward(
63
+ self,
64
+ image_embedding: Tensor,
65
+ image_pe: Tensor,
66
+ point_embedding: Tensor,
67
+ ) -> Tuple[Tensor, Tensor]:
68
+ """
69
+ Args:
70
+ image_embedding (torch.Tensor): image to attend to. Should be shape
71
+ B x embedding_dim x h x w for any h and w.
72
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
73
+ have the same shape as image_embedding.
74
+ point_embedding (torch.Tensor): the embedding to add to the query points.
75
+ Must have shape B x N_points x embedding_dim for any N_points.
76
+
77
+ Returns:
78
+ torch.Tensor: the processed point_embedding
79
+ torch.Tensor: the processed image_embedding
80
+ """
81
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82
+ bs, c, h, w = image_embedding.shape
83
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
85
+
86
+ # Prepare queries
87
+ queries = point_embedding
88
+ keys = image_embedding
89
+
90
+ # Apply transformer blocks and final layernorm
91
+ for layer in self.layers:
92
+ queries, keys = layer(
93
+ queries=queries,
94
+ keys=keys,
95
+ query_pe=point_embedding,
96
+ key_pe=image_pe,
97
+ )
98
+
99
+ # Apply the final attention layer from the points to the image
100
+ q = queries + point_embedding
101
+ k = keys + image_pe
102
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103
+ queries = queries + attn_out
104
+ queries = self.norm_final_attn(queries)
105
+
106
+ return queries, keys
107
+
108
+
109
+ class TwoWayAttentionBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ embedding_dim: int,
113
+ num_heads: int,
114
+ mlp_dim: int = 2048,
115
+ activation: Type[nn.Module] = nn.ReLU,
116
+ attention_downsample_rate: int = 2,
117
+ skip_first_layer_pe: bool = False,
118
+ ) -> None:
119
+ """
120
+ A transformer block with four layers: (1) self-attention of sparse
121
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
123
+ inputs.
124
+
125
+ Arguments:
126
+ embedding_dim (int): the channel dimension of the embeddings
127
+ num_heads (int): the number of heads in the attention layers
128
+ mlp_dim (int): the hidden dimension of the mlp block
129
+ activation (nn.Module): the activation of the mlp block
130
+ skip_first_layer_pe (bool): skip the PE on the first layer
131
+ """
132
+ super().__init__()
133
+ self.self_attn = Attention(embedding_dim, num_heads)
134
+ self.norm1 = nn.LayerNorm(embedding_dim)
135
+
136
+ self.cross_attn_token_to_image = Attention(
137
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142
+ self.norm3 = nn.LayerNorm(embedding_dim)
143
+
144
+ self.norm4 = nn.LayerNorm(embedding_dim)
145
+ self.cross_attn_image_to_token = Attention(
146
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
+ )
148
+
149
+ self.skip_first_layer_pe = skip_first_layer_pe
150
+
151
+ def forward(
152
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153
+ ) -> Tuple[Tensor, Tensor]:
154
+ # Self attention block
155
+ if self.skip_first_layer_pe:
156
+ queries = self.self_attn(q=queries, k=queries, v=queries)
157
+ else:
158
+ q = queries + query_pe
159
+ attn_out = self.self_attn(q=q, k=q, v=queries)
160
+ queries = queries + attn_out
161
+ queries = self.norm1(queries)
162
+
163
+ # Cross attention block, tokens attending to image embedding
164
+ q = queries + query_pe
165
+ k = keys + key_pe
166
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167
+ queries = queries + attn_out
168
+ queries = self.norm2(queries)
169
+
170
+ # MLP block
171
+ mlp_out = self.mlp(queries)
172
+ queries = queries + mlp_out
173
+ queries = self.norm3(queries)
174
+
175
+ # Cross attention block, image embedding attending to tokens
176
+ q = queries + query_pe
177
+ k = keys + key_pe
178
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179
+ keys = keys + attn_out
180
+ keys = self.norm4(keys)
181
+
182
+ return queries, keys
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """
187
+ An attention layer that allows for downscaling the size of the embedding
188
+ after projection to queries, keys, and values.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ embedding_dim: int,
194
+ num_heads: int,
195
+ downsample_rate: int = 1,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.embedding_dim = embedding_dim
199
+ self.internal_dim = embedding_dim // downsample_rate
200
+ self.num_heads = num_heads
201
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202
+
203
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207
+
208
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209
+ b, n, c = x.shape
210
+ x = x.reshape(b, n, num_heads, c // num_heads)
211
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212
+
213
+ def _recombine_heads(self, x: Tensor) -> Tensor:
214
+ b, n_heads, n_tokens, c_per_head = x.shape
215
+ x = x.transpose(1, 2)
216
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217
+
218
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219
+ # Input projections
220
+ q = self.q_proj(q)
221
+ k = self.k_proj(k)
222
+ v = self.v_proj(v)
223
+
224
+ # Separate into heads
225
+ q = self._separate_heads(q, self.num_heads)
226
+ k = self._separate_heads(k, self.num_heads)
227
+ v = self._separate_heads(v, self.num_heads)
228
+
229
+ # Attention
230
+ _, _, _, c_per_head = q.shape
231
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232
+ attn = attn / math.sqrt(c_per_head)
233
+ attn = torch.softmax(attn, dim=-1)
234
+
235
+ # Get output
236
+ out = attn @ v
237
+ out = self._recombine_heads(out)
238
+ out = self.out_proj(out)
239
+
240
+ return out
SAM/predictor.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from .modeling import Sam
11
+
12
+ from typing import Optional, Tuple
13
+
14
+ from .utils.transforms import ResizeLongestSide
15
+
16
+
17
+ class SamPredictor:
18
+ def __init__(
19
+ self,
20
+ sam_model: Sam,
21
+ ) -> None:
22
+ """
23
+ Uses SAM to calculate the image embedding for an image, and then
24
+ allow repeated, efficient mask prediction given prompts.
25
+
26
+ Arguments:
27
+ sam_model (Sam): The model to use for mask prediction.
28
+ """
29
+ super().__init__()
30
+ self.model = sam_model
31
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
32
+ self.reset_image()
33
+
34
+ def set_image(
35
+ self,
36
+ image: np.ndarray,
37
+ image_format: str = "RGB",
38
+ ) -> None:
39
+ """
40
+ Calculates the image embeddings for the provided image, allowing
41
+ masks to be predicted with the 'predict' method.
42
+
43
+ Arguments:
44
+ image (np.ndarray): The image for calculating masks. Expects an
45
+ image in HWC uint8 format, with pixel values in [0, 255].
46
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
47
+ """
48
+ assert image_format in [
49
+ "RGB",
50
+ "BGR",
51
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
52
+ if image_format != self.model.image_format:
53
+ image = image[..., ::-1]
54
+
55
+ # Transform the image to the form expected by the model
56
+ input_image = self.transform.apply_image(image)
57
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
58
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59
+
60
+ self.set_torch_image(input_image_torch, image.shape[:2])
61
+
62
+ @torch.no_grad()
63
+ def set_torch_image(
64
+ self,
65
+ transformed_image: torch.Tensor,
66
+ original_image_size: Tuple[int, ...],
67
+ ) -> None:
68
+ """
69
+ Calculates the image embeddings for the provided image, allowing
70
+ masks to be predicted with the 'predict' method. Expects the input
71
+ image to be already transformed to the format expected by the model.
72
+
73
+ Arguments:
74
+ transformed_image (torch.Tensor): The input image, with shape
75
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
76
+ original_image_size (tuple(int, int)): The size of the image
77
+ before transformation, in (H, W) format.
78
+ """
79
+ assert (
80
+ len(transformed_image.shape) == 4
81
+ and transformed_image.shape[1] == 3
82
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84
+ self.reset_image()
85
+
86
+ self.original_size = original_image_size
87
+ self.input_size = tuple(transformed_image.shape[-2:])
88
+ input_image = self.model.preprocess(transformed_image)
89
+ self.features = self.model.image_encoder(input_image)
90
+ self.is_image_set = True
91
+
92
+ def predict(
93
+ self,
94
+ point_coords: Optional[np.ndarray] = None,
95
+ point_labels: Optional[np.ndarray] = None,
96
+ box: Optional[np.ndarray] = None,
97
+ mask_input: Optional[np.ndarray] = None,
98
+ multimask_output: bool = True,
99
+ return_logits: bool = False,
100
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101
+ """
102
+ Predict masks for the given input prompts, using the currently set image.
103
+
104
+ Arguments:
105
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
106
+ model. Each point is in (X,Y) in pixels.
107
+ point_labels (np.ndarray or None): A length N array of labels for the
108
+ point prompts. 1 indicates a foreground point and 0 indicates a
109
+ background point.
110
+ box (np.ndarray or None): A length 4 array given a box prompt to the
111
+ model, in XYXY format.
112
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
113
+ coming from a previous prediction iteration. Has form 1xHxW, where
114
+ for SAM, H=W=256.
115
+ multimask_output (bool): If true, the model will return three masks.
116
+ For ambiguous input prompts (such as a single click), this will often
117
+ produce better masks than a single prediction. If only a single
118
+ mask is needed, the model's predicted quality score can be used
119
+ to select the best mask. For non-ambiguous prompts, such as multiple
120
+ input prompts, multimask_output=False can give better results.
121
+ return_logits (bool): If true, returns un-thresholded masks logits
122
+ instead of a binary mask.
123
+
124
+ Returns:
125
+ (np.ndarray): The output masks in CxHxW format, where C is the
126
+ number of masks, and (H, W) is the original image size.
127
+ (np.ndarray): An array of length C containing the model's
128
+ predictions for the quality of each mask.
129
+ (np.ndarray): An array of shape CxHxW, where C is the number
130
+ of masks and H=W=256. These low resolution logits can be passed to
131
+ a subsequent iteration as mask input.
132
+ """
133
+ if not self.is_image_set:
134
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
135
+
136
+ # Transform input prompts
137
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
138
+ if point_coords is not None:
139
+ assert (
140
+ point_labels is not None
141
+ ), "point_labels must be supplied if point_coords is supplied."
142
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
143
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
144
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
145
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
146
+ if box is not None:
147
+ box = self.transform.apply_boxes(box, self.original_size)
148
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
149
+ box_torch = box_torch[None, :]
150
+ if mask_input is not None:
151
+ mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
152
+ mask_input_torch = mask_input_torch[None, :, :, :]
153
+
154
+ masks, iou_predictions, low_res_masks = self.predict_torch(
155
+ coords_torch,
156
+ labels_torch,
157
+ box_torch,
158
+ mask_input_torch,
159
+ multimask_output,
160
+ return_logits=return_logits,
161
+ )
162
+
163
+ masks_np = masks[0].detach().cpu().numpy()
164
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
165
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
166
+ return masks_np, iou_predictions_np, low_res_masks_np
167
+
168
+ @torch.no_grad()
169
+ def predict_torch(
170
+ self,
171
+ point_coords: Optional[torch.Tensor],
172
+ point_labels: Optional[torch.Tensor],
173
+ boxes: Optional[torch.Tensor] = None,
174
+ mask_input: Optional[torch.Tensor] = None,
175
+ multimask_output: bool = True,
176
+ return_logits: bool = False,
177
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ """
179
+ Predict masks for the given input prompts, using the currently set image.
180
+ Input prompts are batched torch tensors and are expected to already be
181
+ transformed to the input frame using ResizeLongestSide.
182
+
183
+ Arguments:
184
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
185
+ model. Each point is in (X,Y) in pixels.
186
+ point_labels (torch.Tensor or None): A BxN array of labels for the
187
+ point prompts. 1 indicates a foreground point and 0 indicates a
188
+ background point.
189
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
190
+ model, in XYXY format.
191
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
192
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
193
+ for SAM, H=W=256. Masks returned by a previous iteration of the
194
+ predict method do not need further transformation.
195
+ multimask_output (bool): If true, the model will return three masks.
196
+ For ambiguous input prompts (such as a single click), this will often
197
+ produce better masks than a single prediction. If only a single
198
+ mask is needed, the model's predicted quality score can be used
199
+ to select the best mask. For non-ambiguous prompts, such as multiple
200
+ input prompts, multimask_output=False can give better results.
201
+ return_logits (bool): If true, returns un-thresholded masks logits
202
+ instead of a binary mask.
203
+
204
+ Returns:
205
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
206
+ number of masks, and (H, W) is the original image size.
207
+ (torch.Tensor): An array of shape BxC containing the model's
208
+ predictions for the quality of each mask.
209
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
210
+ of masks and H=W=256. These low res logits can be passed to
211
+ a subsequent iteration as mask input.
212
+ """
213
+ if not self.is_image_set:
214
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
215
+
216
+ if point_coords is not None:
217
+ points = (point_coords, point_labels)
218
+ else:
219
+ points = None
220
+
221
+ # Embed prompts
222
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
223
+ points=points,
224
+ boxes=boxes,
225
+ masks=mask_input,
226
+ )
227
+
228
+ # Predict masks
229
+ low_res_masks, iou_predictions = self.model.mask_decoder(
230
+ image_embeddings=self.features,
231
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
232
+ sparse_prompt_embeddings=sparse_embeddings,
233
+ dense_prompt_embeddings=dense_embeddings,
234
+ multimask_output=multimask_output,
235
+ )
236
+
237
+ # Upscale the masks to the original image resolution
238
+ masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
239
+
240
+ if not return_logits:
241
+ masks = masks > self.model.mask_threshold
242
+
243
+ return masks, iou_predictions, low_res_masks
244
+
245
+ def get_image_embedding(self) -> torch.Tensor:
246
+ """
247
+ Returns the image embeddings for the currently set image, with
248
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
249
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
250
+ """
251
+ if not self.is_image_set:
252
+ raise RuntimeError(
253
+ "An image must be set with .set_image(...) to generate an embedding."
254
+ )
255
+ assert self.features is not None, "Features must exist if an image has been set."
256
+ return self.features
257
+
258
+ @property
259
+ def device(self) -> torch.device:
260
+ return self.model.device
261
+
262
+ def reset_image(self) -> None:
263
+ """Resets the currently set image."""
264
+ self.is_image_set = False
265
+ self.features = None
266
+ self.orig_h = None
267
+ self.orig_w = None
268
+ self.input_h = None
269
+ self.input_w = None
SAM/utils/__pycache__/amg.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
SAM/utils/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (3.93 kB). View file
 
SAM/utils/amg.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ import math
11
+ from copy import deepcopy
12
+ from itertools import product
13
+ from typing import Any, Dict, Generator, ItemsView, List, Tuple
14
+
15
+
16
+ class MaskData:
17
+ """
18
+ A structure for storing masks and their related data in batched format.
19
+ Implements basic filtering and concatenation.
20
+ """
21
+
22
+ def __init__(self, **kwargs) -> None:
23
+ for v in kwargs.values():
24
+ assert isinstance(
25
+ v, (list, np.ndarray, torch.Tensor)
26
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
27
+ self._stats = dict(**kwargs)
28
+
29
+ def __setitem__(self, key: str, item: Any) -> None:
30
+ assert isinstance(
31
+ item, (list, np.ndarray, torch.Tensor)
32
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
33
+ self._stats[key] = item
34
+
35
+ def __delitem__(self, key: str) -> None:
36
+ del self._stats[key]
37
+
38
+ def __getitem__(self, key: str) -> Any:
39
+ return self._stats[key]
40
+
41
+ def items(self) -> ItemsView[str, Any]:
42
+ return self._stats.items()
43
+
44
+ def filter(self, keep: torch.Tensor) -> None:
45
+ for k, v in self._stats.items():
46
+ if v is None:
47
+ self._stats[k] = None
48
+ elif isinstance(v, torch.Tensor):
49
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50
+ elif isinstance(v, np.ndarray):
51
+ self._stats[k] = v[keep.detach().cpu().numpy()]
52
+ elif isinstance(v, list) and keep.dtype == torch.bool:
53
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54
+ elif isinstance(v, list):
55
+ self._stats[k] = [v[i] for i in keep]
56
+ else:
57
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58
+
59
+ def cat(self, new_stats: "MaskData") -> None:
60
+ for k, v in new_stats.items():
61
+ if k not in self._stats or self._stats[k] is None:
62
+ self._stats[k] = deepcopy(v)
63
+ elif isinstance(v, torch.Tensor):
64
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65
+ elif isinstance(v, np.ndarray):
66
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67
+ elif isinstance(v, list):
68
+ self._stats[k] = self._stats[k] + deepcopy(v)
69
+ else:
70
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71
+
72
+ def to_numpy(self) -> None:
73
+ for k, v in self._stats.items():
74
+ if isinstance(v, torch.Tensor):
75
+ self._stats[k] = v.detach().cpu().numpy()
76
+
77
+
78
+ def is_box_near_crop_edge(
79
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80
+ ) -> torch.Tensor:
81
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
82
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88
+ return torch.any(near_crop_edge, dim=1)
89
+
90
+
91
+ def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92
+ box_xywh = deepcopy(box_xyxy)
93
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
94
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
95
+ return box_xywh
96
+
97
+
98
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99
+ assert len(args) > 0 and all(
100
+ len(a) == len(args[0]) for a in args
101
+ ), "Batched iteration must have inputs of all the same size."
102
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103
+ for b in range(n_batches):
104
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105
+
106
+
107
+ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108
+ """
109
+ Encodes masks to an uncompressed RLE, in the format expected by
110
+ pycoco tools.
111
+ """
112
+ # Put in fortran order and flatten h,w
113
+ b, h, w = tensor.shape
114
+ tensor = tensor.permute(0, 2, 1).flatten(1)
115
+
116
+ # Compute change indices
117
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
118
+ change_indices = diff.nonzero()
119
+
120
+ # Encode run length
121
+ out = []
122
+ for i in range(b):
123
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124
+ cur_idxs = torch.cat(
125
+ [
126
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127
+ cur_idxs + 1,
128
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129
+ ]
130
+ )
131
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132
+ counts = [] if tensor[i, 0] == 0 else [0]
133
+ counts.extend(btw_idxs.detach().cpu().tolist())
134
+ out.append({"size": [h, w], "counts": counts})
135
+ return out
136
+
137
+
138
+ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139
+ """Compute a binary mask from an uncompressed RLE."""
140
+ h, w = rle["size"]
141
+ mask = np.empty(h * w, dtype=bool)
142
+ idx = 0
143
+ parity = False
144
+ for count in rle["counts"]:
145
+ mask[idx : idx + count] = parity
146
+ idx += count
147
+ parity ^= True
148
+ mask = mask.reshape(w, h)
149
+ return mask.transpose() # Put in C order
150
+
151
+
152
+ def area_from_rle(rle: Dict[str, Any]) -> int:
153
+ return sum(rle["counts"][1::2])
154
+
155
+
156
+ def calculate_stability_score(
157
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158
+ ) -> torch.Tensor:
159
+ """
160
+ Computes the stability score for a batch of masks. The stability
161
+ score is the IoU between the binary masks obtained by thresholding
162
+ the predicted mask logits at high and low values.
163
+ """
164
+ # One mask is always contained inside the other.
165
+ # Save memory by preventing unnecessary cast to torch.int64
166
+ intersections = (
167
+ (masks > (mask_threshold + threshold_offset))
168
+ .sum(-1, dtype=torch.int16)
169
+ .sum(-1, dtype=torch.int32)
170
+ )
171
+ unions = (
172
+ (masks > (mask_threshold - threshold_offset))
173
+ .sum(-1, dtype=torch.int16)
174
+ .sum(-1, dtype=torch.int32)
175
+ )
176
+ return intersections / unions
177
+
178
+
179
+ def build_point_grid(n_per_side: int) -> np.ndarray:
180
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181
+ offset = 1 / (2 * n_per_side)
182
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186
+ return points
187
+
188
+
189
+ def build_all_layer_point_grids(
190
+ n_per_side: int, n_layers: int, scale_per_layer: int
191
+ ) -> List[np.ndarray]:
192
+ """Generates point grids for all crop layers."""
193
+ points_by_layer = []
194
+ for i in range(n_layers + 1):
195
+ n_points = int(n_per_side / (scale_per_layer**i))
196
+ points_by_layer.append(build_point_grid(n_points))
197
+ return points_by_layer
198
+
199
+
200
+ def generate_crop_boxes(
201
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202
+ ) -> Tuple[List[List[int]], List[int]]:
203
+ """
204
+ Generates a list of crop boxes of different sizes. Each layer
205
+ has (2**i)**2 boxes for the ith layer.
206
+ """
207
+ crop_boxes, layer_idxs = [], []
208
+ im_h, im_w = im_size
209
+ short_side = min(im_h, im_w)
210
+
211
+ # Original image
212
+ crop_boxes.append([0, 0, im_w, im_h])
213
+ layer_idxs.append(0)
214
+
215
+ def crop_len(orig_len, n_crops, overlap):
216
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217
+
218
+ for i_layer in range(n_layers):
219
+ n_crops_per_side = 2 ** (i_layer + 1)
220
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221
+
222
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
223
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
224
+
225
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227
+
228
+ # Crops in XYWH format
229
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
230
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231
+ crop_boxes.append(box)
232
+ layer_idxs.append(i_layer + 1)
233
+
234
+ return crop_boxes, layer_idxs
235
+
236
+
237
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238
+ x0, y0, _, _ = crop_box
239
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240
+ # Check if boxes has a channel dimension
241
+ if len(boxes.shape) == 3:
242
+ offset = offset.unsqueeze(1)
243
+ return boxes + offset
244
+
245
+
246
+ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247
+ x0, y0, _, _ = crop_box
248
+ offset = torch.tensor([[x0, y0]], device=points.device)
249
+ # Check if points has a channel dimension
250
+ if len(points.shape) == 3:
251
+ offset = offset.unsqueeze(1)
252
+ return points + offset
253
+
254
+
255
+ def uncrop_masks(
256
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257
+ ) -> torch.Tensor:
258
+ x0, y0, x1, y1 = crop_box
259
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260
+ return masks
261
+ # Coordinate transform masks
262
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
264
+ return torch.nn.functional.pad(masks, pad, value=0)
265
+
266
+
267
+ def remove_small_regions(
268
+ mask: np.ndarray, area_thresh: float, mode: str
269
+ ) -> Tuple[np.ndarray, bool]:
270
+ """
271
+ Removes small disconnected regions and holes in a mask. Returns the
272
+ mask and an indicator of if the mask has been modified.
273
+ """
274
+ import cv2 # type: ignore
275
+
276
+ assert mode in ["holes", "islands"]
277
+ correct_holes = mode == "holes"
278
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
279
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280
+ sizes = stats[:, -1][1:] # Row 0 is background label
281
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282
+ if len(small_regions) == 0:
283
+ return mask, False
284
+ fill_labels = [0] + small_regions
285
+ if not correct_holes:
286
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287
+ # If every region is below threshold, keep largest
288
+ if len(fill_labels) == 0:
289
+ fill_labels = [int(np.argmax(sizes)) + 1]
290
+ mask = np.isin(regions, fill_labels)
291
+ return mask, True
292
+
293
+
294
+ def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295
+ from pycocotools import mask as mask_utils # type: ignore
296
+
297
+ h, w = uncompressed_rle["size"]
298
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300
+ return rle
301
+
302
+
303
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304
+ """
305
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307
+ """
308
+ # torch.max below raises an error on empty inputs, just skip in this case
309
+ if torch.numel(masks) == 0:
310
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311
+
312
+ # Normalize shape to CxHxW
313
+ shape = masks.shape
314
+ h, w = shape[-2:]
315
+ if len(shape) > 2:
316
+ masks = masks.flatten(0, -3)
317
+ else:
318
+ masks = masks.unsqueeze(0)
319
+
320
+ # Get top and bottom edges
321
+ in_height, _ = torch.max(masks, dim=-1)
322
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324
+ in_height_coords = in_height_coords + h * (~in_height)
325
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
326
+
327
+ # Get left and right edges
328
+ in_width, _ = torch.max(masks, dim=-2)
329
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
331
+ in_width_coords = in_width_coords + w * (~in_width)
332
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
333
+
334
+ # If the mask is empty the right edge will be to the left of the left edge.
335
+ # Replace these boxes with [0, 0, 0, 0]
336
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338
+ out = out * (~empty_filter).unsqueeze(-1)
339
+
340
+ # Return to original shape
341
+ if len(shape) > 2:
342
+ out = out.reshape(*shape[:-2], 4)
343
+ else:
344
+ out = out[0]
345
+
346
+ return out
SAM/utils/transforms.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11
+
12
+ from copy import deepcopy
13
+ from typing import Tuple
14
+
15
+
16
+ class ResizeLongestSide:
17
+ """
18
+ Resizes images to the longest side 'target_length', as well as provides
19
+ methods for resizing coordinates and boxes. Provides methods for
20
+ transforming both numpy array and batched torch tensors.
21
+ """
22
+
23
+ def __init__(self, target_length: int) -> None:
24
+ self.target_length = target_length
25
+
26
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
27
+ """
28
+ Expects a numpy array with shape HxWxC in uint8 format.
29
+ """
30
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31
+ return np.array(resize(to_pil_image(image), target_size))
32
+
33
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34
+ """
35
+ Expects a numpy array of length 2 in the final dimension. Requires the
36
+ original image size in (H, W) format.
37
+ """
38
+ old_h, old_w = original_size
39
+ new_h, new_w = self.get_preprocess_shape(
40
+ original_size[0], original_size[1], self.target_length
41
+ )
42
+ coords = deepcopy(coords).astype(float)
43
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
44
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
45
+ return coords
46
+
47
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48
+ """
49
+ Expects a numpy array shape Bx4. Requires the original image size
50
+ in (H, W) format.
51
+ """
52
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53
+ return boxes.reshape(-1, 4)
54
+
55
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Expects batched images with shape BxCxHxW and float format. This
58
+ transformation may not exactly match apply_image. apply_image is
59
+ the transformation expected by the model.
60
+ """
61
+ # Expects an image in BCHW format. May not exactly match apply_image.
62
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63
+ return F.interpolate(
64
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
65
+ )
66
+
67
+ def apply_coords_torch(
68
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
69
+ ) -> torch.Tensor:
70
+ """
71
+ Expects a torch tensor with length 2 in the last dimension. Requires the
72
+ original image size in (H, W) format.
73
+ """
74
+ old_h, old_w = original_size
75
+ new_h, new_w = self.get_preprocess_shape(
76
+ original_size[0], original_size[1], self.target_length
77
+ )
78
+ coords = deepcopy(coords).to(torch.float)
79
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
80
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
81
+ return coords
82
+
83
+ def apply_boxes_torch(
84
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85
+ ) -> torch.Tensor:
86
+ """
87
+ Expects a torch tensor with shape Bx4. Requires the original image
88
+ size in (H, W) format.
89
+ """
90
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91
+ return boxes.reshape(-1, 4)
92
+
93
+ @staticmethod
94
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95
+ """
96
+ Compute the output size given input size and target long side length.
97
+ """
98
+ scale = long_side_length * 1.0 / max(oldh, oldw)
99
+ newh, neww = oldh * scale, oldw * scale
100
+ neww = int(neww + 0.5)
101
+ newh = int(newh + 0.5)
102
+ return (newh, neww)
__pycache__/evaluate.cpython-310.pyc ADDED
Binary file (3.59 kB). View file
 
__pycache__/load_nvos.cpython-310.pyc ADDED
Binary file (4.75 kB). View file
 
app.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # gradio demo
7
+ # --------------------------------------------------------
8
+ import argparse
9
+ import gradio
10
+ import os
11
+ import torch
12
+ import numpy as np
13
+ import tempfile
14
+ import functools
15
+ import trimesh
16
+ import copy
17
+ from scipy.spatial.transform import Rotation
18
+
19
+ from dust3r.inference import inference, load_model
20
+ from dust3r.image_pairs import make_pairs
21
+ from dust3r.utils.image import load_images, rgb
22
+ from dust3r.utils.device import to_numpy
23
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
24
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
25
+
26
+ import matplotlib.pyplot as plt
27
+ plt.ion()
28
+
29
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
30
+ batch_size = 1
31
+
32
+ def show_mask(mask, ax, random_color=False):
33
+ if random_color:
34
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
35
+ else:
36
+ color = np.array([30/255, 144/255, 255/255, 0.6])
37
+ h, w = mask.shape[-2:]
38
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
39
+ ax.imshow(mask_image)
40
+
41
+ def show_points(coords, labels, ax, marker_size=375):
42
+ pos_points = coords[labels==1]
43
+ neg_points = coords[labels==0]
44
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
45
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
46
+
47
+ def show_box(box, ax):
48
+ x0, y0 = box[0], box[1]
49
+ w, h = box[2] - box[0], box[3] - box[1]
50
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
51
+
52
+ from SAM import SamPredictor
53
+ from SAM.build_sam import sam_model_registry
54
+ sam_checkpoint = "checkpoints/sam_vit_b_01ec64.pth"
55
+ model_type = "vit_b"
56
+
57
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
58
+ sam.to(device='cuda')
59
+ predictor = SamPredictor(sam)
60
+
61
+ def get_args_parser():
62
+ parser = argparse.ArgumentParser()
63
+ parser_url = parser.add_mutually_exclusive_group()
64
+ parser_url.add_argument("--local_network", action='store_true', default=False,
65
+ help="make app accessible on local network: address will be set to 0.0.0.0")
66
+ parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
67
+ parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
68
+ parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
69
+ "If None, will search for an available port starting at 7860."),
70
+ default=None)
71
+ parser.add_argument("--weights", type=str, required=True, help="path to the model weights")
72
+ parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
73
+ parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
74
+ return parser
75
+
76
+
77
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
78
+ cam_color=None, as_pointcloud=False, transparent_cams=False):
79
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
80
+ pts3d = to_numpy(pts3d)
81
+ imgs = to_numpy(imgs)
82
+ focals = to_numpy(focals)
83
+ cams2world = to_numpy(cams2world)
84
+
85
+ scene = trimesh.Scene()
86
+
87
+ # full pointcloud
88
+ if as_pointcloud:
89
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
90
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
91
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
92
+ scene.add_geometry(pct)
93
+ else:
94
+ meshes = []
95
+ for i in range(len(imgs)):
96
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
97
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
98
+ scene.add_geometry(mesh)
99
+
100
+ # add each camera
101
+ for i, pose_c2w in enumerate(cams2world):
102
+ if isinstance(cam_color, list):
103
+ camera_edge_color = cam_color[i]
104
+ else:
105
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
106
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
107
+ None if transparent_cams else imgs[i], focals[i],
108
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
109
+
110
+ rot = np.eye(4)
111
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
112
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
113
+ outfile = os.path.join(outdir, 'scene.glb')
114
+ print('(exporting 3D scene to', outfile, ')')
115
+ scene.export(file_obj=outfile)
116
+ return outfile
117
+
118
+
119
+ def get_3D_model_from_scene(outdir, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
120
+ clean_depth=False, transparent_cams=False, cam_size=0.05):
121
+ """
122
+ extract 3D_model (glb file) from a reconstructed scene
123
+ """
124
+ if scene is None:
125
+ return None
126
+ # post processes
127
+ if clean_depth:
128
+ scene = scene.clean_pointcloud()
129
+ if mask_sky:
130
+ scene = scene.mask_sky()
131
+
132
+ # get optimized values from scene
133
+ rgbimg = scene.imgs
134
+ # print('SAM step...')
135
+ # predictor.set_image((rgbimg[0] * 255).astype(np.uint8))
136
+ # h,w,c = rgbimg[0].shape
137
+ # input_point = np.array([
138
+ # [int(w/2), int(h/2)],
139
+ # [int(w/2), int(h/2)-20]
140
+ # ])
141
+ # input_label = np.array([1,1])
142
+ # masks1, scores, logits = predictor.predict(
143
+ # point_coords=input_point,
144
+ # point_labels=input_label,
145
+ # multimask_output=False,
146
+ # )
147
+ # fig, ax = plt.subplots(4, 2, figsize=(20, 20))
148
+ # show_mask(masks1[0], ax[0][0], random_color=True)
149
+ # show_points(input_point, input_label, ax[0][0])
150
+ # ax[0][1].imshow(rgbimg[0])
151
+
152
+ # predictor.set_image((rgbimg[1] * 255).astype(np.uint8))
153
+ # h,w,c = rgbimg[1].shape
154
+ # input_point = np.array([
155
+ # [int(w/2), int(h/2)],
156
+ # [int(w/2), int(h/2)-20]
157
+ # ])
158
+ # input_label = np.array([1,1])
159
+ # masks2, scores, logits = predictor.predict(
160
+ # point_coords=input_point,
161
+ # point_labels=input_label,
162
+ # multimask_output=False,
163
+ # )
164
+ focals = scene.get_focals().cpu()
165
+ cams2world = scene.get_im_poses().cpu()
166
+ # 3D pointcloud from depthmap, poses and intrinsics
167
+ pts3d = to_numpy(scene.get_pts3d())
168
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
169
+ msk = to_numpy(scene.get_masks())
170
+ # ax[1][0].imshow(msk[0])
171
+ # msk[0] = msk[0] & masks1[0]
172
+ # ax[1][1].imshow(msk[0])
173
+ # ax[2][1].imshow(rgbimg[1])
174
+ # show_mask(masks2[0], ax[2][0], random_color=True)
175
+ # show_points(input_point, input_label, ax[2][0])
176
+ # ax[3][0].imshow(msk[1])
177
+ # # msk[1] = msk[1] & masks2[0]
178
+ # ax[3][1].imshow(msk[1])
179
+ # plt.savefig("rgb.png")
180
+ # import pdb
181
+ # pdb.set_trace()
182
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
183
+ transparent_cams=transparent_cams, cam_size=cam_size)
184
+
185
+
186
+ def get_reconstructed_scene(outdir, model, device, image_size, filelist, schedule, niter, min_conf_thr,
187
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
188
+ scenegraph_type, winsize, refid):
189
+ """
190
+ from a list of images, run dust3r inference, global aligner.
191
+ then run get_3D_model_from_scene
192
+ """
193
+ imgs = load_images(filelist, size=image_size)
194
+ if len(imgs) == 1:
195
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
196
+ imgs[1]['idx'] = 1
197
+ if scenegraph_type == "swin":
198
+ scenegraph_type = scenegraph_type + "-" + str(winsize)
199
+ elif scenegraph_type == "oneref":
200
+ scenegraph_type = scenegraph_type + "-" + str(refid)
201
+
202
+ pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
203
+ output = inference(pairs, model, device, batch_size=batch_size)
204
+
205
+ mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
206
+ scene = global_aligner(output, device=device, mode=mode)
207
+ lr = 0.01
208
+
209
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
210
+ loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
211
+
212
+ outfile = get_3D_model_from_scene(outdir, scene, min_conf_thr, as_pointcloud, mask_sky,
213
+ clean_depth, transparent_cams, cam_size)
214
+
215
+ # also return rgb, depth and confidence imgs
216
+ # depth is normalized with the max value for all images
217
+ # we apply the jet colormap on the confidence maps
218
+ rgbimg = scene.imgs
219
+ depths = to_numpy(scene.get_depthmaps())
220
+ confs = to_numpy([c for c in scene.im_conf])
221
+ cmap = plt.get_cmap('jet')
222
+ depths_max = max([d.max() for d in depths])
223
+ depths = [d/depths_max for d in depths]
224
+ confs_max = max([d.max() for d in confs])
225
+ confs = [cmap(d/confs_max) for d in confs]
226
+
227
+ imgs = []
228
+ for i in range(len(rgbimg)):
229
+ imgs.append(rgbimg[i])
230
+ imgs.append(rgb(depths[i]))
231
+ imgs.append(rgb(confs[i]))
232
+
233
+ return scene, outfile, imgs
234
+
235
+
236
+ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
237
+ num_files = len(inputfiles) if inputfiles is not None else 1
238
+ max_winsize = max(1, (num_files - 1)//2)
239
+ if scenegraph_type == "swin":
240
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
241
+ minimum=1, maximum=max_winsize, step=1, visible=True)
242
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
243
+ maximum=num_files-1, step=1, visible=False)
244
+ elif scenegraph_type == "oneref":
245
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
246
+ minimum=1, maximum=max_winsize, step=1, visible=False)
247
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
248
+ maximum=num_files-1, step=1, visible=True)
249
+ else:
250
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
251
+ minimum=1, maximum=max_winsize, step=1, visible=False)
252
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
253
+ maximum=num_files-1, step=1, visible=False)
254
+ return winsize, refid
255
+
256
+
257
+ def main_demo(tmpdirname, model, device, image_size, server_name, server_port):
258
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, image_size)
259
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
260
+ with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo:
261
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
262
+ scene = gradio.State(None)
263
+ gradio.HTML('<h2 style="text-align: center;">DUSt3R Demo</h2>')
264
+ with gradio.Column():
265
+ inputfiles = gradio.File(file_count="multiple")
266
+ with gradio.Row():
267
+ schedule = gradio.Dropdown(["linear", "cosine"],
268
+ value='linear', label="schedule", info="For global alignment!")
269
+ niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
270
+ label="num_iterations", info="For global alignment!")
271
+ scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
272
+ value='complete', label="Scenegraph",
273
+ info="Define how to make pairs",
274
+ interactive=True)
275
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
276
+ minimum=1, maximum=1, step=1, visible=False)
277
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
278
+
279
+ run_btn = gradio.Button("Run")
280
+
281
+ with gradio.Row():
282
+ # adjust the confidence threshold
283
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
284
+ # adjust the camera size in the output pointcloud
285
+ cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
286
+ with gradio.Row():
287
+ as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
288
+ # two post process implemented
289
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
290
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
291
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
292
+
293
+ outmodel = gradio.Model3D()
294
+ outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%")
295
+
296
+ # events
297
+ scenegraph_type.change(set_scenegraph_options,
298
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
299
+ outputs=[winsize, refid])
300
+ inputfiles.change(set_scenegraph_options,
301
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
302
+ outputs=[winsize, refid])
303
+ run_btn.click(fn=recon_fun,
304
+ inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
305
+ mask_sky, clean_depth, transparent_cams, cam_size,
306
+ scenegraph_type, winsize, refid],
307
+ outputs=[scene, outmodel, outgallery])
308
+ min_conf_thr.release(fn=model_from_scene_fun,
309
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
310
+ clean_depth, transparent_cams, cam_size],
311
+ outputs=outmodel)
312
+ cam_size.change(fn=model_from_scene_fun,
313
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
314
+ clean_depth, transparent_cams, cam_size],
315
+ outputs=outmodel)
316
+ as_pointcloud.change(fn=model_from_scene_fun,
317
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
318
+ clean_depth, transparent_cams, cam_size],
319
+ outputs=outmodel)
320
+ mask_sky.change(fn=model_from_scene_fun,
321
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
322
+ clean_depth, transparent_cams, cam_size],
323
+ outputs=outmodel)
324
+ clean_depth.change(fn=model_from_scene_fun,
325
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
326
+ clean_depth, transparent_cams, cam_size],
327
+ outputs=outmodel)
328
+ transparent_cams.change(model_from_scene_fun,
329
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
330
+ clean_depth, transparent_cams, cam_size],
331
+ outputs=outmodel)
332
+ demo.launch(share=False, server_name=server_name, server_port=server_port)
333
+
334
+
335
+ if __name__ == '__main__':
336
+ parser = get_args_parser()
337
+ args = parser.parse_args()
338
+
339
+ if args.tmp_dir is not None:
340
+ tmp_path = args.tmp_dir
341
+ os.makedirs(tmp_path, exist_ok=True)
342
+ tempfile.tempdir = tmp_path
343
+
344
+ if args.server_name is not None:
345
+ server_name = args.server_name
346
+ else:
347
+ server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
348
+
349
+ model = load_model(args.weights, args.device)
350
+ # dust3r will write the 3D model inside tmpdirname
351
+ with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
352
+ print('Outputing stuff in', tmpdirname)
353
+ main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port)
checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e8bbf0c4d1d6007f5343f3f45814b956ddc5bbb4d00cb66beaf73afe5c53b34
3
+ size 2285019929
checkpoints/sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383
configs/default.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ expname = None # experiment name
4
+ basedir = './logs/' # where to store ckpts and logs
5
+
6
+ ''' Template of data options
7
+ '''
8
+ data = dict(
9
+ datadir=None, # path to dataset root folder
10
+ dataset_type=None, # blender | nsvf | blendedmvs | tankstemple | deepvoxels | co3d
11
+ inverse_y=False, # intrinsict mode (to support blendedmvs, nsvf, tankstemple)
12
+ flip_x=False, # to support co3d
13
+ flip_y=False, # to suppo/= 10
14
+ annot_path='', # to support co3d
15
+ split_path='', # to support co3d
16
+ sequence_name='', # to support co3d
17
+ # load2gpu_on_the_fly=False, # do not load all images into gpu (to save gpu memory)
18
+ load2gpu_on_the_fly=True, # do not load all images into gpu (to save gpu memory)
19
+ testskip=5, # subsample testset to preview results
20
+ white_bkgd=True, # use white background (note that some dataset don't provide alpha and with blended bg color)
21
+ rand_bkgd=False, # use random background during training
22
+ half_res=False, # [TODO]
23
+ bd_factor=.75,
24
+ movie_render_kwargs=dict(),
25
+
26
+ # Below are forward-facing llff specific settings.
27
+ ndc=False, # use ndc coordinate (only for forward-facing; not support yet)
28
+ spherify=False, # inward-facing
29
+ factor=4, # [TODO]
30
+ width=None, # enforce image width
31
+ height=None, # enforce image height
32
+ llffhold=8, # testsplit
33
+ load_depths=False, # load depth
34
+
35
+ # Below are unbounded inward-facing specific settings.
36
+ unbounded_inward=False,
37
+ unbounded_inner_r=1.0,
38
+ )
39
+
40
+ ''' Template of training options
41
+ '''
42
+ coarse_train = dict(
43
+ N_iters=5000, # number of optimization steps
44
+ N_rand=8192, # batch size (number of random rays per optimization step)
45
+ #N_rand=1024, # batch size (number of random rays per optimization step)
46
+ lrate_density=1e-1, # lr of density voxel grid
47
+ lrate_k0=1e-1, # lr of color/feature voxel grid
48
+ lrate_rgbnet=1e-3, # lr of the mlp to preduct view-dependent color
49
+ lrate_decay=20, # lr decay by 0.1 after every lrate_decay*1000 steps
50
+ pervoxel_lr=True, # view-count-based lr
51
+ pervoxel_lr_downrate=1, # downsampled image for computing view-count-based lr
52
+ ray_sampler='random', # ray sampling strategies
53
+ weight_main=1.0, # weight of photometric loss
54
+ weight_entropy_last=0.01, # weight of background entropy loss
55
+ weight_nearclip=0,
56
+ weight_distortion=0,
57
+ weight_rgbper=0.1, # weight of per-point rgb loss
58
+ tv_every=1, # count total variation loss every tv_every step
59
+ tv_after=0, # count total variation loss from tv_from step
60
+ tv_before=0, # count total variation before the given number of iterations
61
+ tv_dense_before=0, # count total variation densely before the given number of iterations
62
+ weight_tv_density=0.0, # weight of total variation loss of density voxel grid
63
+ weight_tv_k0=0.0, # weight of total variation loss of color/feature voxel grid
64
+ pg_scale=[], # checkpoints for progressive scaling
65
+ decay_after_scale=1.0, # decay act_shift after scaling
66
+ skip_zero_grad_fields=[], # the variable name to skip optimizing parameters w/ zero grad in each iteration
67
+ maskout_lt_nviews=0,
68
+ )
69
+
70
+ fine_train = deepcopy(coarse_train)
71
+ fine_train.update(dict(
72
+ N_iters=20000,
73
+ pervoxel_lr=False,
74
+ ray_sampler='flatten',
75
+ weight_entropy_last=0.001,
76
+ weight_rgbper=0.01,
77
+ pg_scale=[1000, 2000, 3000, 4000],
78
+ skip_zero_grad_fields=['density', 'k0'],
79
+ ))
80
+
81
+ ''' Template of model and rendering options
82
+ '''
83
+ coarse_model_and_render = dict(
84
+ num_voxels=1024000, # expected number of voxel
85
+ num_voxels_base=1024000, # to rescale delta distance
86
+ density_type='DenseGrid', # DenseGrid, TensoRFGrid
87
+ k0_type='TensoRFGrid', # DenseGrid, TensoRFGrid
88
+ density_config=dict(),
89
+ k0_config=dict(n_comp=48),
90
+ mpi_depth=128, # the number of planes in Multiplane Image (work when ndc=True)
91
+ nearest=False, # nearest interpolation
92
+ pre_act_density=False, # pre-activated trilinear interpolation
93
+ in_act_density=False, # in-activated trilinear interpolation
94
+ bbox_thres=1e-3, # threshold to determine known free-space in the fine stage
95
+ mask_cache_thres=1e-3, # threshold to determine a tighten BBox in the fine stage
96
+ rgbnet_dim=0, # feature voxel grid dim
97
+ rgbnet_full_implicit=False, # let the colors MLP ignore feature voxel grid
98
+ rgbnet_direct=True, # set to False to treat the first 3 dim of feature voxel grid as diffuse rgb
99
+ rgbnet_depth=3, # depth of the colors MLP (there are rgbnet_depth-1 intermediate features)
100
+ rgbnet_width=128, # width of the colors MLP
101
+ alpha_init=1e-6, # set the alpha values everywhere at the begin of training
102
+ fast_color_thres=1e-7, # threshold of alpha value to skip the fine stage sampled point
103
+ maskout_near_cam_vox=True, # maskout grid points that between cameras and their near planes
104
+ world_bound_scale=1, # rescale the BBox enclosing the scene
105
+ stepsize=0.5, # sampling stepsize in volume rendering
106
+ )
107
+
108
+ fine_model_and_render = deepcopy(coarse_model_and_render)
109
+ fine_model_and_render.update(dict(
110
+ num_voxels=160**3,
111
+ num_voxels_base=160**3,
112
+ rgbnet_dim=12,
113
+ alpha_init=1e-2,
114
+ fast_color_thres=1e-4,
115
+ maskout_near_cam_vox=False,
116
+ world_bound_scale=1.05,
117
+ ))
118
+
119
+ del deepcopy
configs/lerf/book_store.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = './lerf_default.py'
2
+
3
+ expname = 'dcvgo_book_store'
4
+
5
+ data = dict(
6
+ datadir='./data/lerf_data/book_store',
7
+ factor=2, # 497 * 369
8
+ # factor=4,
9
+ movie_render_kwargs=dict(
10
+ shift_x=0.5, # positive right
11
+ shift_y=0.5, # negative down
12
+ shift_z=1,
13
+ scale_r=0,
14
+ pitch_deg=0, # negative look downward
15
+ ),
16
+ )
configs/lerf/bouquet.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = './lerf_default.py'
2
+
3
+ expname = 'dcvgo_bouquet'
4
+
5
+ data = dict(
6
+ datadir='./data/lerf_data/bouquet',
7
+ factor=2, # 497 * 369
8
+ # factor=4,
9
+ movie_render_kwargs=dict(
10
+ shift_x=0.0, # positive right
11
+ shift_y=-0.0, # negative down
12
+ shift_z=0,
13
+ scale_r=0.2,
14
+ pitch_deg=0, # negative look downward
15
+ ),
16
+ )
configs/lerf/donuts.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = './lerf_default.py'
2
+
3
+ expname = 'dcvgo_donuts'
4
+
5
+ data = dict(
6
+ datadir='./data/lerf_data/donuts',
7
+ factor=2, # 497 * 369
8
+ # factor=4,
9
+ movie_render_kwargs=dict(
10
+ shift_x=-0.2,
11
+ shift_y=0.2,
12
+ shift_z=0.1,
13
+ scale_r=1.3,
14
+ pitch_deg=60,
15
+ ),
16
+ )
configs/lerf/dozer_nerfgun_waldo.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = './lerf_default.py'
2
+
3
+ expname = 'dcvgo_dozer_nerfgun_waldo'
4
+
5
+ data = dict(
6
+ datadir='./data/lerf_data/dozer_nerfgun_waldo',
7
+ factor=2, # 497 * 369
8
+ # factor=4,
9
+ # movie_render_kwargs=dict(
10
+ # shift_x=0.0, # positive right
11
+ # shift_y=-0.3, # negative down
12
+ # shift_z=0,
13
+ # scale_r=0.2,
14
+ # pitch_deg=-40, # negative look downward
15
+ # ),
16
+ )
configs/lerf/espresso.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = './lerf_default.py'
2
+
3
+ expname = 'dcvgo_espresso'
4
+
5
+ data = dict(
6
+ datadir='./data/lerf_data/espresso',
7
+ factor=2, # 497 * 369
8
+ # factor=4,
9
+ # movie_render_kwargs=dict(
10
+ # shift_x=0.0, # positive right
11
+ # shift_y=-0.3, # negative down
12
+ # shift_z=0,
13
+ # scale_r=0.2,
14
+ # pitch_deg=-40, # negative look downward
15
+ # ),
16
+ )