schirrmacher commited on
Commit
656b5ab
·
verified ·
1 Parent(s): 1351f60

Upload folder using huggingface_hub

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
.gitattributes CHANGED
@@ -38,3 +38,8 @@ no-background.png filter=lfs diff=lfs merge=lfs -text
38
  examples/example1.png filter=lfs diff=lfs merge=lfs -text
39
  examples/no-background1.png filter=lfs diff=lfs merge=lfs -text
40
  examples.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
38
  examples/example1.png filter=lfs diff=lfs merge=lfs -text
39
  examples/no-background1.png filter=lfs diff=lfs merge=lfs -text
40
  examples.jpg filter=lfs diff=lfs merge=lfs -text
41
+ dataset/training/im/p_00a4eda7.png filter=lfs diff=lfs merge=lfs -text
42
+ dataset/training/im/p_00a5b702.png filter=lfs diff=lfs merge=lfs -text
43
+ dataset/validation/im/p_00a7a27c.png filter=lfs diff=lfs merge=lfs -text
44
+ examples/image/image01.png filter=lfs diff=lfs merge=lfs -text
45
+ examples/image/image01_no_background.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ models/*
README.md CHANGED
@@ -15,7 +15,9 @@ datasets:
15
 
16
  [>>> DEMO <<<](https://huggingface.co/spaces/schirrmacher/ormbg)
17
 
18
- ![](examples.jpg)
 
 
19
 
20
  This model is a **fully open-source background remover** optimized for images with humans. It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS). The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans), [P3M-10k](https://paperswithcode.com/dataset/p3m-10k) and [AIM-500](https://paperswithcode.com/dataset/aim-500).
21
 
@@ -24,7 +26,22 @@ This model is similar to [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4), but
24
  ## Inference
25
 
26
  ```
27
- python utils/inference.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ```
29
 
30
  # Research
 
15
 
16
  [>>> DEMO <<<](https://huggingface.co/spaces/schirrmacher/ormbg)
17
 
18
+ Join our [Research Discord Group](https://discord.gg/YYZ3D66t)!
19
+
20
+ ![](examples/image/image01_no_background.png)
21
 
22
  This model is a **fully open-source background remover** optimized for images with humans. It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS). The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans), [P3M-10k](https://paperswithcode.com/dataset/p3m-10k) and [AIM-500](https://paperswithcode.com/dataset/aim-500).
23
 
 
26
  ## Inference
27
 
28
  ```
29
+ python ormbg/inference.py
30
+ ```
31
+
32
+ ## Training
33
+
34
+ Install dependencies:
35
+
36
+ ```
37
+ conda env create -f environment.yaml
38
+ conda activate ormbg
39
+ ```
40
+
41
+ Replace dummy dataset with (training dataset)[https://huggingface.co/datasets/schirrmacher/humans].
42
+
43
+ ```
44
+ python3 ormbg/train_model.py
45
  ```
46
 
47
  # Research
dataset/training/gt/p_00a4eda7.png ADDED
dataset/training/gt/p_00a5b702.png ADDED
dataset/training/im/p_00a4eda7.png ADDED

Git LFS Details

  • SHA256: e226a687b5d755056076e12d7f2c24704d101ad90918554c43028e8c1e53638f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.74 MB
dataset/training/im/p_00a5b702.png ADDED

Git LFS Details

  • SHA256: 184b2d97ffdbffc9d0a5d3c3b84a848938df636855d59b81f3d109445a92b0ef
  • Pointer size: 132 Bytes
  • Size of remote file: 3.46 MB
dataset/validation/gt/p_00a7a27c.png ADDED
dataset/validation/im/p_00a7a27c.png ADDED

Git LFS Details

  • SHA256: b87d59e4598ddc1078ebdc856e7101d92582315ecff2aecdadc17802e82bc8c1
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
environment.yaml ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ormbg
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - anaconda
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=main
9
+ - _openmp_mutex=5.1=1_gnu
10
+ - aom=3.6.0=h6a678d5_0
11
+ - blas=1.0=mkl
12
+ - blosc=1.21.3=h6a678d5_0
13
+ - brotli=1.0.9=h5eee18b_7
14
+ - brotli-bin=1.0.9=h5eee18b_7
15
+ - brotli-python=1.0.9=py38h6a678d5_7
16
+ - brunsli=0.1=h2531618_0
17
+ - bzip2=1.0.8=h7b6447c_0
18
+ - c-ares=1.19.1=h5eee18b_0
19
+ - ca-certificates=2023.08.22=h06a4308_0
20
+ - certifi=2023.7.22=py38h06a4308_0
21
+ - cffi=1.15.0=py38h7f8727e_0
22
+ - cfitsio=3.470=h5893167_7
23
+ - charls=2.2.0=h2531618_0
24
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
25
+ - click=8.1.7=py38h06a4308_0
26
+ - cloudpickle=2.2.1=py38h06a4308_0
27
+ - contourpy=1.0.5=py38hdb19cb5_0
28
+ - cryptography=41.0.3=py38h130f0dd_0
29
+ - cuda-cudart=11.8.89=0
30
+ - cuda-cupti=11.8.87=0
31
+ - cuda-libraries=11.8.0=0
32
+ - cuda-nvrtc=11.8.89=0
33
+ - cuda-nvtx=11.8.86=0
34
+ - cuda-runtime=11.8.0=0
35
+ - cudatoolkit=11.8.0=h6a678d5_0
36
+ - cycler=0.11.0=pyhd3eb1b0_0
37
+ - cytoolz=0.12.0=py38h5eee18b_0
38
+ - dask-core=2023.4.1=py38h06a4308_0
39
+ - dav1d=1.2.1=h5eee18b_0
40
+ - dbus=1.13.18=hb2f20db_0
41
+ - expat=2.5.0=h6a678d5_0
42
+ - ffmpeg=4.3=hf484d3e_0
43
+ - fftw=3.3.9=h27cfd23_1
44
+ - filelock=3.9.0=py38h06a4308_0
45
+ - fontconfig=2.14.1=h52c9d5c_1
46
+ - fonttools=4.25.0=pyhd3eb1b0_0
47
+ - freetype=2.12.1=h4a9f257_0
48
+ - fsspec=2023.9.2=py38h06a4308_0
49
+ - giflib=5.2.1=h5eee18b_3
50
+ - glib=2.63.1=h5a9c865_0
51
+ - gmp=6.2.1=h295c915_3
52
+ - gmpy2=2.1.2=py38heeb90bb_0
53
+ - gnutls=3.6.15=he1e5248_0
54
+ - gst-plugins-base=1.14.0=hbbd80ab_1
55
+ - gstreamer=1.14.0=hb453b48_1
56
+ - icu=58.2=he6710b0_3
57
+ - idna=3.4=py38h06a4308_0
58
+ - imagecodecs=2023.1.23=py38hc4b7b5f_0
59
+ - imageio=2.31.4=py38h06a4308_0
60
+ - importlib-metadata=6.0.0=py38h06a4308_0
61
+ - importlib_resources=6.1.0=py38h06a4308_0
62
+ - intel-openmp=2021.4.0=h06a4308_3561
63
+ - jinja2=3.1.2=py38h06a4308_0
64
+ - jpeg=9e=h5eee18b_1
65
+ - jxrlib=1.1=h7b6447c_2
66
+ - kiwisolver=1.4.4=py38h6a678d5_0
67
+ - krb5=1.20.1=h568e23c_1
68
+ - lame=3.100=h7b6447c_0
69
+ - lazy_loader=0.3=py38h06a4308_0
70
+ - lcms2=2.12=h3be6417_0
71
+ - lerc=3.0=h295c915_0
72
+ - libaec=1.0.4=he6710b0_1
73
+ - libavif=0.11.1=h5eee18b_0
74
+ - libbrotlicommon=1.0.9=h5eee18b_7
75
+ - libbrotlidec=1.0.9=h5eee18b_7
76
+ - libbrotlienc=1.0.9=h5eee18b_7
77
+ - libcublas=11.11.3.6=0
78
+ - libcufft=10.9.0.58=0
79
+ - libcufile=1.8.1.2=0
80
+ - libcurand=10.3.4.101=0
81
+ - libcurl=7.88.1=h91b91d3_2
82
+ - libcusolver=11.4.1.48=0
83
+ - libcusparse=11.7.5.86=0
84
+ - libdeflate=1.17=h5eee18b_1
85
+ - libedit=3.1.20221030=h5eee18b_0
86
+ - libev=4.33=h7f8727e_1
87
+ - libffi=3.2.1=hf484d3e_1007
88
+ - libgcc-ng=11.2.0=h1234567_1
89
+ - libgfortran-ng=11.2.0=h00389a5_1
90
+ - libgfortran5=11.2.0=h1234567_1
91
+ - libgomp=11.2.0=h1234567_1
92
+ - libiconv=1.16=h7f8727e_2
93
+ - libidn2=2.3.4=h5eee18b_0
94
+ - libjpeg-turbo=2.0.0=h9bf148f_0
95
+ - libnghttp2=1.52.0=ha637b67_1
96
+ - libnpp=11.8.0.86=0
97
+ - libnvjpeg=11.9.0.86=0
98
+ - libpng=1.6.39=h5eee18b_0
99
+ - libssh2=1.10.0=h37d81fd_2
100
+ - libstdcxx-ng=11.2.0=h1234567_1
101
+ - libtasn1=4.19.0=h5eee18b_0
102
+ - libtiff=4.5.1=h6a678d5_0
103
+ - libunistring=0.9.10=h27cfd23_0
104
+ - libuuid=1.41.5=h5eee18b_0
105
+ - libwebp=1.3.2=h11a3e52_0
106
+ - libwebp-base=1.3.2=h5eee18b_0
107
+ - libxcb=1.15=h7f8727e_0
108
+ - libxml2=2.9.14=h74e7548_0
109
+ - libzopfli=1.0.3=he6710b0_0
110
+ - llvm-openmp=14.0.6=h9e868ea_0
111
+ - locket=1.0.0=py38h06a4308_0
112
+ - lz4-c=1.9.4=h6a678d5_0
113
+ - markupsafe=2.1.1=py38h7f8727e_0
114
+ - matplotlib=3.7.2=py38h06a4308_0
115
+ - matplotlib-base=3.7.2=py38h1128e8f_0
116
+ - mkl=2021.4.0=h06a4308_640
117
+ - mkl-service=2.4.0=py38h7f8727e_0
118
+ - mkl_fft=1.3.1=py38hd3c417c_0
119
+ - mkl_random=1.2.2=py38h51133e4_0
120
+ - mpc=1.1.0=h10f8cd9_1
121
+ - mpfr=4.0.2=hb69a4c5_1
122
+ - mpmath=1.3.0=py38h06a4308_0
123
+ - munkres=1.1.4=py_0
124
+ - ncurses=6.4=h6a678d5_0
125
+ - nettle=3.7.3=hbbd107a_1
126
+ - networkx=3.1=py38h06a4308_0
127
+ - openh264=2.1.1=h4ff587b_0
128
+ - openjpeg=2.4.0=h3ad879b_0
129
+ - openssl=1.1.1w=h7f8727e_0
130
+ - packaging=23.1=py38h06a4308_0
131
+ - partd=1.4.1=py38h06a4308_0
132
+ - pcre=8.45=h295c915_0
133
+ - pillow=10.0.1=py38ha6cbd5a_0
134
+ - pip=23.3=py38h06a4308_0
135
+ - pycparser=2.21=pyhd3eb1b0_0
136
+ - pyopenssl=23.2.0=py38h06a4308_0
137
+ - pyparsing=3.0.9=py38h06a4308_0
138
+ - pyqt=5.9.2=py38h05f1152_4
139
+ - pysocks=1.7.1=py38h06a4308_0
140
+ - python=3.8.0=h0371630_2
141
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
142
+ - pytorch=2.1.1=py3.8_cuda11.8_cudnn8.7.0_0
143
+ - pytorch-cuda=11.8=h7e8668a_5
144
+ - pytorch-mutex=1.0=cuda
145
+ - pywavelets=1.4.1=py38h5eee18b_0
146
+ - pyyaml=6.0.1=py38h5eee18b_0
147
+ - qt=5.9.7=h5867ecd_1
148
+ - readline=7.0=h7b6447c_5
149
+ - requests=2.31.0=py38h06a4308_0
150
+ - setuptools=68.0.0=py38h06a4308_0
151
+ - sip=4.19.13=py38h295c915_0
152
+ - six=1.16.0=pyhd3eb1b0_1
153
+ - snappy=1.1.9=h295c915_0
154
+ - sqlite=3.33.0=h62c20be_0
155
+ - sympy=1.11.1=py38h06a4308_0
156
+ - tifffile=2023.4.12=py38h06a4308_0
157
+ - tk=8.6.12=h1ccaba5_0
158
+ - toolz=0.12.0=py38h06a4308_0
159
+ - torchaudio=2.1.1=py38_cu118
160
+ - torchtriton=2.1.0=py38
161
+ - torchvision=0.16.1=py38_cu118
162
+ - tornado=6.3.3=py38h5eee18b_0
163
+ - tqdm=4.65.0=py38hb070fc8_0
164
+ - urllib3=1.26.18=py38h06a4308_0
165
+ - wheel=0.41.2=py38h06a4308_0
166
+ - xz=5.4.2=h5eee18b_0
167
+ - yaml=0.2.5=h7b6447c_0
168
+ - zfp=1.0.0=h6a678d5_0
169
+ - zipp=3.11.0=py38h06a4308_0
170
+ - zlib=1.2.13=h5eee18b_0
171
+ - zstd=1.5.5=hc292b87_0
172
+ - pip:
173
+ - albucore==0.0.12
174
+ - albumentations==1.4.11
175
+ - annotated-types==0.7.0
176
+ - appdirs==1.4.4
177
+ - conda-pack==0.7.1
178
+ - docker-pycreds==0.4.0
179
+ - eval-type-backport==0.2.0
180
+ - gitdb==4.0.11
181
+ - gitpython==3.1.40
182
+ - joblib==1.4.2
183
+ - numpy==1.24.4
184
+ - opencv-python-headless==4.10.0.84
185
+ - protobuf==4.25.1
186
+ - psutil==5.9.6
187
+ - pydantic==2.8.2
188
+ - pydantic-core==2.20.1
189
+ - scikit-image==0.21.0
190
+ - scikit-learn==1.3.2
191
+ - scipy==1.10.1
192
+ - sentry-sdk==1.35.0
193
+ - setproctitle==1.3.3
194
+ - smmap==5.0.1
195
+ - threadpoolctl==3.5.0
196
+ - tomli==2.0.1
197
+ - typing-extensions==4.12.2
198
+ - wandb==0.16.0
199
+ prefix: /home/macher/miniconda3/envs/ormbg
examples/.DS_Store ADDED
Binary file (6.15 kB). View file
 
examples/image/image01.png ADDED

Git LFS Details

  • SHA256: 1c6d54789fc0d8816231ca9f061b19af50bdbfb59a4fed7fa6c7bd3168591b0e
  • Pointer size: 133 Bytes
  • Size of remote file: 16.7 MB
examples/image/image01_no_background.png ADDED

Git LFS Details

  • SHA256: 9290ced416914386458bded92614b3b620bf82fc9dc7b06b4015fc6791d34cc3
  • Pointer size: 133 Bytes
  • Size of remote file: 21.4 MB
examples/loss/gt.png ADDED
examples/loss/loss01.png ADDED
examples/loss/loss02.png ADDED
examples/loss/loss03.png ADDED
examples/loss/loss04.png ADDED
examples/loss/loss05.png ADDED
examples/loss/orginal.jpg ADDED
ormbg/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ormbg/basics.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '2'
4
+ from skimage import io, transform
5
+ import torch
6
+ import torchvision
7
+ from torch.autograd import Variable
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms, utils
12
+ import torch.optim as optim
13
+
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ from PIL import Image
17
+ import glob
18
+
19
+
20
+ def mae_torch(pred, gt):
21
+
22
+ h, w = gt.shape[0:2]
23
+ sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
24
+ maeError = torch.divide(sumError, float(h) * float(w) * 255.0 + 1e-4)
25
+
26
+ return maeError
27
+
28
+
29
+ def f1score_torch(pd, gt):
30
+
31
+ # print(gt.shape)
32
+ gtNum = torch.sum((gt > 128).float() * 1) ## number of ground truth pixels
33
+
34
+ pp = pd[gt > 128]
35
+ nn = pd[gt <= 128]
36
+
37
+ pp_hist = torch.histc(pp, bins=255, min=0, max=255)
38
+ nn_hist = torch.histc(nn, bins=255, min=0, max=255)
39
+
40
+ pp_hist_flip = torch.flipud(pp_hist)
41
+ nn_hist_flip = torch.flipud(nn_hist)
42
+
43
+ pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
44
+ nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
45
+
46
+ precision = (pp_hist_flip_cum) / (
47
+ pp_hist_flip_cum + nn_hist_flip_cum + 1e-4
48
+ ) # torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
49
+ recall = (pp_hist_flip_cum) / (gtNum + 1e-4)
50
+ f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 1e-4)
51
+
52
+ return (
53
+ torch.reshape(precision, (1, precision.shape[0])),
54
+ torch.reshape(recall, (1, recall.shape[0])),
55
+ torch.reshape(f1, (1, f1.shape[0])),
56
+ )
57
+
58
+
59
+ def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
60
+
61
+ import time
62
+
63
+ tic = time.time()
64
+
65
+ if len(gt.shape) > 2:
66
+ gt = gt[:, :, 0]
67
+
68
+ pre, rec, f1 = f1score_torch(pred, gt)
69
+ mae = mae_torch(pred, gt)
70
+
71
+ print(valid_dataset.dataset["im_name"][idx] + ".png")
72
+ print("time for evaluation : ", time.time() - tic)
73
+
74
+ return (
75
+ pre.cpu().data.numpy(),
76
+ rec.cpu().data.numpy(),
77
+ f1.cpu().data.numpy(),
78
+ mae.cpu().data.numpy(),
79
+ )
ormbg/data_loader_cache.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## data loader
2
+ ## Ackownledgement:
3
+ ## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
4
+ ## for his helps in implementing cache machanism of our DIS dataloader.
5
+ from __future__ import print_function, division
6
+
7
+ import albumentations as A
8
+ import numpy as np
9
+ import random
10
+ from copy import deepcopy
11
+ import json
12
+ from tqdm import tqdm
13
+ from skimage import io
14
+ import os
15
+ from glob import glob
16
+
17
+ import torch
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torchvision import transforms
20
+ from torchvision.transforms.functional import normalize
21
+ import torch.nn.functional as F
22
+
23
+ #### --------------------- DIS dataloader cache ---------------------####
24
+
25
+
26
+ def get_im_gt_name_dict(datasets, flag="valid"):
27
+ print("------------------------------", flag, "--------------------------------")
28
+ name_im_gt_list = []
29
+ for i in range(len(datasets)):
30
+ print(
31
+ "--->>>",
32
+ flag,
33
+ " dataset ",
34
+ i,
35
+ "/",
36
+ len(datasets),
37
+ " ",
38
+ datasets[i]["name"],
39
+ "<<<---",
40
+ )
41
+ tmp_im_list, tmp_gt_list = [], []
42
+ im_dir = datasets[i]["im_dir"]
43
+ gt_dir = datasets[i]["gt_dir"]
44
+ tmp_im_list = glob(os.path.join(im_dir, "*" + "*.[jp][pn]g"))
45
+ tmp_gt_list = glob(os.path.join(gt_dir, "*" + "*.[jp][pn]g"))
46
+
47
+ print(
48
+ "-im-", datasets[i]["name"], datasets[i]["im_dir"], ": ", len(tmp_im_list)
49
+ )
50
+
51
+ print(
52
+ "-gt-",
53
+ datasets[i]["name"],
54
+ datasets[i]["gt_dir"],
55
+ ": ",
56
+ len(tmp_gt_list),
57
+ )
58
+
59
+ if flag == "train": ## combine multiple training sets into one dataset
60
+ if len(name_im_gt_list) == 0:
61
+ name_im_gt_list.append(
62
+ {
63
+ "dataset_name": datasets[i]["name"],
64
+ "im_path": tmp_im_list,
65
+ "gt_path": tmp_gt_list,
66
+ "im_ext": datasets[i]["im_ext"],
67
+ "gt_ext": datasets[i]["gt_ext"],
68
+ "cache_dir": datasets[i]["cache_dir"],
69
+ }
70
+ )
71
+ else:
72
+ name_im_gt_list[0]["dataset_name"] = (
73
+ name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"]
74
+ )
75
+ name_im_gt_list[0]["im_path"] = (
76
+ name_im_gt_list[0]["im_path"] + tmp_im_list
77
+ )
78
+ name_im_gt_list[0]["gt_path"] = (
79
+ name_im_gt_list[0]["gt_path"] + tmp_gt_list
80
+ )
81
+ if datasets[i]["im_ext"] != ".jpg" or datasets[i]["gt_ext"] != ".png":
82
+ print(
83
+ "Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!"
84
+ )
85
+ exit()
86
+ name_im_gt_list[0]["im_ext"] = ".jpg"
87
+ name_im_gt_list[0]["gt_ext"] = ".png"
88
+ name_im_gt_list[0]["cache_dir"] = (
89
+ os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])
90
+ + os.sep
91
+ + name_im_gt_list[0]["dataset_name"]
92
+ )
93
+ else: ## keep different validation or inference datasets as separate ones
94
+ name_im_gt_list.append(
95
+ {
96
+ "dataset_name": datasets[i]["name"],
97
+ "im_path": tmp_im_list,
98
+ "gt_path": tmp_gt_list,
99
+ "im_ext": datasets[i]["im_ext"],
100
+ "gt_ext": datasets[i]["gt_ext"],
101
+ "cache_dir": datasets[i]["cache_dir"],
102
+ }
103
+ )
104
+
105
+ return name_im_gt_list
106
+
107
+
108
+ def create_dataloaders(
109
+ name_im_gt_list,
110
+ cache_size=[],
111
+ cache_boost=True,
112
+ my_transforms=[],
113
+ batch_size=1,
114
+ shuffle=False,
115
+ ):
116
+ ## model="train": return one dataloader for training
117
+ ## model="valid": return a list of dataloaders for validation or testing
118
+
119
+ gos_dataloaders = []
120
+ gos_datasets = []
121
+
122
+ if len(name_im_gt_list) == 0:
123
+ return gos_dataloaders, gos_datasets
124
+
125
+ num_workers_ = 1
126
+ if batch_size > 1:
127
+ num_workers_ = 2
128
+ if batch_size > 4:
129
+ num_workers_ = 4
130
+ if batch_size > 8:
131
+ num_workers_ = 8
132
+
133
+ for i in range(0, len(name_im_gt_list)):
134
+ gos_dataset = GOSDatasetCache(
135
+ [name_im_gt_list[i]],
136
+ cache_size=cache_size,
137
+ cache_path=name_im_gt_list[i]["cache_dir"],
138
+ cache_boost=cache_boost,
139
+ transform=transforms.Compose(my_transforms),
140
+ )
141
+ gos_dataloaders.append(
142
+ DataLoader(
143
+ gos_dataset,
144
+ batch_size=batch_size,
145
+ shuffle=shuffle,
146
+ num_workers=num_workers_,
147
+ )
148
+ )
149
+ gos_datasets.append(gos_dataset)
150
+
151
+ return gos_dataloaders, gos_datasets
152
+
153
+
154
+ def im_reader(im_path):
155
+ return io.imread(im_path)
156
+
157
+
158
+ def im_preprocess(im, size):
159
+ if len(im.shape) < 3:
160
+ im = im[:, :, np.newaxis]
161
+ if im.shape[2] == 1:
162
+ im = np.repeat(im, 3, axis=2)
163
+ im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
164
+ im_tensor = torch.transpose(torch.transpose(im_tensor, 1, 2), 0, 1)
165
+ if len(size) < 2:
166
+ return im_tensor, im.shape[0:2]
167
+ else:
168
+ im_tensor = torch.unsqueeze(im_tensor, 0)
169
+ im_tensor = F.upsample(im_tensor, size, mode="bilinear")
170
+ im_tensor = torch.squeeze(im_tensor, 0)
171
+
172
+ return im_tensor.type(torch.uint8), im.shape[0:2]
173
+
174
+
175
+ def gt_preprocess(gt, size):
176
+ if len(gt.shape) > 2:
177
+ gt = gt[:, :, 0]
178
+
179
+ gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8), 0)
180
+
181
+ if len(size) < 2:
182
+ return gt_tensor.type(torch.uint8), gt.shape[0:2]
183
+ else:
184
+ gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32), 0)
185
+ gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
186
+ gt_tensor = torch.squeeze(gt_tensor, 0)
187
+
188
+ return gt_tensor.type(torch.uint8), gt.shape[0:2]
189
+ # return gt_tensor, gt.shape[0:2]
190
+
191
+
192
+ class GOSGridDropout(object):
193
+ def __init__(
194
+ self,
195
+ ratio=0.5,
196
+ unit_size_min=100,
197
+ unit_size_max=100,
198
+ holes_number_x=None,
199
+ holes_number_y=None,
200
+ shift_x=0,
201
+ shift_y=0,
202
+ random_offset=True,
203
+ fill_value=0,
204
+ mask_fill_value=None,
205
+ always_apply=None,
206
+ p=1.0,
207
+ ):
208
+ self.transform = A.GridDropout(
209
+ ratio=ratio,
210
+ unit_size_min=unit_size_min,
211
+ unit_size_max=unit_size_max,
212
+ holes_number_x=holes_number_x,
213
+ holes_number_y=holes_number_y,
214
+ shift_x=shift_x,
215
+ shift_y=shift_y,
216
+ random_offset=random_offset,
217
+ fill_value=fill_value,
218
+ mask_fill_value=mask_fill_value,
219
+ always_apply=always_apply,
220
+ p=p,
221
+ )
222
+
223
+ def __call__(self, sample):
224
+ imidx, image, label, shape = (
225
+ sample["imidx"],
226
+ sample["image"],
227
+ sample["label"],
228
+ sample["shape"],
229
+ )
230
+
231
+ # Convert the torch tensors to numpy arrays
232
+ image_np = image.permute(1, 2, 0).numpy()
233
+
234
+ augmented = self.transform(image=image_np)
235
+
236
+ # Convert the numpy arrays back to torch tensors
237
+ image = torch.tensor(augmented["image"]).permute(2, 0, 1)
238
+
239
+ return {"imidx": imidx, "image": image, "label": label, "shape": shape}
240
+
241
+
242
+ class GOSRandomHFlip(object):
243
+ def __init__(self, prob=0.5):
244
+ self.prob = prob
245
+
246
+ def __call__(self, sample):
247
+ imidx, image, label, shape = (
248
+ sample["imidx"],
249
+ sample["image"],
250
+ sample["label"],
251
+ sample["shape"],
252
+ )
253
+
254
+ # random horizontal flip
255
+ if random.random() >= self.prob:
256
+ image = torch.flip(image, dims=[2])
257
+ label = torch.flip(label, dims=[2])
258
+
259
+ return {"imidx": imidx, "image": image, "label": label, "shape": shape}
260
+
261
+
262
+ class GOSDatasetCache(Dataset):
263
+
264
+ def __init__(
265
+ self,
266
+ name_im_gt_list,
267
+ cache_size=[],
268
+ cache_path="./cache",
269
+ cache_file_name="dataset.json",
270
+ cache_boost=False,
271
+ transform=None,
272
+ ):
273
+
274
+ self.cache_size = cache_size
275
+ self.cache_path = cache_path
276
+ self.cache_file_name = cache_file_name
277
+ self.cache_boost_name = ""
278
+
279
+ self.cache_boost = cache_boost
280
+ # self.ims_npy = None
281
+ # self.gts_npy = None
282
+
283
+ ## cache all the images and ground truth into a single pytorch tensor
284
+ self.ims_pt = None
285
+ self.gts_pt = None
286
+
287
+ ## we will cache the npy as well regardless of the cache_boost
288
+ # if(self.cache_boost):
289
+ self.cache_boost_name = cache_file_name.split(".json")[0]
290
+
291
+ self.transform = transform
292
+
293
+ self.dataset = {}
294
+
295
+ ## combine different datasets into one
296
+ dataset_names = []
297
+ dt_name_list = [] # dataset name per image
298
+ im_name_list = [] # image name
299
+ im_path_list = [] # im path
300
+ gt_path_list = [] # gt path
301
+ im_ext_list = [] # im ext
302
+ gt_ext_list = [] # gt ext
303
+ for i in range(0, len(name_im_gt_list)):
304
+ dataset_names.append(name_im_gt_list[i]["dataset_name"])
305
+ # dataset name repeated based on the number of images in this dataset
306
+ dt_name_list.extend(
307
+ [
308
+ name_im_gt_list[i]["dataset_name"]
309
+ for x in name_im_gt_list[i]["im_path"]
310
+ ]
311
+ )
312
+ im_name_list.extend(
313
+ [
314
+ x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0]
315
+ for x in name_im_gt_list[i]["im_path"]
316
+ ]
317
+ )
318
+ im_path_list.extend(name_im_gt_list[i]["im_path"])
319
+ gt_path_list.extend(name_im_gt_list[i]["gt_path"])
320
+ im_ext_list.extend(
321
+ [name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]]
322
+ )
323
+ gt_ext_list.extend(
324
+ [name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]]
325
+ )
326
+
327
+ self.dataset["data_name"] = dt_name_list
328
+ self.dataset["im_name"] = im_name_list
329
+ self.dataset["im_path"] = im_path_list
330
+ self.dataset["ori_im_path"] = deepcopy(im_path_list)
331
+ self.dataset["gt_path"] = gt_path_list
332
+ self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
333
+ self.dataset["im_shp"] = []
334
+ self.dataset["gt_shp"] = []
335
+ self.dataset["im_ext"] = im_ext_list
336
+ self.dataset["gt_ext"] = gt_ext_list
337
+
338
+ self.dataset["ims_pt_dir"] = ""
339
+ self.dataset["gts_pt_dir"] = ""
340
+
341
+ self.dataset = self.manage_cache(dataset_names)
342
+
343
+ def manage_cache(self, dataset_names):
344
+ if not os.path.exists(self.cache_path): # create the folder for cache
345
+ os.makedirs(self.cache_path)
346
+ cache_folder = os.path.join(
347
+ self.cache_path,
348
+ "_".join(dataset_names) + "_" + "x".join([str(x) for x in self.cache_size]),
349
+ )
350
+ if not os.path.exists(
351
+ cache_folder
352
+ ): # check if the cache files are there, if not then cache
353
+ return self.cache(cache_folder)
354
+ return self.load_cache(cache_folder)
355
+
356
+ def cache(self, cache_folder):
357
+ os.mkdir(cache_folder)
358
+ cached_dataset = deepcopy(self.dataset)
359
+
360
+ # ims_list = []
361
+ # gts_list = []
362
+ ims_pt_list = []
363
+ gts_pt_list = []
364
+ for i, im_path in tqdm(
365
+ enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])
366
+ ):
367
+
368
+ im_id = cached_dataset["im_name"][i]
369
+ print("im_path: ", im_path)
370
+ im = im_reader(im_path)
371
+ im, im_shp = im_preprocess(im, self.cache_size)
372
+ im_cache_file = os.path.join(
373
+ cache_folder, self.dataset["data_name"][i] + "_" + im_id + "_im.pt"
374
+ )
375
+ torch.save(im, im_cache_file)
376
+
377
+ cached_dataset["im_path"][i] = im_cache_file
378
+ if self.cache_boost:
379
+ ims_pt_list.append(torch.unsqueeze(im, 0))
380
+ # ims_list.append(im.cpu().data.numpy().astype(np.uint8))
381
+
382
+ gt = np.zeros(im.shape[0:2])
383
+ if len(self.dataset["gt_path"]) != 0:
384
+ gt = im_reader(self.dataset["gt_path"][i])
385
+ gt, gt_shp = gt_preprocess(gt, self.cache_size)
386
+ gt_cache_file = os.path.join(
387
+ cache_folder, self.dataset["data_name"][i] + "_" + im_id + "_gt.pt"
388
+ )
389
+ torch.save(gt, gt_cache_file)
390
+ if len(self.dataset["gt_path"]) > 0:
391
+ cached_dataset["gt_path"][i] = gt_cache_file
392
+ else:
393
+ cached_dataset["gt_path"].append(gt_cache_file)
394
+ if self.cache_boost:
395
+ gts_pt_list.append(torch.unsqueeze(gt, 0))
396
+ # gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
397
+
398
+ # im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
399
+ # torch.save(gt_shp, shp_cache_file)
400
+ cached_dataset["im_shp"].append(im_shp)
401
+ # self.dataset["im_shp"].append(im_shp)
402
+
403
+ # shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
404
+ # torch.save(gt_shp, shp_cache_file)
405
+ cached_dataset["gt_shp"].append(gt_shp)
406
+ # self.dataset["gt_shp"].append(gt_shp)
407
+
408
+ if self.cache_boost:
409
+ cached_dataset["ims_pt_dir"] = os.path.join(
410
+ cache_folder, self.cache_boost_name + "_ims.pt"
411
+ )
412
+ cached_dataset["gts_pt_dir"] = os.path.join(
413
+ cache_folder, self.cache_boost_name + "_gts.pt"
414
+ )
415
+ self.ims_pt = torch.cat(ims_pt_list, dim=0)
416
+ self.gts_pt = torch.cat(gts_pt_list, dim=0)
417
+ torch.save(torch.cat(ims_pt_list, dim=0), cached_dataset["ims_pt_dir"])
418
+ torch.save(torch.cat(gts_pt_list, dim=0), cached_dataset["gts_pt_dir"])
419
+
420
+ try:
421
+ json_file = open(os.path.join(cache_folder, self.cache_file_name), "w")
422
+ json.dump(cached_dataset, json_file)
423
+ json_file.close()
424
+ except Exception:
425
+ raise FileNotFoundError("Cannot create JSON")
426
+ return cached_dataset
427
+
428
+ def load_cache(self, cache_folder):
429
+ json_file = open(os.path.join(cache_folder, self.cache_file_name), "r")
430
+ dataset = json.load(json_file)
431
+ json_file.close()
432
+ ## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
433
+ ## otherwise the pytorch tensor will be loaded
434
+ if self.cache_boost:
435
+ # self.ims_npy = np.load(dataset["ims_npy_dir"])
436
+ # self.gts_npy = np.load(dataset["gts_npy_dir"])
437
+ self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location="cpu")
438
+ self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location="cpu")
439
+ return dataset
440
+
441
+ def __len__(self):
442
+ return len(self.dataset["im_path"])
443
+
444
+ def __getitem__(self, idx):
445
+
446
+ im = None
447
+ gt = None
448
+ if self.cache_boost and self.ims_pt is not None:
449
+
450
+ # start = time.time()
451
+ im = self.ims_pt[idx] # .type(torch.float32)
452
+ gt = self.gts_pt[idx] # .type(torch.float32)
453
+ # print(idx, 'time for pt loading: ', time.time()-start)
454
+
455
+ else:
456
+ # import time
457
+ # start = time.time()
458
+ # print("tensor***")
459
+ im_pt_path = os.path.join(
460
+ self.cache_path,
461
+ os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]),
462
+ )
463
+ im = torch.load(im_pt_path) # (self.dataset["im_path"][idx])
464
+ gt_pt_path = os.path.join(
465
+ self.cache_path,
466
+ os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]),
467
+ )
468
+ gt = torch.load(gt_pt_path) # (self.dataset["gt_path"][idx])
469
+ # print(idx,'time for tensor loading: ', time.time()-start)
470
+
471
+ im_shp = self.dataset["im_shp"][idx]
472
+ # print("time for loading im and gt: ", time.time()-start)
473
+
474
+ # start_time = time.time()
475
+ im = torch.divide(im, 255.0)
476
+ gt = torch.divide(gt, 255.0)
477
+ # print(idx, 'time for normalize torch divide: ', time.time()-start_time)
478
+
479
+ sample = {
480
+ "imidx": torch.from_numpy(np.array(idx)),
481
+ "image": im,
482
+ "label": gt,
483
+ "shape": torch.from_numpy(np.array(im_shp)),
484
+ }
485
+
486
+ if self.transform:
487
+ sample = self.transform(sample)
488
+
489
+ return sample
ormbg/inference.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import numpy as np
5
+ from PIL import Image
6
+ from skimage import io
7
+ from models.ormbg import ORMBG
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def parse_args():
12
+ parser = argparse.ArgumentParser(
13
+ description="Remove background from images using ORMBG model."
14
+ )
15
+ parser.add_argument(
16
+ "--image",
17
+ type=str,
18
+ default=os.path.join("examples", "image", "image01.png"),
19
+ help="Path to the input image file.",
20
+ )
21
+ parser.add_argument(
22
+ "--output",
23
+ type=str,
24
+ default=os.path.join("image01_no_background.png"),
25
+ help="Path to the output image file.",
26
+ )
27
+ parser.add_argument(
28
+ "--model-path",
29
+ type=str,
30
+ default=os.path.join("models", "ormbg.pth"),
31
+ help="Path to the model file.",
32
+ )
33
+ parser.add_argument(
34
+ "--compare",
35
+ action="store_false",
36
+ help="Flag to save the original and processed images side by side.",
37
+ )
38
+ return parser.parse_args()
39
+
40
+
41
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
42
+ if len(im.shape) < 3:
43
+ im = im[:, :, np.newaxis]
44
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
45
+ im_tensor = F.interpolate(
46
+ torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
47
+ ).type(torch.uint8)
48
+ image = torch.divide(im_tensor, 255.0)
49
+ return image
50
+
51
+
52
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
53
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
54
+ ma = torch.max(result)
55
+ mi = torch.min(result)
56
+ result = (result - mi) / (ma - mi)
57
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
58
+ im_array = np.squeeze(im_array)
59
+ return im_array
60
+
61
+
62
+ def inference(args):
63
+ image_path = args.image
64
+ result_name = args.output
65
+ model_path = args.model_path
66
+ compare = args.compare
67
+
68
+ net = ORMBG()
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+
71
+ if torch.cuda.is_available():
72
+ net.load_state_dict(torch.load(model_path))
73
+ net = net.cuda()
74
+ else:
75
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
76
+ net.eval()
77
+
78
+ model_input_size = [1024, 1024]
79
+ orig_im = io.imread(image_path)
80
+ orig_im_size = orig_im.shape[0:2]
81
+ image = preprocess_image(orig_im, model_input_size).to(device)
82
+
83
+ result = net(image)
84
+
85
+ # post process
86
+ result_image = postprocess_image(result[0][0], orig_im_size)
87
+
88
+ # save result
89
+ pil_im = Image.fromarray(result_image)
90
+
91
+ if pil_im.mode == "RGBA":
92
+ pil_im = pil_im.convert("RGB")
93
+
94
+ no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
95
+ orig_image = Image.open(image_path)
96
+ no_bg_image.paste(orig_image, mask=pil_im)
97
+
98
+ if compare:
99
+ combined_width = orig_image.width + no_bg_image.width
100
+ combined_image = Image.new("RGBA", (combined_width, orig_image.height))
101
+ combined_image.paste(orig_image, (0, 0))
102
+ combined_image.paste(no_bg_image, (orig_image.width, 0))
103
+ stacked_output_path = os.path.splitext(result_name)[0] + ".png"
104
+ combined_image.save(stacked_output_path)
105
+ else:
106
+ no_bg_image.save(result_name)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ inference(parse_args())
ormbg/models/ormbg.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
6
+
7
+
8
+ class REBNCONV(nn.Module):
9
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
10
+ super(REBNCONV, self).__init__()
11
+
12
+ self.conv_s1 = nn.Conv2d(
13
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
14
+ )
15
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
16
+ self.relu_s1 = nn.ReLU(inplace=True)
17
+
18
+ def forward(self, x):
19
+
20
+ hx = x
21
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
22
+
23
+ return xout
24
+
25
+
26
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
27
+ def _upsample_like(src, tar):
28
+
29
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
+
31
+ return src
32
+
33
+
34
+ ### RSU-7 ###
35
+ class RSU7(nn.Module):
36
+
37
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
38
+ super(RSU7, self).__init__()
39
+
40
+ self.in_ch = in_ch
41
+ self.mid_ch = mid_ch
42
+ self.out_ch = out_ch
43
+
44
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
45
+
46
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
47
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
+
49
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
50
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
+
52
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
53
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
54
+
55
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
56
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
+
58
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
59
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
60
+
61
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
62
+
63
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
64
+
65
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
69
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
70
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
71
+
72
+ def forward(self, x):
73
+ b, c, h, w = x.shape
74
+
75
+ hx = x
76
+ hxin = self.rebnconvin(hx)
77
+
78
+ hx1 = self.rebnconv1(hxin)
79
+ hx = self.pool1(hx1)
80
+
81
+ hx2 = self.rebnconv2(hx)
82
+ hx = self.pool2(hx2)
83
+
84
+ hx3 = self.rebnconv3(hx)
85
+ hx = self.pool3(hx3)
86
+
87
+ hx4 = self.rebnconv4(hx)
88
+ hx = self.pool4(hx4)
89
+
90
+ hx5 = self.rebnconv5(hx)
91
+ hx = self.pool5(hx5)
92
+
93
+ hx6 = self.rebnconv6(hx)
94
+
95
+ hx7 = self.rebnconv7(hx6)
96
+
97
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
98
+ hx6dup = _upsample_like(hx6d, hx5)
99
+
100
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
101
+ hx5dup = _upsample_like(hx5d, hx4)
102
+
103
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
104
+ hx4dup = _upsample_like(hx4d, hx3)
105
+
106
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
107
+ hx3dup = _upsample_like(hx3d, hx2)
108
+
109
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
110
+ hx2dup = _upsample_like(hx2d, hx1)
111
+
112
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
113
+
114
+ return hx1d + hxin
115
+
116
+
117
+ ### RSU-6 ###
118
+ class RSU6(nn.Module):
119
+
120
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
121
+ super(RSU6, self).__init__()
122
+
123
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
124
+
125
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
126
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
136
+
137
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
138
+
139
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
140
+
141
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
143
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
144
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
145
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
146
+
147
+ def forward(self, x):
148
+
149
+ hx = x
150
+
151
+ hxin = self.rebnconvin(hx)
152
+
153
+ hx1 = self.rebnconv1(hxin)
154
+ hx = self.pool1(hx1)
155
+
156
+ hx2 = self.rebnconv2(hx)
157
+ hx = self.pool2(hx2)
158
+
159
+ hx3 = self.rebnconv3(hx)
160
+ hx = self.pool3(hx3)
161
+
162
+ hx4 = self.rebnconv4(hx)
163
+ hx = self.pool4(hx4)
164
+
165
+ hx5 = self.rebnconv5(hx)
166
+
167
+ hx6 = self.rebnconv6(hx5)
168
+
169
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
170
+ hx5dup = _upsample_like(hx5d, hx4)
171
+
172
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
173
+ hx4dup = _upsample_like(hx4d, hx3)
174
+
175
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
176
+ hx3dup = _upsample_like(hx3d, hx2)
177
+
178
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
179
+ hx2dup = _upsample_like(hx2d, hx1)
180
+
181
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
182
+
183
+ return hx1d + hxin
184
+
185
+
186
+ ### RSU-5 ###
187
+ class RSU5(nn.Module):
188
+
189
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
190
+ super(RSU5, self).__init__()
191
+
192
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
193
+
194
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
195
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
196
+
197
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
198
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
199
+
200
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
201
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
202
+
203
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
204
+
205
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
206
+
207
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
209
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
211
+
212
+ def forward(self, x):
213
+
214
+ hx = x
215
+
216
+ hxin = self.rebnconvin(hx)
217
+
218
+ hx1 = self.rebnconv1(hxin)
219
+ hx = self.pool1(hx1)
220
+
221
+ hx2 = self.rebnconv2(hx)
222
+ hx = self.pool2(hx2)
223
+
224
+ hx3 = self.rebnconv3(hx)
225
+ hx = self.pool3(hx3)
226
+
227
+ hx4 = self.rebnconv4(hx)
228
+
229
+ hx5 = self.rebnconv5(hx4)
230
+
231
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
232
+ hx4dup = _upsample_like(hx4d, hx3)
233
+
234
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
235
+ hx3dup = _upsample_like(hx3d, hx2)
236
+
237
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
238
+ hx2dup = _upsample_like(hx2d, hx1)
239
+
240
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
241
+
242
+ return hx1d + hxin
243
+
244
+
245
+ ### RSU-4 ###
246
+ class RSU4(nn.Module):
247
+
248
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
249
+ super(RSU4, self).__init__()
250
+
251
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
252
+
253
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
254
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
255
+
256
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
257
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
258
+
259
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
260
+
261
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
262
+
263
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
264
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
265
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
266
+
267
+ def forward(self, x):
268
+
269
+ hx = x
270
+
271
+ hxin = self.rebnconvin(hx)
272
+
273
+ hx1 = self.rebnconv1(hxin)
274
+ hx = self.pool1(hx1)
275
+
276
+ hx2 = self.rebnconv2(hx)
277
+ hx = self.pool2(hx2)
278
+
279
+ hx3 = self.rebnconv3(hx)
280
+
281
+ hx4 = self.rebnconv4(hx3)
282
+
283
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
284
+ hx3dup = _upsample_like(hx3d, hx2)
285
+
286
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
287
+ hx2dup = _upsample_like(hx2d, hx1)
288
+
289
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
290
+
291
+ return hx1d + hxin
292
+
293
+
294
+ ### RSU-4F ###
295
+ class RSU4F(nn.Module):
296
+
297
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
298
+ super(RSU4F, self).__init__()
299
+
300
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
301
+
302
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
303
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
304
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
305
+
306
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
307
+
308
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
309
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
310
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
311
+
312
+ def forward(self, x):
313
+
314
+ hx = x
315
+
316
+ hxin = self.rebnconvin(hx)
317
+
318
+ hx1 = self.rebnconv1(hxin)
319
+ hx2 = self.rebnconv2(hx1)
320
+ hx3 = self.rebnconv3(hx2)
321
+
322
+ hx4 = self.rebnconv4(hx3)
323
+
324
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
325
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
326
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
327
+
328
+ return hx1d + hxin
329
+
330
+
331
+ class myrebnconv(nn.Module):
332
+ def __init__(
333
+ self,
334
+ in_ch=3,
335
+ out_ch=1,
336
+ kernel_size=3,
337
+ stride=1,
338
+ padding=1,
339
+ dilation=1,
340
+ groups=1,
341
+ ):
342
+ super(myrebnconv, self).__init__()
343
+
344
+ self.conv = nn.Conv2d(
345
+ in_ch,
346
+ out_ch,
347
+ kernel_size=kernel_size,
348
+ stride=stride,
349
+ padding=padding,
350
+ dilation=dilation,
351
+ groups=groups,
352
+ )
353
+ self.bn = nn.BatchNorm2d(out_ch)
354
+ self.rl = nn.ReLU(inplace=True)
355
+
356
+ def forward(self, x):
357
+ return self.rl(self.bn(self.conv(x)))
358
+
359
+
360
+ bce_loss = nn.BCELoss(size_average=True)
361
+
362
+
363
+ class ORMBG(nn.Module):
364
+
365
+ def __init__(self, in_ch=3, out_ch=1):
366
+ super(ORMBG, self).__init__()
367
+
368
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
369
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
+
371
+ self.stage1 = RSU7(64, 32, 64)
372
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
373
+
374
+ self.stage2 = RSU6(64, 32, 128)
375
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
376
+
377
+ self.stage3 = RSU5(128, 64, 256)
378
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
379
+
380
+ self.stage4 = RSU4(256, 128, 512)
381
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
382
+
383
+ self.stage5 = RSU4F(512, 256, 512)
384
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
385
+
386
+ self.stage6 = RSU4F(512, 256, 512)
387
+
388
+ # decoder
389
+ self.stage5d = RSU4F(1024, 256, 512)
390
+ self.stage4d = RSU4(1024, 128, 256)
391
+ self.stage3d = RSU5(512, 64, 128)
392
+ self.stage2d = RSU6(256, 32, 64)
393
+ self.stage1d = RSU7(128, 16, 64)
394
+
395
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
396
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
397
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
398
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
399
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
400
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
401
+
402
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
403
+
404
+ def compute_loss(self, predictions, ground_truth):
405
+ loss0, loss = 0.0, 0.0
406
+ for i in range(0, len(predictions)):
407
+ loss = loss + bce_loss(predictions[i], ground_truth)
408
+ if i == 0:
409
+ loss0 = loss
410
+ return loss0, loss
411
+
412
+ def forward(self, x):
413
+
414
+ hx = x
415
+
416
+ hxin = self.conv_in(hx)
417
+ # hx = self.pool_in(hxin)
418
+
419
+ # stage 1
420
+ hx1 = self.stage1(hxin)
421
+ hx = self.pool12(hx1)
422
+
423
+ # stage 2
424
+ hx2 = self.stage2(hx)
425
+ hx = self.pool23(hx2)
426
+
427
+ # stage 3
428
+ hx3 = self.stage3(hx)
429
+ hx = self.pool34(hx3)
430
+
431
+ # stage 4
432
+ hx4 = self.stage4(hx)
433
+ hx = self.pool45(hx4)
434
+
435
+ # stage 5
436
+ hx5 = self.stage5(hx)
437
+ hx = self.pool56(hx5)
438
+
439
+ # stage 6
440
+ hx6 = self.stage6(hx)
441
+ hx6up = _upsample_like(hx6, hx5)
442
+
443
+ # -------------------- decoder --------------------
444
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
445
+ hx5dup = _upsample_like(hx5d, hx4)
446
+
447
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
448
+ hx4dup = _upsample_like(hx4d, hx3)
449
+
450
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
451
+ hx3dup = _upsample_like(hx3d, hx2)
452
+
453
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
454
+ hx2dup = _upsample_like(hx2d, hx1)
455
+
456
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
457
+
458
+ # side output
459
+ d1 = self.side1(hx1d)
460
+ d1 = _upsample_like(d1, x)
461
+
462
+ d2 = self.side2(hx2d)
463
+ d2 = _upsample_like(d2, x)
464
+
465
+ d3 = self.side3(hx3d)
466
+ d3 = _upsample_like(d3, x)
467
+
468
+ d4 = self.side4(hx4d)
469
+ d4 = _upsample_like(d4, x)
470
+
471
+ d5 = self.side5(hx5d)
472
+ d5 = _upsample_like(d5, x)
473
+
474
+ d6 = self.side6(hx6)
475
+ d6 = _upsample_like(d6, x)
476
+
477
+ return [
478
+ F.sigmoid(d1),
479
+ F.sigmoid(d2),
480
+ F.sigmoid(d3),
481
+ F.sigmoid(d4),
482
+ F.sigmoid(d5),
483
+ F.sigmoid(d6),
484
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
ormbg/train_model.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import torch, gc
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.autograd import Variable
8
+ import torch.nn.functional as F
9
+
10
+ import numpy as np
11
+
12
+ from pathlib import Path
13
+
14
+ from models.ormbg import ORMBG
15
+
16
+ from skimage import io
17
+
18
+ from basics import f1_mae_torch
19
+
20
+ from data_loader_cache import (
21
+ get_im_gt_name_dict,
22
+ create_dataloaders,
23
+ GOSGridDropout,
24
+ GOSRandomHFlip,
25
+ )
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+
30
+ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
31
+ net.eval()
32
+ print("Validating...")
33
+ epoch_num = hypar["max_epoch_num"]
34
+
35
+ val_loss = 0.0
36
+ tar_loss = 0.0
37
+ val_cnt = 0.0
38
+
39
+ tmp_f1 = []
40
+ tmp_mae = []
41
+ tmp_time = []
42
+
43
+ start_valid = time.time()
44
+
45
+ for k in range(len(valid_dataloaders)):
46
+
47
+ valid_dataloader = valid_dataloaders[k]
48
+ valid_dataset = valid_datasets[k]
49
+
50
+ val_num = valid_dataset.__len__()
51
+ mybins = np.arange(0, 256)
52
+ PRE = np.zeros((val_num, len(mybins) - 1))
53
+ REC = np.zeros((val_num, len(mybins) - 1))
54
+ F1 = np.zeros((val_num, len(mybins) - 1))
55
+ MAE = np.zeros((val_num))
56
+
57
+ for i_val, data_val in enumerate(valid_dataloader):
58
+ val_cnt = val_cnt + 1.0
59
+ imidx_val, inputs_val, labels_val, shapes_val = (
60
+ data_val["imidx"],
61
+ data_val["image"],
62
+ data_val["label"],
63
+ data_val["shape"],
64
+ )
65
+
66
+ if hypar["model_digit"] == "full":
67
+ inputs_val = inputs_val.type(torch.FloatTensor)
68
+ labels_val = labels_val.type(torch.FloatTensor)
69
+ else:
70
+ inputs_val = inputs_val.type(torch.HalfTensor)
71
+ labels_val = labels_val.type(torch.HalfTensor)
72
+
73
+ # wrap them in Variable
74
+ if torch.cuda.is_available():
75
+ inputs_val_v, labels_val_v = Variable(
76
+ inputs_val.cuda(), requires_grad=False
77
+ ), Variable(labels_val.cuda(), requires_grad=False)
78
+ else:
79
+ inputs_val_v, labels_val_v = Variable(
80
+ inputs_val, requires_grad=False
81
+ ), Variable(labels_val, requires_grad=False)
82
+
83
+ t_start = time.time()
84
+ ds_val = net(inputs_val_v)[0]
85
+ t_end = time.time() - t_start
86
+ tmp_time.append(t_end)
87
+
88
+ # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
89
+ loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
90
+
91
+ # compute F measure
92
+ for t in range(hypar["batch_size_valid"]):
93
+ i_test = imidx_val[t].data.numpy()
94
+
95
+ pred_val = ds_val[0][t, :, :, :] # B x 1 x H x W
96
+
97
+ ## recover the prediction spatial size to the orignal image size
98
+ pred_val = torch.squeeze(
99
+ F.upsample(
100
+ torch.unsqueeze(pred_val, 0),
101
+ (shapes_val[t][0], shapes_val[t][1]),
102
+ mode="bilinear",
103
+ )
104
+ )
105
+
106
+ # pred_val = normPRED(pred_val)
107
+ ma = torch.max(pred_val)
108
+ mi = torch.min(pred_val)
109
+ pred_val = (pred_val - mi) / (ma - mi) # max = 1
110
+
111
+ if len(valid_dataset.dataset["ori_gt_path"]) != 0:
112
+ gt = np.squeeze(
113
+ io.imread(valid_dataset.dataset["ori_gt_path"][i_test])
114
+ ) # max = 255
115
+ if gt.max() == 1:
116
+ gt = gt * 255
117
+ else:
118
+ gt = np.zeros((shapes_val[t][0], shapes_val[t][1]))
119
+ with torch.no_grad():
120
+ gt = torch.tensor(gt).to(device)
121
+
122
+ pre, rec, f1, mae = f1_mae_torch(
123
+ pred_val * 255, gt, valid_dataset, i_test, mybins, hypar
124
+ )
125
+
126
+ PRE[i_test, :] = pre
127
+ REC[i_test, :] = rec
128
+ F1[i_test, :] = f1
129
+ MAE[i_test] = mae
130
+
131
+ del ds_val, gt
132
+ gc.collect()
133
+ torch.cuda.empty_cache()
134
+
135
+ # if(loss_val.data[0]>1):
136
+ val_loss += loss_val.item() # data[0]
137
+ tar_loss += loss2_val.item() # data[0]
138
+
139
+ print(
140
+ "[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"
141
+ % (
142
+ i_val,
143
+ val_num,
144
+ val_loss / (i_val + 1),
145
+ tar_loss / (i_val + 1),
146
+ np.amax(F1[i_test, :]),
147
+ MAE[i_test],
148
+ t_end,
149
+ )
150
+ )
151
+
152
+ del loss2_val, loss_val
153
+
154
+ print("============================")
155
+ PRE_m = np.mean(PRE, 0)
156
+ REC_m = np.mean(REC, 0)
157
+ f1_m = (1 + 0.3) * PRE_m * REC_m / (0.3 * PRE_m + REC_m + 1e-8)
158
+
159
+ tmp_f1.append(np.amax(f1_m))
160
+ tmp_mae.append(np.mean(MAE))
161
+
162
+ return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
163
+
164
+
165
+ def train(
166
+ net,
167
+ optimizer,
168
+ train_dataloaders,
169
+ train_datasets,
170
+ valid_dataloaders,
171
+ valid_datasets,
172
+ hypar,
173
+ ):
174
+
175
+ model_path = hypar["model_path"]
176
+ model_save_fre = hypar["model_save_fre"]
177
+ max_ite = hypar["max_ite"]
178
+ batch_size_train = hypar["batch_size_train"]
179
+ batch_size_valid = hypar["batch_size_valid"]
180
+
181
+ if not os.path.exists(model_path):
182
+ os.mkdir(model_path)
183
+
184
+ ite_num = hypar["start_ite"] # count the toal iteration number
185
+ ite_num4val = 0 #
186
+ running_loss = 0.0 # count the toal loss
187
+ running_tar_loss = 0.0 # count the target output loss
188
+ last_f1 = [0 for x in range(len(valid_dataloaders))]
189
+
190
+ train_num = train_datasets[0].__len__()
191
+
192
+ net.train()
193
+
194
+ start_last = time.time()
195
+ gos_dataloader = train_dataloaders[0]
196
+ epoch_num = hypar["max_epoch_num"]
197
+ notgood_cnt = 0
198
+
199
+ for epoch in range(epoch_num):
200
+
201
+ for i, data in enumerate(gos_dataloader):
202
+
203
+ if ite_num >= max_ite:
204
+ print("Training Reached the Maximal Iteration Number ", max_ite)
205
+ exit()
206
+
207
+ # start_read = time.time()
208
+ ite_num = ite_num + 1
209
+ ite_num4val = ite_num4val + 1
210
+
211
+ # get the inputs
212
+ inputs, labels = data["image"], data["label"]
213
+
214
+ if hypar["model_digit"] == "full":
215
+ inputs = inputs.type(torch.FloatTensor)
216
+ labels = labels.type(torch.FloatTensor)
217
+ else:
218
+ inputs = inputs.type(torch.HalfTensor)
219
+ labels = labels.type(torch.HalfTensor)
220
+
221
+ # wrap them in Variable
222
+ if torch.cuda.is_available():
223
+ inputs_v, labels_v = Variable(
224
+ inputs.cuda(), requires_grad=False
225
+ ), Variable(labels.cuda(), requires_grad=False)
226
+ else:
227
+ inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(
228
+ labels, requires_grad=False
229
+ )
230
+
231
+ # y zero the parameter gradients
232
+ start_inf_loss_back = time.time()
233
+ optimizer.zero_grad()
234
+
235
+ ds, _ = net(inputs_v)
236
+ loss2, loss = net.compute_loss(ds, labels_v)
237
+
238
+ loss.backward()
239
+ optimizer.step()
240
+
241
+ # # print statistics
242
+ running_loss += loss.item()
243
+ running_tar_loss += loss2.item()
244
+
245
+ # del outputs, loss
246
+ del ds, loss2, loss
247
+ end_inf_loss_back = time.time() - start_inf_loss_back
248
+
249
+ print(
250
+ ">>>"
251
+ + model_path.split("/")[-1]
252
+ + " - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f"
253
+ % (
254
+ epoch + 1,
255
+ epoch_num,
256
+ (i + 1) * batch_size_train,
257
+ train_num,
258
+ ite_num,
259
+ running_loss / ite_num4val,
260
+ running_tar_loss / ite_num4val,
261
+ time.time() - start_last,
262
+ time.time() - start_last - end_inf_loss_back,
263
+ )
264
+ )
265
+ start_last = time.time()
266
+
267
+ if ite_num % model_save_fre == 0: # validate every 2000 iterations
268
+ notgood_cnt += 1
269
+ net.eval()
270
+ tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(
271
+ net, valid_dataloaders, valid_datasets, hypar, epoch
272
+ )
273
+ net.train() # resume train
274
+
275
+ tmp_out = 0
276
+ print("last_f1:", last_f1)
277
+ print("tmp_f1:", tmp_f1)
278
+ for fi in range(len(last_f1)):
279
+ if tmp_f1[fi] > last_f1[fi]:
280
+ tmp_out = 1
281
+ print("tmp_out:", tmp_out)
282
+ if tmp_out:
283
+ notgood_cnt = 0
284
+ last_f1 = tmp_f1
285
+ tmp_f1_str = [str(round(f1x, 4)) for f1x in tmp_f1]
286
+ tmp_mae_str = [str(round(mx, 4)) for mx in tmp_mae]
287
+ maxf1 = "_".join(tmp_f1_str)
288
+ meanM = "_".join(tmp_mae_str)
289
+ # .cpu().detach().numpy()
290
+ model_name = (
291
+ "/gpu_itr_"
292
+ + str(ite_num)
293
+ + "_traLoss_"
294
+ + str(np.round(running_loss / ite_num4val, 4))
295
+ + "_traTarLoss_"
296
+ + str(np.round(running_tar_loss / ite_num4val, 4))
297
+ + "_valLoss_"
298
+ + str(np.round(val_loss / (i_val + 1), 4))
299
+ + "_valTarLoss_"
300
+ + str(np.round(tar_loss / (i_val + 1), 4))
301
+ + "_maxF1_"
302
+ + maxf1
303
+ + "_mae_"
304
+ + meanM
305
+ + "_time_"
306
+ + str(
307
+ np.round(np.mean(np.array(tmp_time)) / batch_size_valid, 6)
308
+ )
309
+ + ".pth"
310
+ )
311
+ torch.save(net.state_dict(), model_path + model_name)
312
+
313
+ running_loss = 0.0
314
+ running_tar_loss = 0.0
315
+ ite_num4val = 0
316
+
317
+ if notgood_cnt >= hypar["early_stop"]:
318
+ print(
319
+ "No improvements in the last "
320
+ + str(notgood_cnt)
321
+ + " validation periods, so training stopped !"
322
+ )
323
+ exit()
324
+
325
+ print("Training Reaches The Maximum Epoch Number")
326
+
327
+
328
+ def main(train_datasets, valid_datasets, hypar):
329
+
330
+ print("--- create training dataloader ---")
331
+
332
+ train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
333
+ ## build dataloader for training datasets
334
+ train_dataloaders, train_datasets = create_dataloaders(
335
+ train_nm_im_gt_list,
336
+ cache_size=hypar["cache_size"],
337
+ cache_boost=hypar["cache_boost_train"],
338
+ my_transforms=[GOSGridDropout(), GOSRandomHFlip()],
339
+ batch_size=hypar["batch_size_train"],
340
+ shuffle=True,
341
+ )
342
+
343
+ valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
344
+
345
+ valid_dataloaders, valid_datasets = create_dataloaders(
346
+ valid_nm_im_gt_list,
347
+ cache_size=hypar["cache_size"],
348
+ cache_boost=hypar["cache_boost_valid"],
349
+ my_transforms=[],
350
+ batch_size=hypar["batch_size_valid"],
351
+ shuffle=False,
352
+ )
353
+
354
+ net = hypar["model"]
355
+
356
+ if hypar["model_digit"] == "half":
357
+ net.half()
358
+ for layer in net.modules():
359
+ if isinstance(layer, nn.BatchNorm2d):
360
+ layer.float()
361
+
362
+ if torch.cuda.is_available():
363
+ net.cuda()
364
+
365
+ if hypar["restore_model"] != "":
366
+ print("restore model from:")
367
+ print(hypar["model_path"] + "/" + hypar["restore_model"])
368
+ if torch.cuda.is_available():
369
+ net.load_state_dict(
370
+ torch.load(hypar["model_path"] + "/" + hypar["restore_model"])
371
+ )
372
+ else:
373
+ net.load_state_dict(
374
+ torch.load(
375
+ hypar["model_path"] + "/" + hypar["restore_model"],
376
+ map_location="cpu",
377
+ )
378
+ )
379
+
380
+ optimizer = optim.Adam(
381
+ net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0
382
+ )
383
+
384
+ train(
385
+ net,
386
+ optimizer,
387
+ train_dataloaders,
388
+ train_datasets,
389
+ valid_dataloaders,
390
+ valid_datasets,
391
+ hypar,
392
+ )
393
+
394
+
395
+ if __name__ == "__main__":
396
+
397
+ output_model_folder = "saved_models"
398
+ Path(output_model_folder).mkdir(parents=True, exist_ok=True)
399
+
400
+ train_datasets, valid_datasets = [], []
401
+ dataset_1, dataset_1 = {}, {}
402
+
403
+ dataset_training = {
404
+ "name": "ormbg-training",
405
+ "im_dir": str(Path("dataset", "training", "im")),
406
+ "gt_dir": str(Path("dataset", "training", "gt")),
407
+ "im_ext": ".png",
408
+ "gt_ext": ".png",
409
+ "cache_dir": str(Path("cache", "teacher", "training")),
410
+ }
411
+
412
+ dataset_validation = {
413
+ "name": "ormbg-training",
414
+ "im_dir": str(Path("dataset", "validation", "im")),
415
+ "gt_dir": str(Path("dataset", "validation", "gt")),
416
+ "im_ext": ".png",
417
+ "gt_ext": ".png",
418
+ "cache_dir": str(Path("cache", "teacher", "validation")),
419
+ }
420
+
421
+ train_datasets = [dataset_training]
422
+ valid_datasets = [dataset_validation]
423
+
424
+ ### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
425
+ hypar = {}
426
+
427
+ hypar["model"] = ORMBG()
428
+ hypar["seed"] = 0
429
+
430
+ ## model weights path
431
+ hypar["model_path"] = "saved_models"
432
+
433
+ ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
434
+ hypar["restore_model"] = ""
435
+
436
+ ## start iteration for the training, can be changed to match the restored training process
437
+ hypar["start_ite"] = 0
438
+
439
+ ## indicates "half" or "full" accuracy of float number
440
+ hypar["model_digit"] = "full"
441
+
442
+ ## To handle large size input images, which take a lot of time for loading in training,
443
+ # we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
444
+ hypar["cache_size"] = [
445
+ 1024,
446
+ 1024,
447
+ ]
448
+
449
+ ## cached input spatial resolution, can be configured into different size
450
+ ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
451
+ hypar["cache_boost_train"] = False
452
+
453
+ ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
454
+ hypar["cache_boost_valid"] = False
455
+
456
+ ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
457
+ hypar["early_stop"] = 20
458
+
459
+ ## valid and save model weights every 2000 iterations
460
+ hypar["model_save_fre"] = 2000
461
+
462
+ ## batch size for training
463
+ hypar["batch_size_train"] = 8
464
+
465
+ ## batch size for validation and inferencing
466
+ hypar["batch_size_valid"] = 1
467
+
468
+ ## if early stop couldn't stop the training process, stop it by the max_ite_num
469
+ hypar["max_ite"] = 10000000
470
+
471
+ ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
472
+ hypar["max_epoch_num"] = 1000000
473
+
474
+ main(train_datasets, valid_datasets, hypar=hypar)
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/architecture.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from ormbg.models.ormbg import ORMBG
2
+
3
+ if __name__ == "__main__":
4
+ print(ORMBG())
utils/loss_example.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import numpy as np
5
+ from skimage import io
6
+ from ormbg.models.ormbg import ORMBG
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser(
12
+ description="Remove background from images using ORMBG model."
13
+ )
14
+ parser.add_argument(
15
+ "--prediction",
16
+ type=list,
17
+ default=[
18
+ os.path.join("examples", "loss", "loss01.png"),
19
+ os.path.join("examples", "loss", "loss02.png"),
20
+ os.path.join("examples", "loss", "loss03.png"),
21
+ os.path.join("examples", "loss", "loss04.png"),
22
+ os.path.join("examples", "loss", "loss05.png"),
23
+ ],
24
+ help="Path to the input image file.",
25
+ )
26
+ parser.add_argument(
27
+ "--gt",
28
+ type=str,
29
+ default=os.path.join("examples", "loss", "gt.png"),
30
+ help="Ground truth mask",
31
+ )
32
+ return parser.parse_args()
33
+
34
+
35
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
36
+ if len(im.shape) < 3:
37
+ im = im[:, :, np.newaxis]
38
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
39
+ im_tensor = F.interpolate(
40
+ torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
41
+ ).type(torch.uint8)
42
+ image = torch.divide(im_tensor, 255.0)
43
+ return image
44
+
45
+
46
+ def inference(args):
47
+ prediction_paths = args.prediction
48
+ gt_path = args.gt
49
+
50
+ net = ORMBG()
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+
53
+ for pred_path in prediction_paths:
54
+
55
+ model_input_size = [1024, 1024]
56
+ loss = io.imread(pred_path)
57
+ prediction = preprocess_image(loss, model_input_size).to(device)
58
+
59
+ model_input_size = [1024, 1024]
60
+ gt = io.imread(gt_path)
61
+ ground_truth = preprocess_image(gt, model_input_size).to(device)
62
+
63
+ _, loss = net.compute_loss([prediction], ground_truth)
64
+
65
+ print(f"Loss: {pred_path} {loss}")
66
+
67
+
68
+ if __name__ == "__main__":
69
+ inference(parse_args())
utils/pth_to_onnx.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import argparse
3
- from ormbg import ORMBG
4
 
5
 
6
  def export_to_onnx(model_path, onnx_path):
@@ -44,13 +44,13 @@ if __name__ == "__main__":
44
  parser.add_argument(
45
  "--model_path",
46
  type=str,
47
- default="./models/ormbg.pth",
48
  help="The path to the trained model file.",
49
  )
50
  parser.add_argument(
51
  "--onnx_path",
52
  type=str,
53
- default="./models/gpu_itr_28000_traLoss_0.102_traTarLoss_0.0105_valLoss_0.1293_valTarLoss_0.015_maxF1_0.9947_mae_0.0059_time_0.015454.pth",
54
  help="The path where the ONNX model will be saved.",
55
  )
56
 
 
1
  import torch
2
  import argparse
3
+ from ormbg.models.ormbg import ORMBG
4
 
5
 
6
  def export_to_onnx(model_path, onnx_path):
 
44
  parser.add_argument(
45
  "--model_path",
46
  type=str,
47
+ default="models/ormbg.pth",
48
  help="The path to the trained model file.",
49
  )
50
  parser.add_argument(
51
  "--onnx_path",
52
  type=str,
53
+ default="models/ormbg.pth",
54
  help="The path where the ONNX model will be saved.",
55
  )
56