schirrmacher commited on
Commit
9a96838
·
verified ·
1 Parent(s): 29275af

Upload folder using huggingface_hub

Browse files
.gitignore DELETED
@@ -1,157 +0,0 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
- .pdm.toml
111
- .pdm-python
112
- .pdm-build/
113
-
114
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
- __pypackages__/
116
-
117
- # Celery stuff
118
- celerybeat-schedule
119
- celerybeat.pid
120
-
121
- # SageMath parsed files
122
- *.sage.py
123
-
124
- # Environments
125
- .env
126
- .venv
127
- env/
128
- venv/
129
- ENV/
130
- env.bak/
131
- venv.bak/
132
-
133
- # Spyder project settings
134
- .spyderproject
135
- .spyproject
136
-
137
- # Rope project settings
138
- .ropeproject
139
-
140
- # mkdocs documentation
141
- /site
142
-
143
- # mypy
144
- .mypy_cache/
145
- .dmypy.json
146
- dmypy.json
147
-
148
- # Pyre type checker
149
- .pyre/
150
-
151
- # pytype static type analyzer
152
- .pytype/
153
-
154
- # Cython debug symbols
155
- cython_debug/
156
-
157
- models/*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
IMG_1051.png DELETED

Git LFS Details

  • SHA256: fe4c6f3e70dfce78cbb26f24d83c1eab96b791972f8f1bffe6126eddc8edb78c
  • Pointer size: 132 Bytes
  • Size of remote file: 4.62 MB
dataset/training/gt/p_00a4eda7.png DELETED
Binary file (38.7 kB)
 
dataset/training/gt/p_00a5b702.png DELETED
Binary file (96.9 kB)
 
dataset/training/im/p_00a4eda7.png DELETED

Git LFS Details

  • SHA256: e226a687b5d755056076e12d7f2c24704d101ad90918554c43028e8c1e53638f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.74 MB
dataset/training/im/p_00a5b702.png DELETED

Git LFS Details

  • SHA256: 184b2d97ffdbffc9d0a5d3c3b84a848938df636855d59b81f3d109445a92b0ef
  • Pointer size: 132 Bytes
  • Size of remote file: 3.46 MB
dataset/validation/gt/p_00a7a27c.png DELETED
Binary file (81.8 kB)
 
dataset/validation/im/p_00a7a27c.png DELETED

Git LFS Details

  • SHA256: b87d59e4598ddc1078ebdc856e7101d92582315ecff2aecdadc17802e82bc8c1
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
environment.yaml DELETED
@@ -1,199 +0,0 @@
1
- name: ormbg
2
- channels:
3
- - pytorch
4
- - nvidia
5
- - anaconda
6
- - defaults
7
- dependencies:
8
- - _libgcc_mutex=0.1=main
9
- - _openmp_mutex=5.1=1_gnu
10
- - aom=3.6.0=h6a678d5_0
11
- - blas=1.0=mkl
12
- - blosc=1.21.3=h6a678d5_0
13
- - brotli=1.0.9=h5eee18b_7
14
- - brotli-bin=1.0.9=h5eee18b_7
15
- - brotli-python=1.0.9=py38h6a678d5_7
16
- - brunsli=0.1=h2531618_0
17
- - bzip2=1.0.8=h7b6447c_0
18
- - c-ares=1.19.1=h5eee18b_0
19
- - ca-certificates=2023.08.22=h06a4308_0
20
- - certifi=2023.7.22=py38h06a4308_0
21
- - cffi=1.15.0=py38h7f8727e_0
22
- - cfitsio=3.470=h5893167_7
23
- - charls=2.2.0=h2531618_0
24
- - charset-normalizer=2.0.4=pyhd3eb1b0_0
25
- - click=8.1.7=py38h06a4308_0
26
- - cloudpickle=2.2.1=py38h06a4308_0
27
- - contourpy=1.0.5=py38hdb19cb5_0
28
- - cryptography=41.0.3=py38h130f0dd_0
29
- - cuda-cudart=11.8.89=0
30
- - cuda-cupti=11.8.87=0
31
- - cuda-libraries=11.8.0=0
32
- - cuda-nvrtc=11.8.89=0
33
- - cuda-nvtx=11.8.86=0
34
- - cuda-runtime=11.8.0=0
35
- - cudatoolkit=11.8.0=h6a678d5_0
36
- - cycler=0.11.0=pyhd3eb1b0_0
37
- - cytoolz=0.12.0=py38h5eee18b_0
38
- - dask-core=2023.4.1=py38h06a4308_0
39
- - dav1d=1.2.1=h5eee18b_0
40
- - dbus=1.13.18=hb2f20db_0
41
- - expat=2.5.0=h6a678d5_0
42
- - ffmpeg=4.3=hf484d3e_0
43
- - fftw=3.3.9=h27cfd23_1
44
- - filelock=3.9.0=py38h06a4308_0
45
- - fontconfig=2.14.1=h52c9d5c_1
46
- - fonttools=4.25.0=pyhd3eb1b0_0
47
- - freetype=2.12.1=h4a9f257_0
48
- - fsspec=2023.9.2=py38h06a4308_0
49
- - giflib=5.2.1=h5eee18b_3
50
- - glib=2.63.1=h5a9c865_0
51
- - gmp=6.2.1=h295c915_3
52
- - gmpy2=2.1.2=py38heeb90bb_0
53
- - gnutls=3.6.15=he1e5248_0
54
- - gst-plugins-base=1.14.0=hbbd80ab_1
55
- - gstreamer=1.14.0=hb453b48_1
56
- - icu=58.2=he6710b0_3
57
- - idna=3.4=py38h06a4308_0
58
- - imagecodecs=2023.1.23=py38hc4b7b5f_0
59
- - imageio=2.31.4=py38h06a4308_0
60
- - importlib-metadata=6.0.0=py38h06a4308_0
61
- - importlib_resources=6.1.0=py38h06a4308_0
62
- - intel-openmp=2021.4.0=h06a4308_3561
63
- - jinja2=3.1.2=py38h06a4308_0
64
- - jpeg=9e=h5eee18b_1
65
- - jxrlib=1.1=h7b6447c_2
66
- - kiwisolver=1.4.4=py38h6a678d5_0
67
- - krb5=1.20.1=h568e23c_1
68
- - lame=3.100=h7b6447c_0
69
- - lazy_loader=0.3=py38h06a4308_0
70
- - lcms2=2.12=h3be6417_0
71
- - lerc=3.0=h295c915_0
72
- - libaec=1.0.4=he6710b0_1
73
- - libavif=0.11.1=h5eee18b_0
74
- - libbrotlicommon=1.0.9=h5eee18b_7
75
- - libbrotlidec=1.0.9=h5eee18b_7
76
- - libbrotlienc=1.0.9=h5eee18b_7
77
- - libcublas=11.11.3.6=0
78
- - libcufft=10.9.0.58=0
79
- - libcufile=1.8.1.2=0
80
- - libcurand=10.3.4.101=0
81
- - libcurl=7.88.1=h91b91d3_2
82
- - libcusolver=11.4.1.48=0
83
- - libcusparse=11.7.5.86=0
84
- - libdeflate=1.17=h5eee18b_1
85
- - libedit=3.1.20221030=h5eee18b_0
86
- - libev=4.33=h7f8727e_1
87
- - libffi=3.2.1=hf484d3e_1007
88
- - libgcc-ng=11.2.0=h1234567_1
89
- - libgfortran-ng=11.2.0=h00389a5_1
90
- - libgfortran5=11.2.0=h1234567_1
91
- - libgomp=11.2.0=h1234567_1
92
- - libiconv=1.16=h7f8727e_2
93
- - libidn2=2.3.4=h5eee18b_0
94
- - libjpeg-turbo=2.0.0=h9bf148f_0
95
- - libnghttp2=1.52.0=ha637b67_1
96
- - libnpp=11.8.0.86=0
97
- - libnvjpeg=11.9.0.86=0
98
- - libpng=1.6.39=h5eee18b_0
99
- - libssh2=1.10.0=h37d81fd_2
100
- - libstdcxx-ng=11.2.0=h1234567_1
101
- - libtasn1=4.19.0=h5eee18b_0
102
- - libtiff=4.5.1=h6a678d5_0
103
- - libunistring=0.9.10=h27cfd23_0
104
- - libuuid=1.41.5=h5eee18b_0
105
- - libwebp=1.3.2=h11a3e52_0
106
- - libwebp-base=1.3.2=h5eee18b_0
107
- - libxcb=1.15=h7f8727e_0
108
- - libxml2=2.9.14=h74e7548_0
109
- - libzopfli=1.0.3=he6710b0_0
110
- - llvm-openmp=14.0.6=h9e868ea_0
111
- - locket=1.0.0=py38h06a4308_0
112
- - lz4-c=1.9.4=h6a678d5_0
113
- - markupsafe=2.1.1=py38h7f8727e_0
114
- - matplotlib=3.7.2=py38h06a4308_0
115
- - matplotlib-base=3.7.2=py38h1128e8f_0
116
- - mkl=2021.4.0=h06a4308_640
117
- - mkl-service=2.4.0=py38h7f8727e_0
118
- - mkl_fft=1.3.1=py38hd3c417c_0
119
- - mkl_random=1.2.2=py38h51133e4_0
120
- - mpc=1.1.0=h10f8cd9_1
121
- - mpfr=4.0.2=hb69a4c5_1
122
- - mpmath=1.3.0=py38h06a4308_0
123
- - munkres=1.1.4=py_0
124
- - ncurses=6.4=h6a678d5_0
125
- - nettle=3.7.3=hbbd107a_1
126
- - networkx=3.1=py38h06a4308_0
127
- - openh264=2.1.1=h4ff587b_0
128
- - openjpeg=2.4.0=h3ad879b_0
129
- - openssl=1.1.1w=h7f8727e_0
130
- - packaging=23.1=py38h06a4308_0
131
- - partd=1.4.1=py38h06a4308_0
132
- - pcre=8.45=h295c915_0
133
- - pillow=10.0.1=py38ha6cbd5a_0
134
- - pip=23.3=py38h06a4308_0
135
- - pycparser=2.21=pyhd3eb1b0_0
136
- - pyopenssl=23.2.0=py38h06a4308_0
137
- - pyparsing=3.0.9=py38h06a4308_0
138
- - pyqt=5.9.2=py38h05f1152_4
139
- - pysocks=1.7.1=py38h06a4308_0
140
- - python=3.8.0=h0371630_2
141
- - python-dateutil=2.8.2=pyhd3eb1b0_0
142
- - pytorch=2.1.1=py3.8_cuda11.8_cudnn8.7.0_0
143
- - pytorch-cuda=11.8=h7e8668a_5
144
- - pytorch-mutex=1.0=cuda
145
- - pywavelets=1.4.1=py38h5eee18b_0
146
- - pyyaml=6.0.1=py38h5eee18b_0
147
- - qt=5.9.7=h5867ecd_1
148
- - readline=7.0=h7b6447c_5
149
- - requests=2.31.0=py38h06a4308_0
150
- - setuptools=68.0.0=py38h06a4308_0
151
- - sip=4.19.13=py38h295c915_0
152
- - six=1.16.0=pyhd3eb1b0_1
153
- - snappy=1.1.9=h295c915_0
154
- - sqlite=3.33.0=h62c20be_0
155
- - sympy=1.11.1=py38h06a4308_0
156
- - tifffile=2023.4.12=py38h06a4308_0
157
- - tk=8.6.12=h1ccaba5_0
158
- - toolz=0.12.0=py38h06a4308_0
159
- - torchaudio=2.1.1=py38_cu118
160
- - torchtriton=2.1.0=py38
161
- - torchvision=0.16.1=py38_cu118
162
- - tornado=6.3.3=py38h5eee18b_0
163
- - tqdm=4.65.0=py38hb070fc8_0
164
- - urllib3=1.26.18=py38h06a4308_0
165
- - wheel=0.41.2=py38h06a4308_0
166
- - xz=5.4.2=h5eee18b_0
167
- - yaml=0.2.5=h7b6447c_0
168
- - zfp=1.0.0=h6a678d5_0
169
- - zipp=3.11.0=py38h06a4308_0
170
- - zlib=1.2.13=h5eee18b_0
171
- - zstd=1.5.5=hc292b87_0
172
- - pip:
173
- - albucore==0.0.12
174
- - albumentations==1.4.11
175
- - annotated-types==0.7.0
176
- - appdirs==1.4.4
177
- - conda-pack==0.7.1
178
- - docker-pycreds==0.4.0
179
- - eval-type-backport==0.2.0
180
- - gitdb==4.0.11
181
- - gitpython==3.1.40
182
- - joblib==1.4.2
183
- - numpy==1.24.4
184
- - opencv-python-headless==4.10.0.84
185
- - protobuf==4.25.1
186
- - psutil==5.9.6
187
- - pydantic==2.8.2
188
- - pydantic-core==2.20.1
189
- - scikit-image==0.21.0
190
- - scikit-learn==1.3.2
191
- - scipy==1.10.1
192
- - sentry-sdk==1.35.0
193
- - setproctitle==1.3.3
194
- - smmap==5.0.1
195
- - threadpoolctl==3.5.0
196
- - tomli==2.0.1
197
- - typing-extensions==4.12.2
198
- - wandb==0.16.0
199
- prefix: /home/macher/miniconda3/envs/ormbg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example 2.png DELETED

Git LFS Details

  • SHA256: 5c7d7c861bd738768dcbefb98076c8423978d3108e6f0bcd482bb180a775a8af
  • Pointer size: 132 Bytes
  • Size of remote file: 4.77 MB
example1.jpeg DELETED

Git LFS Details

  • SHA256: 436f546cc1d7b2fd7021180299b028c0d379e48a9e9f05214a694b9c4eb8a7e3
  • Pointer size: 132 Bytes
  • Size of remote file: 7.63 MB
example1.png DELETED

Git LFS Details

  • SHA256: 42c8627c1ada7b69ef8561fcb5611cd8aa08af5eed211379a2619960524639c5
  • Pointer size: 132 Bytes
  • Size of remote file: 4.83 MB
example2.jpeg DELETED

Git LFS Details

  • SHA256: 1dad92b56723fd8ac1c3832844873ad297300d0e85f6e14764334687a70c8abc
  • Pointer size: 132 Bytes
  • Size of remote file: 4.32 MB
example2.png DELETED

Git LFS Details

  • SHA256: 5c7d7c861bd738768dcbefb98076c8423978d3108e6f0bcd482bb180a775a8af
  • Pointer size: 132 Bytes
  • Size of remote file: 4.77 MB
example3.jpeg DELETED

Git LFS Details

  • SHA256: f392dc4716469f5367ce0e2ac788f284d1b8d70c39be109db7038c3306a1da16
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
example3.png DELETED

Git LFS Details

  • SHA256: e024065f9c0f6c981c107c5a403b1cc8fd3dfd20ac37fa212b43e0e69ec1b8ae
  • Pointer size: 132 Bytes
  • Size of remote file: 4.81 MB
examples.jpg DELETED

Git LFS Details

  • SHA256: ca9ab255b054e237cb51072bf687bf5f044a902d494ab7aa14b931e450519358
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
examples/.DS_Store DELETED
Binary file (6.15 kB)
 
examples/image/example01.jpeg DELETED

Git LFS Details

  • SHA256: 436f546cc1d7b2fd7021180299b028c0d379e48a9e9f05214a694b9c4eb8a7e3
  • Pointer size: 132 Bytes
  • Size of remote file: 7.63 MB
examples/image/example02.jpeg DELETED

Git LFS Details

  • SHA256: 1dad92b56723fd8ac1c3832844873ad297300d0e85f6e14764334687a70c8abc
  • Pointer size: 132 Bytes
  • Size of remote file: 4.32 MB
examples/image/example03.jpeg DELETED

Git LFS Details

  • SHA256: f392dc4716469f5367ce0e2ac788f284d1b8d70c39be109db7038c3306a1da16
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
examples/image/image01.png DELETED

Git LFS Details

  • SHA256: 1c6d54789fc0d8816231ca9f061b19af50bdbfb59a4fed7fa6c7bd3168591b0e
  • Pointer size: 133 Bytes
  • Size of remote file: 16.7 MB
examples/image/image01_no_background.png DELETED

Git LFS Details

  • SHA256: 9290ced416914386458bded92614b3b620bf82fc9dc7b06b4015fc6791d34cc3
  • Pointer size: 133 Bytes
  • Size of remote file: 21.4 MB
examples/loss/gt.png DELETED
Binary file (258 kB)
 
examples/loss/loss01.png DELETED
Binary file (291 kB)
 
examples/loss/loss02.png DELETED
Binary file (417 kB)
 
examples/loss/loss03.png DELETED
Binary file (645 kB)
 
examples/loss/loss04.png DELETED
Binary file (794 kB)
 
examples/loss/loss05.png DELETED
Binary file (983 kB)
 
examples/loss/orginal.jpg DELETED
Binary file (366 kB)
 
explanation.jpg DELETED
Binary file (713 kB)
 
hf_space.py DELETED
@@ -1,88 +0,0 @@
1
- import spaces
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
- import gradio as gr
6
- from ormbg.models.ormbg import ORMBG
7
- from PIL import Image
8
-
9
- model_path = "models/ormbg.pth"
10
-
11
- # Load the model globally but don't send to device yet
12
- net = ORMBG()
13
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
14
- net.eval()
15
-
16
-
17
- def resize_image(image):
18
- image = image.convert("RGB")
19
- model_input_size = (1024, 1024)
20
- image = image.resize(model_input_size, Image.BILINEAR)
21
- return image
22
-
23
-
24
- @spaces.GPU
25
- @torch.inference_mode()
26
- def inference(image):
27
- # Check for CUDA and set the device inside inference
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- net.to(device)
30
-
31
- # Prepare input
32
- orig_image = Image.fromarray(image)
33
- w, h = orig_image.size
34
- image = resize_image(orig_image)
35
- im_np = np.array(image)
36
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
37
- im_tensor = torch.unsqueeze(im_tensor, 0)
38
- im_tensor = torch.divide(im_tensor, 255.0)
39
-
40
- if torch.cuda.is_available():
41
- im_tensor = im_tensor.to(device)
42
-
43
- # Inference
44
- result = net(im_tensor)
45
- # Post process
46
- result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
47
- ma = torch.max(result)
48
- mi = torch.min(result)
49
- result = (result - mi) / (ma - mi)
50
- # Image to PIL
51
- im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
- pil_im = Image.fromarray(np.squeeze(im_array))
53
- # Paste the mask on the original image
54
- new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
55
- new_im.paste(orig_image, mask=pil_im)
56
-
57
- return new_im
58
-
59
-
60
- # Gradio interface setup
61
- title = "Open Remove Background Model (ormbg)"
62
- description = r"""
63
- This model is a <strong>fully open-source background remover</strong> optimized for images with humans. It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS). The model was trained with the synthetic <a href="https://huggingface.co/datasets/schirrmacher/humans">Human Segmentation Dataset</a>, <a href="https://paperswithcode.com/dataset/p3m-10k">P3M-10k</a> and <a href="https://paperswithcode.com/dataset/aim-500">AIM-500</a>.
64
-
65
- If you identify cases where the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
66
-
67
- - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
68
- - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
69
- - <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
70
- """
71
-
72
- examples = [
73
- "./examples/image/example1.jpeg",
74
- "./examples/image/example2.jpeg",
75
- "./examples/image/example3.jpeg",
76
- ]
77
-
78
- demo = gr.Interface(
79
- fn=inference,
80
- inputs="image",
81
- outputs="image",
82
- examples=examples,
83
- title=title,
84
- description=description,
85
- )
86
-
87
- if __name__ == "__main__":
88
- demo.launch(share=False, allowed_paths=["ormbg", "models", "examples"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf_space/app.py DELETED
@@ -1,90 +0,0 @@
1
- import spaces
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
- import gradio as gr
6
- from ormbg import ORMBG
7
- from PIL import Image
8
-
9
- model_path = "../models/ormbg.pth"
10
-
11
- # Load the model globally but don't send to device yet
12
- net = ORMBG()
13
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
14
- net.eval()
15
-
16
-
17
- def resize_image(image):
18
- image = image.convert("RGB")
19
- model_input_size = (1024, 1024)
20
- image = image.resize(model_input_size, Image.BILINEAR)
21
- return image
22
-
23
-
24
- @spaces.GPU
25
- @torch.inference_mode()
26
- def inference(image):
27
- # Check for CUDA and set the device inside inference
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- net.to(device)
30
-
31
- # Prepare input
32
- orig_image = Image.fromarray(image)
33
- w, h = orig_image.size
34
- image = resize_image(orig_image)
35
- im_np = np.array(image)
36
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
37
- im_tensor = torch.unsqueeze(im_tensor, 0)
38
- im_tensor = torch.divide(im_tensor, 255.0)
39
-
40
- if torch.cuda.is_available():
41
- im_tensor = im_tensor.to(device)
42
-
43
- # Inference
44
- result = net(im_tensor)
45
- # Post process
46
- result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
47
- ma = torch.max(result)
48
- mi = torch.min(result)
49
- result = (result - mi) / (ma - mi)
50
- # Image to PIL
51
- im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
- pil_im = Image.fromarray(np.squeeze(im_array))
53
- # Paste the mask on the original image
54
- new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
55
- new_im.paste(orig_image, mask=pil_im)
56
-
57
- return new_im
58
-
59
-
60
- # Gradio interface setup
61
- title = "Open Remove Background Model (ormbg)"
62
- description = r"""
63
- This model is a <strong>fully open-source background remover</strong> optimized for images with humans. It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS). The model was trained with the synthetic <a href="https://huggingface.co/datasets/schirrmacher/humans">Human Segmentation Dataset</a>, <a href="https://paperswithcode.com/dataset/p3m-10k">P3M-10k</a> and <a href="https://paperswithcode.com/dataset/aim-500">AIM-500</a>.
64
-
65
- If you identify cases where the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
66
-
67
- - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
68
- - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
69
- - <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
70
- """
71
-
72
- examples = [
73
- "example1.jpeg",
74
- "example2.jpeg",
75
- "example3.jpeg",
76
- ]
77
-
78
- demo = gr.Interface(
79
- fn=inference,
80
- inputs="image",
81
- outputs="image",
82
- examples=examples,
83
- title=title,
84
- description=description,
85
- )
86
-
87
- if __name__ == "__main__":
88
- demo.launch(
89
- share=False, root_path="../", allowed_paths=["../hf_space", "../models"]
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf_space/example01.jpeg DELETED

Git LFS Details

  • SHA256: 436f546cc1d7b2fd7021180299b028c0d379e48a9e9f05214a694b9c4eb8a7e3
  • Pointer size: 132 Bytes
  • Size of remote file: 7.63 MB
hf_space/example02.jpeg DELETED

Git LFS Details

  • SHA256: 1dad92b56723fd8ac1c3832844873ad297300d0e85f6e14764334687a70c8abc
  • Pointer size: 132 Bytes
  • Size of remote file: 4.32 MB
hf_space/example03.jpeg DELETED

Git LFS Details

  • SHA256: f392dc4716469f5367ce0e2ac788f284d1b8d70c39be109db7038c3306a1da16
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
hf_space/ormbg.py DELETED
@@ -1,484 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
6
-
7
-
8
- class REBNCONV(nn.Module):
9
- def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
10
- super(REBNCONV, self).__init__()
11
-
12
- self.conv_s1 = nn.Conv2d(
13
- in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
14
- )
15
- self.bn_s1 = nn.BatchNorm2d(out_ch)
16
- self.relu_s1 = nn.ReLU(inplace=True)
17
-
18
- def forward(self, x):
19
-
20
- hx = x
21
- xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
22
-
23
- return xout
24
-
25
-
26
- ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
27
- def _upsample_like(src, tar):
28
-
29
- src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
-
31
- return src
32
-
33
-
34
- ### RSU-7 ###
35
- class RSU7(nn.Module):
36
-
37
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
38
- super(RSU7, self).__init__()
39
-
40
- self.in_ch = in_ch
41
- self.mid_ch = mid_ch
42
- self.out_ch = out_ch
43
-
44
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
45
-
46
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
47
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
-
49
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
50
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
-
52
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
53
- self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
54
-
55
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
56
- self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
-
58
- self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
59
- self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
60
-
61
- self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
62
-
63
- self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
64
-
65
- self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
- self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
- self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
69
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
70
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
71
-
72
- def forward(self, x):
73
- b, c, h, w = x.shape
74
-
75
- hx = x
76
- hxin = self.rebnconvin(hx)
77
-
78
- hx1 = self.rebnconv1(hxin)
79
- hx = self.pool1(hx1)
80
-
81
- hx2 = self.rebnconv2(hx)
82
- hx = self.pool2(hx2)
83
-
84
- hx3 = self.rebnconv3(hx)
85
- hx = self.pool3(hx3)
86
-
87
- hx4 = self.rebnconv4(hx)
88
- hx = self.pool4(hx4)
89
-
90
- hx5 = self.rebnconv5(hx)
91
- hx = self.pool5(hx5)
92
-
93
- hx6 = self.rebnconv6(hx)
94
-
95
- hx7 = self.rebnconv7(hx6)
96
-
97
- hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
98
- hx6dup = _upsample_like(hx6d, hx5)
99
-
100
- hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
101
- hx5dup = _upsample_like(hx5d, hx4)
102
-
103
- hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
104
- hx4dup = _upsample_like(hx4d, hx3)
105
-
106
- hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
107
- hx3dup = _upsample_like(hx3d, hx2)
108
-
109
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
110
- hx2dup = _upsample_like(hx2d, hx1)
111
-
112
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
113
-
114
- return hx1d + hxin
115
-
116
-
117
- ### RSU-6 ###
118
- class RSU6(nn.Module):
119
-
120
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
121
- super(RSU6, self).__init__()
122
-
123
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
124
-
125
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
126
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
-
128
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
-
131
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
- self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
-
134
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
- self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
136
-
137
- self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
138
-
139
- self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
140
-
141
- self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
- self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
143
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
144
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
145
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
146
-
147
- def forward(self, x):
148
-
149
- hx = x
150
-
151
- hxin = self.rebnconvin(hx)
152
-
153
- hx1 = self.rebnconv1(hxin)
154
- hx = self.pool1(hx1)
155
-
156
- hx2 = self.rebnconv2(hx)
157
- hx = self.pool2(hx2)
158
-
159
- hx3 = self.rebnconv3(hx)
160
- hx = self.pool3(hx3)
161
-
162
- hx4 = self.rebnconv4(hx)
163
- hx = self.pool4(hx4)
164
-
165
- hx5 = self.rebnconv5(hx)
166
-
167
- hx6 = self.rebnconv6(hx5)
168
-
169
- hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
170
- hx5dup = _upsample_like(hx5d, hx4)
171
-
172
- hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
173
- hx4dup = _upsample_like(hx4d, hx3)
174
-
175
- hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
176
- hx3dup = _upsample_like(hx3d, hx2)
177
-
178
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
179
- hx2dup = _upsample_like(hx2d, hx1)
180
-
181
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
182
-
183
- return hx1d + hxin
184
-
185
-
186
- ### RSU-5 ###
187
- class RSU5(nn.Module):
188
-
189
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
190
- super(RSU5, self).__init__()
191
-
192
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
193
-
194
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
195
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
196
-
197
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
198
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
199
-
200
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
201
- self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
202
-
203
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
204
-
205
- self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
206
-
207
- self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
209
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
211
-
212
- def forward(self, x):
213
-
214
- hx = x
215
-
216
- hxin = self.rebnconvin(hx)
217
-
218
- hx1 = self.rebnconv1(hxin)
219
- hx = self.pool1(hx1)
220
-
221
- hx2 = self.rebnconv2(hx)
222
- hx = self.pool2(hx2)
223
-
224
- hx3 = self.rebnconv3(hx)
225
- hx = self.pool3(hx3)
226
-
227
- hx4 = self.rebnconv4(hx)
228
-
229
- hx5 = self.rebnconv5(hx4)
230
-
231
- hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
232
- hx4dup = _upsample_like(hx4d, hx3)
233
-
234
- hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
235
- hx3dup = _upsample_like(hx3d, hx2)
236
-
237
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
238
- hx2dup = _upsample_like(hx2d, hx1)
239
-
240
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
241
-
242
- return hx1d + hxin
243
-
244
-
245
- ### RSU-4 ###
246
- class RSU4(nn.Module):
247
-
248
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
249
- super(RSU4, self).__init__()
250
-
251
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
252
-
253
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
254
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
255
-
256
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
257
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
258
-
259
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
260
-
261
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
262
-
263
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
264
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
265
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
266
-
267
- def forward(self, x):
268
-
269
- hx = x
270
-
271
- hxin = self.rebnconvin(hx)
272
-
273
- hx1 = self.rebnconv1(hxin)
274
- hx = self.pool1(hx1)
275
-
276
- hx2 = self.rebnconv2(hx)
277
- hx = self.pool2(hx2)
278
-
279
- hx3 = self.rebnconv3(hx)
280
-
281
- hx4 = self.rebnconv4(hx3)
282
-
283
- hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
284
- hx3dup = _upsample_like(hx3d, hx2)
285
-
286
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
287
- hx2dup = _upsample_like(hx2d, hx1)
288
-
289
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
290
-
291
- return hx1d + hxin
292
-
293
-
294
- ### RSU-4F ###
295
- class RSU4F(nn.Module):
296
-
297
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
298
- super(RSU4F, self).__init__()
299
-
300
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
301
-
302
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
303
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
304
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
305
-
306
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
307
-
308
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
309
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
310
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
311
-
312
- def forward(self, x):
313
-
314
- hx = x
315
-
316
- hxin = self.rebnconvin(hx)
317
-
318
- hx1 = self.rebnconv1(hxin)
319
- hx2 = self.rebnconv2(hx1)
320
- hx3 = self.rebnconv3(hx2)
321
-
322
- hx4 = self.rebnconv4(hx3)
323
-
324
- hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
325
- hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
326
- hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
327
-
328
- return hx1d + hxin
329
-
330
-
331
- class myrebnconv(nn.Module):
332
- def __init__(
333
- self,
334
- in_ch=3,
335
- out_ch=1,
336
- kernel_size=3,
337
- stride=1,
338
- padding=1,
339
- dilation=1,
340
- groups=1,
341
- ):
342
- super(myrebnconv, self).__init__()
343
-
344
- self.conv = nn.Conv2d(
345
- in_ch,
346
- out_ch,
347
- kernel_size=kernel_size,
348
- stride=stride,
349
- padding=padding,
350
- dilation=dilation,
351
- groups=groups,
352
- )
353
- self.bn = nn.BatchNorm2d(out_ch)
354
- self.rl = nn.ReLU(inplace=True)
355
-
356
- def forward(self, x):
357
- return self.rl(self.bn(self.conv(x)))
358
-
359
-
360
- bce_loss = nn.BCELoss(size_average=True)
361
-
362
-
363
- class ORMBG(nn.Module):
364
-
365
- def __init__(self, in_ch=3, out_ch=1):
366
- super(ORMBG, self).__init__()
367
-
368
- self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
369
- self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
-
371
- self.stage1 = RSU7(64, 32, 64)
372
- self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
373
-
374
- self.stage2 = RSU6(64, 32, 128)
375
- self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
376
-
377
- self.stage3 = RSU5(128, 64, 256)
378
- self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
379
-
380
- self.stage4 = RSU4(256, 128, 512)
381
- self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
382
-
383
- self.stage5 = RSU4F(512, 256, 512)
384
- self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
385
-
386
- self.stage6 = RSU4F(512, 256, 512)
387
-
388
- # decoder
389
- self.stage5d = RSU4F(1024, 256, 512)
390
- self.stage4d = RSU4(1024, 128, 256)
391
- self.stage3d = RSU5(512, 64, 128)
392
- self.stage2d = RSU6(256, 32, 64)
393
- self.stage1d = RSU7(128, 16, 64)
394
-
395
- self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
396
- self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
397
- self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
398
- self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
399
- self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
400
- self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
401
-
402
- # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
403
-
404
- def compute_loss(self, predictions, ground_truth):
405
- loss0, loss = 0.0, 0.0
406
- for i in range(0, len(predictions)):
407
- loss = loss + bce_loss(predictions[i], ground_truth)
408
- if i == 0:
409
- loss0 = loss
410
- return loss0, loss
411
-
412
- def forward(self, x):
413
-
414
- hx = x
415
-
416
- hxin = self.conv_in(hx)
417
- # hx = self.pool_in(hxin)
418
-
419
- # stage 1
420
- hx1 = self.stage1(hxin)
421
- hx = self.pool12(hx1)
422
-
423
- # stage 2
424
- hx2 = self.stage2(hx)
425
- hx = self.pool23(hx2)
426
-
427
- # stage 3
428
- hx3 = self.stage3(hx)
429
- hx = self.pool34(hx3)
430
-
431
- # stage 4
432
- hx4 = self.stage4(hx)
433
- hx = self.pool45(hx4)
434
-
435
- # stage 5
436
- hx5 = self.stage5(hx)
437
- hx = self.pool56(hx5)
438
-
439
- # stage 6
440
- hx6 = self.stage6(hx)
441
- hx6up = _upsample_like(hx6, hx5)
442
-
443
- # -------------------- decoder --------------------
444
- hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
445
- hx5dup = _upsample_like(hx5d, hx4)
446
-
447
- hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
448
- hx4dup = _upsample_like(hx4d, hx3)
449
-
450
- hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
451
- hx3dup = _upsample_like(hx3d, hx2)
452
-
453
- hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
454
- hx2dup = _upsample_like(hx2d, hx1)
455
-
456
- hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
457
-
458
- # side output
459
- d1 = self.side1(hx1d)
460
- d1 = _upsample_like(d1, x)
461
-
462
- d2 = self.side2(hx2d)
463
- d2 = _upsample_like(d2, x)
464
-
465
- d3 = self.side3(hx3d)
466
- d3 = _upsample_like(d3, x)
467
-
468
- d4 = self.side4(hx4d)
469
- d4 = _upsample_like(d4, x)
470
-
471
- d5 = self.side5(hx5d)
472
- d5 = _upsample_like(d5, x)
473
-
474
- d6 = self.side6(hx6)
475
- d6 = _upsample_like(d6, x)
476
-
477
- return [
478
- F.sigmoid(d1),
479
- F.sigmoid(d2),
480
- F.sigmoid(d3),
481
- F.sigmoid(d4),
482
- F.sigmoid(d5),
483
- F.sigmoid(d6),
484
- ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
input.png DELETED

Git LFS Details

  • SHA256: 42c8627c1ada7b69ef8561fcb5611cd8aa08af5eed211379a2619960524639c5
  • Pointer size: 132 Bytes
  • Size of remote file: 4.83 MB
ormbg/.DS_Store DELETED
Binary file (6.15 kB)
 
ormbg/basics.py DELETED
@@ -1,79 +0,0 @@
1
- import os
2
-
3
- # os.environ['CUDA_VISIBLE_DEVICES'] = '2'
4
- from skimage import io, transform
5
- import torch
6
- import torchvision
7
- from torch.autograd import Variable
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- from torch.utils.data import Dataset, DataLoader
11
- from torchvision import transforms, utils
12
- import torch.optim as optim
13
-
14
- import matplotlib.pyplot as plt
15
- import numpy as np
16
- from PIL import Image
17
- import glob
18
-
19
-
20
- def mae_torch(pred, gt):
21
-
22
- h, w = gt.shape[0:2]
23
- sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
24
- maeError = torch.divide(sumError, float(h) * float(w) * 255.0 + 1e-4)
25
-
26
- return maeError
27
-
28
-
29
- def f1score_torch(pd, gt):
30
-
31
- # print(gt.shape)
32
- gtNum = torch.sum((gt > 128).float() * 1) ## number of ground truth pixels
33
-
34
- pp = pd[gt > 128]
35
- nn = pd[gt <= 128]
36
-
37
- pp_hist = torch.histc(pp, bins=255, min=0, max=255)
38
- nn_hist = torch.histc(nn, bins=255, min=0, max=255)
39
-
40
- pp_hist_flip = torch.flipud(pp_hist)
41
- nn_hist_flip = torch.flipud(nn_hist)
42
-
43
- pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
44
- nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
45
-
46
- precision = (pp_hist_flip_cum) / (
47
- pp_hist_flip_cum + nn_hist_flip_cum + 1e-4
48
- ) # torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
49
- recall = (pp_hist_flip_cum) / (gtNum + 1e-4)
50
- f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 1e-4)
51
-
52
- return (
53
- torch.reshape(precision, (1, precision.shape[0])),
54
- torch.reshape(recall, (1, recall.shape[0])),
55
- torch.reshape(f1, (1, f1.shape[0])),
56
- )
57
-
58
-
59
- def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
60
-
61
- import time
62
-
63
- tic = time.time()
64
-
65
- if len(gt.shape) > 2:
66
- gt = gt[:, :, 0]
67
-
68
- pre, rec, f1 = f1score_torch(pred, gt)
69
- mae = mae_torch(pred, gt)
70
-
71
- print(valid_dataset.dataset["im_name"][idx] + ".png")
72
- print("time for evaluation : ", time.time() - tic)
73
-
74
- return (
75
- pre.cpu().data.numpy(),
76
- rec.cpu().data.numpy(),
77
- f1.cpu().data.numpy(),
78
- mae.cpu().data.numpy(),
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ormbg/data_loader_cache.py DELETED
@@ -1,489 +0,0 @@
1
- ## data loader
2
- ## Ackownledgement:
3
- ## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
4
- ## for his helps in implementing cache machanism of our DIS dataloader.
5
- from __future__ import print_function, division
6
-
7
- import albumentations as A
8
- import numpy as np
9
- import random
10
- from copy import deepcopy
11
- import json
12
- from tqdm import tqdm
13
- from skimage import io
14
- import os
15
- from glob import glob
16
-
17
- import torch
18
- from torch.utils.data import Dataset, DataLoader
19
- from torchvision import transforms
20
- from torchvision.transforms.functional import normalize
21
- import torch.nn.functional as F
22
-
23
- #### --------------------- DIS dataloader cache ---------------------####
24
-
25
-
26
- def get_im_gt_name_dict(datasets, flag="valid"):
27
- print("------------------------------", flag, "--------------------------------")
28
- name_im_gt_list = []
29
- for i in range(len(datasets)):
30
- print(
31
- "--->>>",
32
- flag,
33
- " dataset ",
34
- i,
35
- "/",
36
- len(datasets),
37
- " ",
38
- datasets[i]["name"],
39
- "<<<---",
40
- )
41
- tmp_im_list, tmp_gt_list = [], []
42
- im_dir = datasets[i]["im_dir"]
43
- gt_dir = datasets[i]["gt_dir"]
44
- tmp_im_list = glob(os.path.join(im_dir, "*" + "*.[jp][pn]g"))
45
- tmp_gt_list = glob(os.path.join(gt_dir, "*" + "*.[jp][pn]g"))
46
-
47
- print(
48
- "-im-", datasets[i]["name"], datasets[i]["im_dir"], ": ", len(tmp_im_list)
49
- )
50
-
51
- print(
52
- "-gt-",
53
- datasets[i]["name"],
54
- datasets[i]["gt_dir"],
55
- ": ",
56
- len(tmp_gt_list),
57
- )
58
-
59
- if flag == "train": ## combine multiple training sets into one dataset
60
- if len(name_im_gt_list) == 0:
61
- name_im_gt_list.append(
62
- {
63
- "dataset_name": datasets[i]["name"],
64
- "im_path": tmp_im_list,
65
- "gt_path": tmp_gt_list,
66
- "im_ext": datasets[i]["im_ext"],
67
- "gt_ext": datasets[i]["gt_ext"],
68
- "cache_dir": datasets[i]["cache_dir"],
69
- }
70
- )
71
- else:
72
- name_im_gt_list[0]["dataset_name"] = (
73
- name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"]
74
- )
75
- name_im_gt_list[0]["im_path"] = (
76
- name_im_gt_list[0]["im_path"] + tmp_im_list
77
- )
78
- name_im_gt_list[0]["gt_path"] = (
79
- name_im_gt_list[0]["gt_path"] + tmp_gt_list
80
- )
81
- if datasets[i]["im_ext"] != ".jpg" or datasets[i]["gt_ext"] != ".png":
82
- print(
83
- "Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!"
84
- )
85
- exit()
86
- name_im_gt_list[0]["im_ext"] = ".jpg"
87
- name_im_gt_list[0]["gt_ext"] = ".png"
88
- name_im_gt_list[0]["cache_dir"] = (
89
- os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])
90
- + os.sep
91
- + name_im_gt_list[0]["dataset_name"]
92
- )
93
- else: ## keep different validation or inference datasets as separate ones
94
- name_im_gt_list.append(
95
- {
96
- "dataset_name": datasets[i]["name"],
97
- "im_path": tmp_im_list,
98
- "gt_path": tmp_gt_list,
99
- "im_ext": datasets[i]["im_ext"],
100
- "gt_ext": datasets[i]["gt_ext"],
101
- "cache_dir": datasets[i]["cache_dir"],
102
- }
103
- )
104
-
105
- return name_im_gt_list
106
-
107
-
108
- def create_dataloaders(
109
- name_im_gt_list,
110
- cache_size=[],
111
- cache_boost=True,
112
- my_transforms=[],
113
- batch_size=1,
114
- shuffle=False,
115
- ):
116
- ## model="train": return one dataloader for training
117
- ## model="valid": return a list of dataloaders for validation or testing
118
-
119
- gos_dataloaders = []
120
- gos_datasets = []
121
-
122
- if len(name_im_gt_list) == 0:
123
- return gos_dataloaders, gos_datasets
124
-
125
- num_workers_ = 1
126
- if batch_size > 1:
127
- num_workers_ = 2
128
- if batch_size > 4:
129
- num_workers_ = 4
130
- if batch_size > 8:
131
- num_workers_ = 8
132
-
133
- for i in range(0, len(name_im_gt_list)):
134
- gos_dataset = GOSDatasetCache(
135
- [name_im_gt_list[i]],
136
- cache_size=cache_size,
137
- cache_path=name_im_gt_list[i]["cache_dir"],
138
- cache_boost=cache_boost,
139
- transform=transforms.Compose(my_transforms),
140
- )
141
- gos_dataloaders.append(
142
- DataLoader(
143
- gos_dataset,
144
- batch_size=batch_size,
145
- shuffle=shuffle,
146
- num_workers=num_workers_,
147
- )
148
- )
149
- gos_datasets.append(gos_dataset)
150
-
151
- return gos_dataloaders, gos_datasets
152
-
153
-
154
- def im_reader(im_path):
155
- return io.imread(im_path)
156
-
157
-
158
- def im_preprocess(im, size):
159
- if len(im.shape) < 3:
160
- im = im[:, :, np.newaxis]
161
- if im.shape[2] == 1:
162
- im = np.repeat(im, 3, axis=2)
163
- im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
164
- im_tensor = torch.transpose(torch.transpose(im_tensor, 1, 2), 0, 1)
165
- if len(size) < 2:
166
- return im_tensor, im.shape[0:2]
167
- else:
168
- im_tensor = torch.unsqueeze(im_tensor, 0)
169
- im_tensor = F.upsample(im_tensor, size, mode="bilinear")
170
- im_tensor = torch.squeeze(im_tensor, 0)
171
-
172
- return im_tensor.type(torch.uint8), im.shape[0:2]
173
-
174
-
175
- def gt_preprocess(gt, size):
176
- if len(gt.shape) > 2:
177
- gt = gt[:, :, 0]
178
-
179
- gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8), 0)
180
-
181
- if len(size) < 2:
182
- return gt_tensor.type(torch.uint8), gt.shape[0:2]
183
- else:
184
- gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32), 0)
185
- gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
186
- gt_tensor = torch.squeeze(gt_tensor, 0)
187
-
188
- return gt_tensor.type(torch.uint8), gt.shape[0:2]
189
- # return gt_tensor, gt.shape[0:2]
190
-
191
-
192
- class GOSGridDropout(object):
193
- def __init__(
194
- self,
195
- ratio=0.5,
196
- unit_size_min=100,
197
- unit_size_max=100,
198
- holes_number_x=None,
199
- holes_number_y=None,
200
- shift_x=0,
201
- shift_y=0,
202
- random_offset=True,
203
- fill_value=0,
204
- mask_fill_value=None,
205
- always_apply=None,
206
- p=1.0,
207
- ):
208
- self.transform = A.GridDropout(
209
- ratio=ratio,
210
- unit_size_min=unit_size_min,
211
- unit_size_max=unit_size_max,
212
- holes_number_x=holes_number_x,
213
- holes_number_y=holes_number_y,
214
- shift_x=shift_x,
215
- shift_y=shift_y,
216
- random_offset=random_offset,
217
- fill_value=fill_value,
218
- mask_fill_value=mask_fill_value,
219
- always_apply=always_apply,
220
- p=p,
221
- )
222
-
223
- def __call__(self, sample):
224
- imidx, image, label, shape = (
225
- sample["imidx"],
226
- sample["image"],
227
- sample["label"],
228
- sample["shape"],
229
- )
230
-
231
- # Convert the torch tensors to numpy arrays
232
- image_np = image.permute(1, 2, 0).numpy()
233
-
234
- augmented = self.transform(image=image_np)
235
-
236
- # Convert the numpy arrays back to torch tensors
237
- image = torch.tensor(augmented["image"]).permute(2, 0, 1)
238
-
239
- return {"imidx": imidx, "image": image, "label": label, "shape": shape}
240
-
241
-
242
- class GOSRandomHFlip(object):
243
- def __init__(self, prob=0.5):
244
- self.prob = prob
245
-
246
- def __call__(self, sample):
247
- imidx, image, label, shape = (
248
- sample["imidx"],
249
- sample["image"],
250
- sample["label"],
251
- sample["shape"],
252
- )
253
-
254
- # random horizontal flip
255
- if random.random() >= self.prob:
256
- image = torch.flip(image, dims=[2])
257
- label = torch.flip(label, dims=[2])
258
-
259
- return {"imidx": imidx, "image": image, "label": label, "shape": shape}
260
-
261
-
262
- class GOSDatasetCache(Dataset):
263
-
264
- def __init__(
265
- self,
266
- name_im_gt_list,
267
- cache_size=[],
268
- cache_path="./cache",
269
- cache_file_name="dataset.json",
270
- cache_boost=False,
271
- transform=None,
272
- ):
273
-
274
- self.cache_size = cache_size
275
- self.cache_path = cache_path
276
- self.cache_file_name = cache_file_name
277
- self.cache_boost_name = ""
278
-
279
- self.cache_boost = cache_boost
280
- # self.ims_npy = None
281
- # self.gts_npy = None
282
-
283
- ## cache all the images and ground truth into a single pytorch tensor
284
- self.ims_pt = None
285
- self.gts_pt = None
286
-
287
- ## we will cache the npy as well regardless of the cache_boost
288
- # if(self.cache_boost):
289
- self.cache_boost_name = cache_file_name.split(".json")[0]
290
-
291
- self.transform = transform
292
-
293
- self.dataset = {}
294
-
295
- ## combine different datasets into one
296
- dataset_names = []
297
- dt_name_list = [] # dataset name per image
298
- im_name_list = [] # image name
299
- im_path_list = [] # im path
300
- gt_path_list = [] # gt path
301
- im_ext_list = [] # im ext
302
- gt_ext_list = [] # gt ext
303
- for i in range(0, len(name_im_gt_list)):
304
- dataset_names.append(name_im_gt_list[i]["dataset_name"])
305
- # dataset name repeated based on the number of images in this dataset
306
- dt_name_list.extend(
307
- [
308
- name_im_gt_list[i]["dataset_name"]
309
- for x in name_im_gt_list[i]["im_path"]
310
- ]
311
- )
312
- im_name_list.extend(
313
- [
314
- x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0]
315
- for x in name_im_gt_list[i]["im_path"]
316
- ]
317
- )
318
- im_path_list.extend(name_im_gt_list[i]["im_path"])
319
- gt_path_list.extend(name_im_gt_list[i]["gt_path"])
320
- im_ext_list.extend(
321
- [name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]]
322
- )
323
- gt_ext_list.extend(
324
- [name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]]
325
- )
326
-
327
- self.dataset["data_name"] = dt_name_list
328
- self.dataset["im_name"] = im_name_list
329
- self.dataset["im_path"] = im_path_list
330
- self.dataset["ori_im_path"] = deepcopy(im_path_list)
331
- self.dataset["gt_path"] = gt_path_list
332
- self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
333
- self.dataset["im_shp"] = []
334
- self.dataset["gt_shp"] = []
335
- self.dataset["im_ext"] = im_ext_list
336
- self.dataset["gt_ext"] = gt_ext_list
337
-
338
- self.dataset["ims_pt_dir"] = ""
339
- self.dataset["gts_pt_dir"] = ""
340
-
341
- self.dataset = self.manage_cache(dataset_names)
342
-
343
- def manage_cache(self, dataset_names):
344
- if not os.path.exists(self.cache_path): # create the folder for cache
345
- os.makedirs(self.cache_path)
346
- cache_folder = os.path.join(
347
- self.cache_path,
348
- "_".join(dataset_names) + "_" + "x".join([str(x) for x in self.cache_size]),
349
- )
350
- if not os.path.exists(
351
- cache_folder
352
- ): # check if the cache files are there, if not then cache
353
- return self.cache(cache_folder)
354
- return self.load_cache(cache_folder)
355
-
356
- def cache(self, cache_folder):
357
- os.mkdir(cache_folder)
358
- cached_dataset = deepcopy(self.dataset)
359
-
360
- # ims_list = []
361
- # gts_list = []
362
- ims_pt_list = []
363
- gts_pt_list = []
364
- for i, im_path in tqdm(
365
- enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])
366
- ):
367
-
368
- im_id = cached_dataset["im_name"][i]
369
- print("im_path: ", im_path)
370
- im = im_reader(im_path)
371
- im, im_shp = im_preprocess(im, self.cache_size)
372
- im_cache_file = os.path.join(
373
- cache_folder, self.dataset["data_name"][i] + "_" + im_id + "_im.pt"
374
- )
375
- torch.save(im, im_cache_file)
376
-
377
- cached_dataset["im_path"][i] = im_cache_file
378
- if self.cache_boost:
379
- ims_pt_list.append(torch.unsqueeze(im, 0))
380
- # ims_list.append(im.cpu().data.numpy().astype(np.uint8))
381
-
382
- gt = np.zeros(im.shape[0:2])
383
- if len(self.dataset["gt_path"]) != 0:
384
- gt = im_reader(self.dataset["gt_path"][i])
385
- gt, gt_shp = gt_preprocess(gt, self.cache_size)
386
- gt_cache_file = os.path.join(
387
- cache_folder, self.dataset["data_name"][i] + "_" + im_id + "_gt.pt"
388
- )
389
- torch.save(gt, gt_cache_file)
390
- if len(self.dataset["gt_path"]) > 0:
391
- cached_dataset["gt_path"][i] = gt_cache_file
392
- else:
393
- cached_dataset["gt_path"].append(gt_cache_file)
394
- if self.cache_boost:
395
- gts_pt_list.append(torch.unsqueeze(gt, 0))
396
- # gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
397
-
398
- # im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
399
- # torch.save(gt_shp, shp_cache_file)
400
- cached_dataset["im_shp"].append(im_shp)
401
- # self.dataset["im_shp"].append(im_shp)
402
-
403
- # shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
404
- # torch.save(gt_shp, shp_cache_file)
405
- cached_dataset["gt_shp"].append(gt_shp)
406
- # self.dataset["gt_shp"].append(gt_shp)
407
-
408
- if self.cache_boost:
409
- cached_dataset["ims_pt_dir"] = os.path.join(
410
- cache_folder, self.cache_boost_name + "_ims.pt"
411
- )
412
- cached_dataset["gts_pt_dir"] = os.path.join(
413
- cache_folder, self.cache_boost_name + "_gts.pt"
414
- )
415
- self.ims_pt = torch.cat(ims_pt_list, dim=0)
416
- self.gts_pt = torch.cat(gts_pt_list, dim=0)
417
- torch.save(torch.cat(ims_pt_list, dim=0), cached_dataset["ims_pt_dir"])
418
- torch.save(torch.cat(gts_pt_list, dim=0), cached_dataset["gts_pt_dir"])
419
-
420
- try:
421
- json_file = open(os.path.join(cache_folder, self.cache_file_name), "w")
422
- json.dump(cached_dataset, json_file)
423
- json_file.close()
424
- except Exception:
425
- raise FileNotFoundError("Cannot create JSON")
426
- return cached_dataset
427
-
428
- def load_cache(self, cache_folder):
429
- json_file = open(os.path.join(cache_folder, self.cache_file_name), "r")
430
- dataset = json.load(json_file)
431
- json_file.close()
432
- ## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
433
- ## otherwise the pytorch tensor will be loaded
434
- if self.cache_boost:
435
- # self.ims_npy = np.load(dataset["ims_npy_dir"])
436
- # self.gts_npy = np.load(dataset["gts_npy_dir"])
437
- self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location="cpu")
438
- self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location="cpu")
439
- return dataset
440
-
441
- def __len__(self):
442
- return len(self.dataset["im_path"])
443
-
444
- def __getitem__(self, idx):
445
-
446
- im = None
447
- gt = None
448
- if self.cache_boost and self.ims_pt is not None:
449
-
450
- # start = time.time()
451
- im = self.ims_pt[idx] # .type(torch.float32)
452
- gt = self.gts_pt[idx] # .type(torch.float32)
453
- # print(idx, 'time for pt loading: ', time.time()-start)
454
-
455
- else:
456
- # import time
457
- # start = time.time()
458
- # print("tensor***")
459
- im_pt_path = os.path.join(
460
- self.cache_path,
461
- os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]),
462
- )
463
- im = torch.load(im_pt_path) # (self.dataset["im_path"][idx])
464
- gt_pt_path = os.path.join(
465
- self.cache_path,
466
- os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]),
467
- )
468
- gt = torch.load(gt_pt_path) # (self.dataset["gt_path"][idx])
469
- # print(idx,'time for tensor loading: ', time.time()-start)
470
-
471
- im_shp = self.dataset["im_shp"][idx]
472
- # print("time for loading im and gt: ", time.time()-start)
473
-
474
- # start_time = time.time()
475
- im = torch.divide(im, 255.0)
476
- gt = torch.divide(gt, 255.0)
477
- # print(idx, 'time for normalize torch divide: ', time.time()-start_time)
478
-
479
- sample = {
480
- "imidx": torch.from_numpy(np.array(idx)),
481
- "image": im,
482
- "label": gt,
483
- "shape": torch.from_numpy(np.array(im_shp)),
484
- }
485
-
486
- if self.transform:
487
- sample = self.transform(sample)
488
-
489
- return sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ormbg/inference.py DELETED
@@ -1,110 +0,0 @@
1
- import os
2
- import torch
3
- import argparse
4
- import numpy as np
5
- from PIL import Image
6
- from skimage import io
7
- from models.ormbg import ORMBG
8
- import torch.nn.functional as F
9
-
10
-
11
- def parse_args():
12
- parser = argparse.ArgumentParser(
13
- description="Remove background from images using ORMBG model."
14
- )
15
- parser.add_argument(
16
- "--image",
17
- type=str,
18
- default=os.path.join("examples", "image", "example01.jpeg"),
19
- help="Path to the input image file.",
20
- )
21
- parser.add_argument(
22
- "--output",
23
- type=str,
24
- default=os.path.join("example01_no_background.png"),
25
- help="Path to the output image file.",
26
- )
27
- parser.add_argument(
28
- "--model-path",
29
- type=str,
30
- default=os.path.join("models", "ormbg.pth"),
31
- help="Path to the model file.",
32
- )
33
- parser.add_argument(
34
- "--compare",
35
- action="store_false",
36
- help="Flag to save the original and processed images side by side.",
37
- )
38
- return parser.parse_args()
39
-
40
-
41
- def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
42
- if len(im.shape) < 3:
43
- im = im[:, :, np.newaxis]
44
- im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
45
- im_tensor = F.interpolate(
46
- torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
47
- ).type(torch.uint8)
48
- image = torch.divide(im_tensor, 255.0)
49
- return image
50
-
51
-
52
- def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
53
- result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
54
- ma = torch.max(result)
55
- mi = torch.min(result)
56
- result = (result - mi) / (ma - mi)
57
- im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
58
- im_array = np.squeeze(im_array)
59
- return im_array
60
-
61
-
62
- def inference(args):
63
- image_path = args.image
64
- result_name = args.output
65
- model_path = args.model_path
66
- compare = args.compare
67
-
68
- net = ORMBG()
69
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
-
71
- if torch.cuda.is_available():
72
- net.load_state_dict(torch.load(model_path))
73
- net = net.cuda()
74
- else:
75
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
76
- net.eval()
77
-
78
- model_input_size = [1024, 1024]
79
- orig_im = io.imread(image_path)
80
- orig_im_size = orig_im.shape[0:2]
81
- image = preprocess_image(orig_im, model_input_size).to(device)
82
-
83
- result = net(image)
84
-
85
- # post process
86
- result_image = postprocess_image(result[0][0], orig_im_size)
87
-
88
- # save result
89
- pil_im = Image.fromarray(result_image)
90
-
91
- if pil_im.mode == "RGBA":
92
- pil_im = pil_im.convert("RGB")
93
-
94
- no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
95
- orig_image = Image.open(image_path)
96
- no_bg_image.paste(orig_image, mask=pil_im)
97
-
98
- if compare:
99
- combined_width = orig_image.width + no_bg_image.width
100
- combined_image = Image.new("RGBA", (combined_width, orig_image.height))
101
- combined_image.paste(orig_image, (0, 0))
102
- combined_image.paste(no_bg_image, (orig_image.width, 0))
103
- stacked_output_path = os.path.splitext(result_name)[0] + ".png"
104
- combined_image.save(stacked_output_path)
105
- else:
106
- no_bg_image.save(result_name)
107
-
108
-
109
- if __name__ == "__main__":
110
- inference(parse_args())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ormbg/models/ormbg.py DELETED
@@ -1,484 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
6
-
7
-
8
- class REBNCONV(nn.Module):
9
- def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
10
- super(REBNCONV, self).__init__()
11
-
12
- self.conv_s1 = nn.Conv2d(
13
- in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
14
- )
15
- self.bn_s1 = nn.BatchNorm2d(out_ch)
16
- self.relu_s1 = nn.ReLU(inplace=True)
17
-
18
- def forward(self, x):
19
-
20
- hx = x
21
- xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
22
-
23
- return xout
24
-
25
-
26
- ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
27
- def _upsample_like(src, tar):
28
-
29
- src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
-
31
- return src
32
-
33
-
34
- ### RSU-7 ###
35
- class RSU7(nn.Module):
36
-
37
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
38
- super(RSU7, self).__init__()
39
-
40
- self.in_ch = in_ch
41
- self.mid_ch = mid_ch
42
- self.out_ch = out_ch
43
-
44
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
45
-
46
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
47
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
-
49
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
50
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
-
52
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
53
- self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
54
-
55
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
56
- self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
-
58
- self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
59
- self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
60
-
61
- self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
62
-
63
- self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
64
-
65
- self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
- self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
- self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
69
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
70
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
71
-
72
- def forward(self, x):
73
- b, c, h, w = x.shape
74
-
75
- hx = x
76
- hxin = self.rebnconvin(hx)
77
-
78
- hx1 = self.rebnconv1(hxin)
79
- hx = self.pool1(hx1)
80
-
81
- hx2 = self.rebnconv2(hx)
82
- hx = self.pool2(hx2)
83
-
84
- hx3 = self.rebnconv3(hx)
85
- hx = self.pool3(hx3)
86
-
87
- hx4 = self.rebnconv4(hx)
88
- hx = self.pool4(hx4)
89
-
90
- hx5 = self.rebnconv5(hx)
91
- hx = self.pool5(hx5)
92
-
93
- hx6 = self.rebnconv6(hx)
94
-
95
- hx7 = self.rebnconv7(hx6)
96
-
97
- hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
98
- hx6dup = _upsample_like(hx6d, hx5)
99
-
100
- hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
101
- hx5dup = _upsample_like(hx5d, hx4)
102
-
103
- hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
104
- hx4dup = _upsample_like(hx4d, hx3)
105
-
106
- hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
107
- hx3dup = _upsample_like(hx3d, hx2)
108
-
109
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
110
- hx2dup = _upsample_like(hx2d, hx1)
111
-
112
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
113
-
114
- return hx1d + hxin
115
-
116
-
117
- ### RSU-6 ###
118
- class RSU6(nn.Module):
119
-
120
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
121
- super(RSU6, self).__init__()
122
-
123
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
124
-
125
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
126
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
-
128
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
-
131
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
- self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
-
134
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
- self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
136
-
137
- self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
138
-
139
- self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
140
-
141
- self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
- self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
143
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
144
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
145
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
146
-
147
- def forward(self, x):
148
-
149
- hx = x
150
-
151
- hxin = self.rebnconvin(hx)
152
-
153
- hx1 = self.rebnconv1(hxin)
154
- hx = self.pool1(hx1)
155
-
156
- hx2 = self.rebnconv2(hx)
157
- hx = self.pool2(hx2)
158
-
159
- hx3 = self.rebnconv3(hx)
160
- hx = self.pool3(hx3)
161
-
162
- hx4 = self.rebnconv4(hx)
163
- hx = self.pool4(hx4)
164
-
165
- hx5 = self.rebnconv5(hx)
166
-
167
- hx6 = self.rebnconv6(hx5)
168
-
169
- hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
170
- hx5dup = _upsample_like(hx5d, hx4)
171
-
172
- hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
173
- hx4dup = _upsample_like(hx4d, hx3)
174
-
175
- hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
176
- hx3dup = _upsample_like(hx3d, hx2)
177
-
178
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
179
- hx2dup = _upsample_like(hx2d, hx1)
180
-
181
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
182
-
183
- return hx1d + hxin
184
-
185
-
186
- ### RSU-5 ###
187
- class RSU5(nn.Module):
188
-
189
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
190
- super(RSU5, self).__init__()
191
-
192
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
193
-
194
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
195
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
196
-
197
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
198
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
199
-
200
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
201
- self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
202
-
203
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
204
-
205
- self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
206
-
207
- self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
209
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
211
-
212
- def forward(self, x):
213
-
214
- hx = x
215
-
216
- hxin = self.rebnconvin(hx)
217
-
218
- hx1 = self.rebnconv1(hxin)
219
- hx = self.pool1(hx1)
220
-
221
- hx2 = self.rebnconv2(hx)
222
- hx = self.pool2(hx2)
223
-
224
- hx3 = self.rebnconv3(hx)
225
- hx = self.pool3(hx3)
226
-
227
- hx4 = self.rebnconv4(hx)
228
-
229
- hx5 = self.rebnconv5(hx4)
230
-
231
- hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
232
- hx4dup = _upsample_like(hx4d, hx3)
233
-
234
- hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
235
- hx3dup = _upsample_like(hx3d, hx2)
236
-
237
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
238
- hx2dup = _upsample_like(hx2d, hx1)
239
-
240
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
241
-
242
- return hx1d + hxin
243
-
244
-
245
- ### RSU-4 ###
246
- class RSU4(nn.Module):
247
-
248
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
249
- super(RSU4, self).__init__()
250
-
251
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
252
-
253
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
254
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
255
-
256
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
257
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
258
-
259
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
260
-
261
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
262
-
263
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
264
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
265
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
266
-
267
- def forward(self, x):
268
-
269
- hx = x
270
-
271
- hxin = self.rebnconvin(hx)
272
-
273
- hx1 = self.rebnconv1(hxin)
274
- hx = self.pool1(hx1)
275
-
276
- hx2 = self.rebnconv2(hx)
277
- hx = self.pool2(hx2)
278
-
279
- hx3 = self.rebnconv3(hx)
280
-
281
- hx4 = self.rebnconv4(hx3)
282
-
283
- hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
284
- hx3dup = _upsample_like(hx3d, hx2)
285
-
286
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
287
- hx2dup = _upsample_like(hx2d, hx1)
288
-
289
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
290
-
291
- return hx1d + hxin
292
-
293
-
294
- ### RSU-4F ###
295
- class RSU4F(nn.Module):
296
-
297
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
298
- super(RSU4F, self).__init__()
299
-
300
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
301
-
302
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
303
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
304
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
305
-
306
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
307
-
308
- self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
309
- self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
310
- self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
311
-
312
- def forward(self, x):
313
-
314
- hx = x
315
-
316
- hxin = self.rebnconvin(hx)
317
-
318
- hx1 = self.rebnconv1(hxin)
319
- hx2 = self.rebnconv2(hx1)
320
- hx3 = self.rebnconv3(hx2)
321
-
322
- hx4 = self.rebnconv4(hx3)
323
-
324
- hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
325
- hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
326
- hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
327
-
328
- return hx1d + hxin
329
-
330
-
331
- class myrebnconv(nn.Module):
332
- def __init__(
333
- self,
334
- in_ch=3,
335
- out_ch=1,
336
- kernel_size=3,
337
- stride=1,
338
- padding=1,
339
- dilation=1,
340
- groups=1,
341
- ):
342
- super(myrebnconv, self).__init__()
343
-
344
- self.conv = nn.Conv2d(
345
- in_ch,
346
- out_ch,
347
- kernel_size=kernel_size,
348
- stride=stride,
349
- padding=padding,
350
- dilation=dilation,
351
- groups=groups,
352
- )
353
- self.bn = nn.BatchNorm2d(out_ch)
354
- self.rl = nn.ReLU(inplace=True)
355
-
356
- def forward(self, x):
357
- return self.rl(self.bn(self.conv(x)))
358
-
359
-
360
- bce_loss = nn.BCELoss(size_average=True)
361
-
362
-
363
- class ORMBG(nn.Module):
364
-
365
- def __init__(self, in_ch=3, out_ch=1):
366
- super(ORMBG, self).__init__()
367
-
368
- self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
369
- self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
-
371
- self.stage1 = RSU7(64, 32, 64)
372
- self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
373
-
374
- self.stage2 = RSU6(64, 32, 128)
375
- self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
376
-
377
- self.stage3 = RSU5(128, 64, 256)
378
- self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
379
-
380
- self.stage4 = RSU4(256, 128, 512)
381
- self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
382
-
383
- self.stage5 = RSU4F(512, 256, 512)
384
- self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
385
-
386
- self.stage6 = RSU4F(512, 256, 512)
387
-
388
- # decoder
389
- self.stage5d = RSU4F(1024, 256, 512)
390
- self.stage4d = RSU4(1024, 128, 256)
391
- self.stage3d = RSU5(512, 64, 128)
392
- self.stage2d = RSU6(256, 32, 64)
393
- self.stage1d = RSU7(128, 16, 64)
394
-
395
- self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
396
- self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
397
- self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
398
- self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
399
- self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
400
- self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
401
-
402
- # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
403
-
404
- def compute_loss(self, predictions, ground_truth):
405
- loss0, loss = 0.0, 0.0
406
- for i in range(0, len(predictions)):
407
- loss = loss + bce_loss(predictions[i], ground_truth)
408
- if i == 0:
409
- loss0 = loss
410
- return loss0, loss
411
-
412
- def forward(self, x):
413
-
414
- hx = x
415
-
416
- hxin = self.conv_in(hx)
417
- # hx = self.pool_in(hxin)
418
-
419
- # stage 1
420
- hx1 = self.stage1(hxin)
421
- hx = self.pool12(hx1)
422
-
423
- # stage 2
424
- hx2 = self.stage2(hx)
425
- hx = self.pool23(hx2)
426
-
427
- # stage 3
428
- hx3 = self.stage3(hx)
429
- hx = self.pool34(hx3)
430
-
431
- # stage 4
432
- hx4 = self.stage4(hx)
433
- hx = self.pool45(hx4)
434
-
435
- # stage 5
436
- hx5 = self.stage5(hx)
437
- hx = self.pool56(hx5)
438
-
439
- # stage 6
440
- hx6 = self.stage6(hx)
441
- hx6up = _upsample_like(hx6, hx5)
442
-
443
- # -------------------- decoder --------------------
444
- hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
445
- hx5dup = _upsample_like(hx5d, hx4)
446
-
447
- hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
448
- hx4dup = _upsample_like(hx4d, hx3)
449
-
450
- hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
451
- hx3dup = _upsample_like(hx3d, hx2)
452
-
453
- hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
454
- hx2dup = _upsample_like(hx2d, hx1)
455
-
456
- hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
457
-
458
- # side output
459
- d1 = self.side1(hx1d)
460
- d1 = _upsample_like(d1, x)
461
-
462
- d2 = self.side2(hx2d)
463
- d2 = _upsample_like(d2, x)
464
-
465
- d3 = self.side3(hx3d)
466
- d3 = _upsample_like(d3, x)
467
-
468
- d4 = self.side4(hx4d)
469
- d4 = _upsample_like(d4, x)
470
-
471
- d5 = self.side5(hx5d)
472
- d5 = _upsample_like(d5, x)
473
-
474
- d6 = self.side6(hx6)
475
- d6 = _upsample_like(d6, x)
476
-
477
- return [
478
- F.sigmoid(d1),
479
- F.sigmoid(d2),
480
- F.sigmoid(d3),
481
- F.sigmoid(d4),
482
- F.sigmoid(d5),
483
- F.sigmoid(d6),
484
- ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ormbg/train_model.py DELETED
@@ -1,474 +0,0 @@
1
- import os
2
- import time
3
-
4
- import torch, gc
5
- import torch.nn as nn
6
- import torch.optim as optim
7
- from torch.autograd import Variable
8
- import torch.nn.functional as F
9
-
10
- import numpy as np
11
-
12
- from pathlib import Path
13
-
14
- from models.ormbg import ORMBG
15
-
16
- from skimage import io
17
-
18
- from basics import f1_mae_torch
19
-
20
- from data_loader_cache import (
21
- get_im_gt_name_dict,
22
- create_dataloaders,
23
- GOSGridDropout,
24
- GOSRandomHFlip,
25
- )
26
-
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
-
29
-
30
- def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
31
- net.eval()
32
- print("Validating...")
33
- epoch_num = hypar["max_epoch_num"]
34
-
35
- val_loss = 0.0
36
- tar_loss = 0.0
37
- val_cnt = 0.0
38
-
39
- tmp_f1 = []
40
- tmp_mae = []
41
- tmp_time = []
42
-
43
- start_valid = time.time()
44
-
45
- for k in range(len(valid_dataloaders)):
46
-
47
- valid_dataloader = valid_dataloaders[k]
48
- valid_dataset = valid_datasets[k]
49
-
50
- val_num = valid_dataset.__len__()
51
- mybins = np.arange(0, 256)
52
- PRE = np.zeros((val_num, len(mybins) - 1))
53
- REC = np.zeros((val_num, len(mybins) - 1))
54
- F1 = np.zeros((val_num, len(mybins) - 1))
55
- MAE = np.zeros((val_num))
56
-
57
- for i_val, data_val in enumerate(valid_dataloader):
58
- val_cnt = val_cnt + 1.0
59
- imidx_val, inputs_val, labels_val, shapes_val = (
60
- data_val["imidx"],
61
- data_val["image"],
62
- data_val["label"],
63
- data_val["shape"],
64
- )
65
-
66
- if hypar["model_digit"] == "full":
67
- inputs_val = inputs_val.type(torch.FloatTensor)
68
- labels_val = labels_val.type(torch.FloatTensor)
69
- else:
70
- inputs_val = inputs_val.type(torch.HalfTensor)
71
- labels_val = labels_val.type(torch.HalfTensor)
72
-
73
- # wrap them in Variable
74
- if torch.cuda.is_available():
75
- inputs_val_v, labels_val_v = Variable(
76
- inputs_val.cuda(), requires_grad=False
77
- ), Variable(labels_val.cuda(), requires_grad=False)
78
- else:
79
- inputs_val_v, labels_val_v = Variable(
80
- inputs_val, requires_grad=False
81
- ), Variable(labels_val, requires_grad=False)
82
-
83
- t_start = time.time()
84
- ds_val = net(inputs_val_v)[0]
85
- t_end = time.time() - t_start
86
- tmp_time.append(t_end)
87
-
88
- # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
89
- loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
90
-
91
- # compute F measure
92
- for t in range(hypar["batch_size_valid"]):
93
- i_test = imidx_val[t].data.numpy()
94
-
95
- pred_val = ds_val[0][t, :, :, :] # B x 1 x H x W
96
-
97
- ## recover the prediction spatial size to the orignal image size
98
- pred_val = torch.squeeze(
99
- F.upsample(
100
- torch.unsqueeze(pred_val, 0),
101
- (shapes_val[t][0], shapes_val[t][1]),
102
- mode="bilinear",
103
- )
104
- )
105
-
106
- # pred_val = normPRED(pred_val)
107
- ma = torch.max(pred_val)
108
- mi = torch.min(pred_val)
109
- pred_val = (pred_val - mi) / (ma - mi) # max = 1
110
-
111
- if len(valid_dataset.dataset["ori_gt_path"]) != 0:
112
- gt = np.squeeze(
113
- io.imread(valid_dataset.dataset["ori_gt_path"][i_test])
114
- ) # max = 255
115
- if gt.max() == 1:
116
- gt = gt * 255
117
- else:
118
- gt = np.zeros((shapes_val[t][0], shapes_val[t][1]))
119
- with torch.no_grad():
120
- gt = torch.tensor(gt).to(device)
121
-
122
- pre, rec, f1, mae = f1_mae_torch(
123
- pred_val * 255, gt, valid_dataset, i_test, mybins, hypar
124
- )
125
-
126
- PRE[i_test, :] = pre
127
- REC[i_test, :] = rec
128
- F1[i_test, :] = f1
129
- MAE[i_test] = mae
130
-
131
- del ds_val, gt
132
- gc.collect()
133
- torch.cuda.empty_cache()
134
-
135
- # if(loss_val.data[0]>1):
136
- val_loss += loss_val.item() # data[0]
137
- tar_loss += loss2_val.item() # data[0]
138
-
139
- print(
140
- "[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"
141
- % (
142
- i_val,
143
- val_num,
144
- val_loss / (i_val + 1),
145
- tar_loss / (i_val + 1),
146
- np.amax(F1[i_test, :]),
147
- MAE[i_test],
148
- t_end,
149
- )
150
- )
151
-
152
- del loss2_val, loss_val
153
-
154
- print("============================")
155
- PRE_m = np.mean(PRE, 0)
156
- REC_m = np.mean(REC, 0)
157
- f1_m = (1 + 0.3) * PRE_m * REC_m / (0.3 * PRE_m + REC_m + 1e-8)
158
-
159
- tmp_f1.append(np.amax(f1_m))
160
- tmp_mae.append(np.mean(MAE))
161
-
162
- return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
163
-
164
-
165
- def train(
166
- net,
167
- optimizer,
168
- train_dataloaders,
169
- train_datasets,
170
- valid_dataloaders,
171
- valid_datasets,
172
- hypar,
173
- ):
174
-
175
- model_path = hypar["model_path"]
176
- model_save_fre = hypar["model_save_fre"]
177
- max_ite = hypar["max_ite"]
178
- batch_size_train = hypar["batch_size_train"]
179
- batch_size_valid = hypar["batch_size_valid"]
180
-
181
- if not os.path.exists(model_path):
182
- os.mkdir(model_path)
183
-
184
- ite_num = hypar["start_ite"] # count the toal iteration number
185
- ite_num4val = 0 #
186
- running_loss = 0.0 # count the toal loss
187
- running_tar_loss = 0.0 # count the target output loss
188
- last_f1 = [0 for x in range(len(valid_dataloaders))]
189
-
190
- train_num = train_datasets[0].__len__()
191
-
192
- net.train()
193
-
194
- start_last = time.time()
195
- gos_dataloader = train_dataloaders[0]
196
- epoch_num = hypar["max_epoch_num"]
197
- notgood_cnt = 0
198
-
199
- for epoch in range(epoch_num):
200
-
201
- for i, data in enumerate(gos_dataloader):
202
-
203
- if ite_num >= max_ite:
204
- print("Training Reached the Maximal Iteration Number ", max_ite)
205
- exit()
206
-
207
- # start_read = time.time()
208
- ite_num = ite_num + 1
209
- ite_num4val = ite_num4val + 1
210
-
211
- # get the inputs
212
- inputs, labels = data["image"], data["label"]
213
-
214
- if hypar["model_digit"] == "full":
215
- inputs = inputs.type(torch.FloatTensor)
216
- labels = labels.type(torch.FloatTensor)
217
- else:
218
- inputs = inputs.type(torch.HalfTensor)
219
- labels = labels.type(torch.HalfTensor)
220
-
221
- # wrap them in Variable
222
- if torch.cuda.is_available():
223
- inputs_v, labels_v = Variable(
224
- inputs.cuda(), requires_grad=False
225
- ), Variable(labels.cuda(), requires_grad=False)
226
- else:
227
- inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(
228
- labels, requires_grad=False
229
- )
230
-
231
- # y zero the parameter gradients
232
- start_inf_loss_back = time.time()
233
- optimizer.zero_grad()
234
-
235
- ds, _ = net(inputs_v)
236
- loss2, loss = net.compute_loss(ds, labels_v)
237
-
238
- loss.backward()
239
- optimizer.step()
240
-
241
- # # print statistics
242
- running_loss += loss.item()
243
- running_tar_loss += loss2.item()
244
-
245
- # del outputs, loss
246
- del ds, loss2, loss
247
- end_inf_loss_back = time.time() - start_inf_loss_back
248
-
249
- print(
250
- ">>>"
251
- + model_path.split("/")[-1]
252
- + " - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f"
253
- % (
254
- epoch + 1,
255
- epoch_num,
256
- (i + 1) * batch_size_train,
257
- train_num,
258
- ite_num,
259
- running_loss / ite_num4val,
260
- running_tar_loss / ite_num4val,
261
- time.time() - start_last,
262
- time.time() - start_last - end_inf_loss_back,
263
- )
264
- )
265
- start_last = time.time()
266
-
267
- if ite_num % model_save_fre == 0: # validate every 2000 iterations
268
- notgood_cnt += 1
269
- net.eval()
270
- tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(
271
- net, valid_dataloaders, valid_datasets, hypar, epoch
272
- )
273
- net.train() # resume train
274
-
275
- tmp_out = 0
276
- print("last_f1:", last_f1)
277
- print("tmp_f1:", tmp_f1)
278
- for fi in range(len(last_f1)):
279
- if tmp_f1[fi] > last_f1[fi]:
280
- tmp_out = 1
281
- print("tmp_out:", tmp_out)
282
- if tmp_out:
283
- notgood_cnt = 0
284
- last_f1 = tmp_f1
285
- tmp_f1_str = [str(round(f1x, 4)) for f1x in tmp_f1]
286
- tmp_mae_str = [str(round(mx, 4)) for mx in tmp_mae]
287
- maxf1 = "_".join(tmp_f1_str)
288
- meanM = "_".join(tmp_mae_str)
289
- # .cpu().detach().numpy()
290
- model_name = (
291
- "/gpu_itr_"
292
- + str(ite_num)
293
- + "_traLoss_"
294
- + str(np.round(running_loss / ite_num4val, 4))
295
- + "_traTarLoss_"
296
- + str(np.round(running_tar_loss / ite_num4val, 4))
297
- + "_valLoss_"
298
- + str(np.round(val_loss / (i_val + 1), 4))
299
- + "_valTarLoss_"
300
- + str(np.round(tar_loss / (i_val + 1), 4))
301
- + "_maxF1_"
302
- + maxf1
303
- + "_mae_"
304
- + meanM
305
- + "_time_"
306
- + str(
307
- np.round(np.mean(np.array(tmp_time)) / batch_size_valid, 6)
308
- )
309
- + ".pth"
310
- )
311
- torch.save(net.state_dict(), model_path + model_name)
312
-
313
- running_loss = 0.0
314
- running_tar_loss = 0.0
315
- ite_num4val = 0
316
-
317
- if notgood_cnt >= hypar["early_stop"]:
318
- print(
319
- "No improvements in the last "
320
- + str(notgood_cnt)
321
- + " validation periods, so training stopped !"
322
- )
323
- exit()
324
-
325
- print("Training Reaches The Maximum Epoch Number")
326
-
327
-
328
- def main(train_datasets, valid_datasets, hypar):
329
-
330
- print("--- create training dataloader ---")
331
-
332
- train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
333
- ## build dataloader for training datasets
334
- train_dataloaders, train_datasets = create_dataloaders(
335
- train_nm_im_gt_list,
336
- cache_size=hypar["cache_size"],
337
- cache_boost=hypar["cache_boost_train"],
338
- my_transforms=[GOSGridDropout(), GOSRandomHFlip()],
339
- batch_size=hypar["batch_size_train"],
340
- shuffle=True,
341
- )
342
-
343
- valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
344
-
345
- valid_dataloaders, valid_datasets = create_dataloaders(
346
- valid_nm_im_gt_list,
347
- cache_size=hypar["cache_size"],
348
- cache_boost=hypar["cache_boost_valid"],
349
- my_transforms=[],
350
- batch_size=hypar["batch_size_valid"],
351
- shuffle=False,
352
- )
353
-
354
- net = hypar["model"]
355
-
356
- if hypar["model_digit"] == "half":
357
- net.half()
358
- for layer in net.modules():
359
- if isinstance(layer, nn.BatchNorm2d):
360
- layer.float()
361
-
362
- if torch.cuda.is_available():
363
- net.cuda()
364
-
365
- if hypar["restore_model"] != "":
366
- print("restore model from:")
367
- print(hypar["model_path"] + "/" + hypar["restore_model"])
368
- if torch.cuda.is_available():
369
- net.load_state_dict(
370
- torch.load(hypar["model_path"] + "/" + hypar["restore_model"])
371
- )
372
- else:
373
- net.load_state_dict(
374
- torch.load(
375
- hypar["model_path"] + "/" + hypar["restore_model"],
376
- map_location="cpu",
377
- )
378
- )
379
-
380
- optimizer = optim.Adam(
381
- net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0
382
- )
383
-
384
- train(
385
- net,
386
- optimizer,
387
- train_dataloaders,
388
- train_datasets,
389
- valid_dataloaders,
390
- valid_datasets,
391
- hypar,
392
- )
393
-
394
-
395
- if __name__ == "__main__":
396
-
397
- output_model_folder = "saved_models"
398
- Path(output_model_folder).mkdir(parents=True, exist_ok=True)
399
-
400
- train_datasets, valid_datasets = [], []
401
- dataset_1, dataset_1 = {}, {}
402
-
403
- dataset_training = {
404
- "name": "ormbg-training",
405
- "im_dir": str(Path("dataset", "training", "im")),
406
- "gt_dir": str(Path("dataset", "training", "gt")),
407
- "im_ext": ".png",
408
- "gt_ext": ".png",
409
- "cache_dir": str(Path("cache", "teacher", "training")),
410
- }
411
-
412
- dataset_validation = {
413
- "name": "ormbg-training",
414
- "im_dir": str(Path("dataset", "validation", "im")),
415
- "gt_dir": str(Path("dataset", "validation", "gt")),
416
- "im_ext": ".png",
417
- "gt_ext": ".png",
418
- "cache_dir": str(Path("cache", "teacher", "validation")),
419
- }
420
-
421
- train_datasets = [dataset_training]
422
- valid_datasets = [dataset_validation]
423
-
424
- ### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
425
- hypar = {}
426
-
427
- hypar["model"] = ORMBG()
428
- hypar["seed"] = 0
429
-
430
- ## model weights path
431
- hypar["model_path"] = "saved_models"
432
-
433
- ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
434
- hypar["restore_model"] = ""
435
-
436
- ## start iteration for the training, can be changed to match the restored training process
437
- hypar["start_ite"] = 0
438
-
439
- ## indicates "half" or "full" accuracy of float number
440
- hypar["model_digit"] = "full"
441
-
442
- ## To handle large size input images, which take a lot of time for loading in training,
443
- # we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
444
- hypar["cache_size"] = [
445
- 1024,
446
- 1024,
447
- ]
448
-
449
- ## cached input spatial resolution, can be configured into different size
450
- ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
451
- hypar["cache_boost_train"] = False
452
-
453
- ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
454
- hypar["cache_boost_valid"] = False
455
-
456
- ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
457
- hypar["early_stop"] = 20
458
-
459
- ## valid and save model weights every 2000 iterations
460
- hypar["model_save_fre"] = 2000
461
-
462
- ## batch size for training
463
- hypar["batch_size_train"] = 8
464
-
465
- ## batch size for validation and inferencing
466
- hypar["batch_size_valid"] = 1
467
-
468
- ## if early stop couldn't stop the training process, stop it by the max_ite_num
469
- hypar["max_ite"] = 10000000
470
-
471
- ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
472
- hypar["max_epoch_num"] = 1000000
473
-
474
- main(train_datasets, valid_datasets, hypar=hypar)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stack.py DELETED
@@ -1,37 +0,0 @@
1
- from PIL import Image
2
-
3
-
4
- def stack_images(image_paths, output_path):
5
- # Load all images from the provided paths
6
- images = [Image.open(path) for path in image_paths]
7
-
8
- # Determine the size of individual images (assuming all are the same size)
9
- width, height = images[0].size
10
-
11
- # Create a new image with appropriate size (2 columns and 3 rows)
12
- total_width = width * 2
13
- total_height = height * 3
14
- new_image = Image.new("RGB", (total_width, total_height))
15
-
16
- # Paste each image into the new image
17
- for i, image in enumerate(images):
18
- # Calculate the position for each image
19
- x_offset = (i % 2) * width
20
- y_offset = (i // 2) * height
21
- new_image.paste(image, (x_offset, y_offset))
22
-
23
- # Save the new image
24
- new_image.save(output_path)
25
-
26
-
27
- # Example usage
28
- image_paths = [
29
- "/Users/mav/Desktop/example1.png",
30
- "/Users/mav/Desktop/image-1.webp",
31
- "/Users/mav/Desktop/example2.png",
32
- "/Users/mav/Desktop/image-2.webp",
33
- "/Users/mav/Desktop/example3.png",
34
- "/Users/mav/Desktop/image-3.webp",
35
- ]
36
- output_path = "stacked_images.jpg"
37
- stack_images(image_paths, output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/.DS_Store DELETED
Binary file (6.15 kB)
 
utils/architecture.py DELETED
@@ -1,4 +0,0 @@
1
- from ormbg.models.ormbg import ORMBG
2
-
3
- if __name__ == "__main__":
4
- print(ORMBG())
 
 
 
 
 
utils/loss_example.py DELETED
@@ -1,69 +0,0 @@
1
- import os
2
- import torch
3
- import argparse
4
- import numpy as np
5
- from skimage import io
6
- from ormbg.models.ormbg import ORMBG
7
- import torch.nn.functional as F
8
-
9
-
10
- def parse_args():
11
- parser = argparse.ArgumentParser(
12
- description="Remove background from images using ORMBG model."
13
- )
14
- parser.add_argument(
15
- "--prediction",
16
- type=list,
17
- default=[
18
- os.path.join("examples", "loss", "loss01.png"),
19
- os.path.join("examples", "loss", "loss02.png"),
20
- os.path.join("examples", "loss", "loss03.png"),
21
- os.path.join("examples", "loss", "loss04.png"),
22
- os.path.join("examples", "loss", "loss05.png"),
23
- ],
24
- help="Path to the input image file.",
25
- )
26
- parser.add_argument(
27
- "--gt",
28
- type=str,
29
- default=os.path.join("examples", "loss", "gt.png"),
30
- help="Ground truth mask",
31
- )
32
- return parser.parse_args()
33
-
34
-
35
- def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
36
- if len(im.shape) < 3:
37
- im = im[:, :, np.newaxis]
38
- im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
39
- im_tensor = F.interpolate(
40
- torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
41
- ).type(torch.uint8)
42
- image = torch.divide(im_tensor, 255.0)
43
- return image
44
-
45
-
46
- def inference(args):
47
- prediction_paths = args.prediction
48
- gt_path = args.gt
49
-
50
- net = ORMBG()
51
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
-
53
- for pred_path in prediction_paths:
54
-
55
- model_input_size = [1024, 1024]
56
- loss = io.imread(pred_path)
57
- prediction = preprocess_image(loss, model_input_size).to(device)
58
-
59
- model_input_size = [1024, 1024]
60
- gt = io.imread(gt_path)
61
- ground_truth = preprocess_image(gt, model_input_size).to(device)
62
-
63
- _, loss = net.compute_loss([prediction], ground_truth)
64
-
65
- print(f"Loss: {pred_path} {loss}")
66
-
67
-
68
- if __name__ == "__main__":
69
- inference(parse_args())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/pth_to_onnx.py DELETED
@@ -1,59 +0,0 @@
1
- import torch
2
- import argparse
3
- from ormbg.models.ormbg import ORMBG
4
-
5
-
6
- def export_to_onnx(model_path, onnx_path):
7
-
8
- net = ORMBG()
9
-
10
- if torch.cuda.is_available():
11
- net.load_state_dict(torch.load(model_path))
12
- net = net.cuda()
13
- else:
14
- net.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
15
-
16
- net.eval()
17
-
18
- # Create a dummy input tensor. The size should match the model's input size.
19
- # Adjust the dimensions as necessary; here it is assumed the input is a 3-channel image.
20
- dummy_input = torch.randn(
21
- 1,
22
- 3,
23
- 1024,
24
- 1024,
25
- device="cuda" if torch.cuda.is_available() else "cpu",
26
- )
27
-
28
- torch.onnx.export(
29
- net,
30
- dummy_input,
31
- onnx_path,
32
- export_params=True,
33
- opset_version=11,
34
- do_constant_folding=True,
35
- input_names=["input"],
36
- output_names=["output"],
37
- )
38
-
39
-
40
- if __name__ == "__main__":
41
- parser = argparse.ArgumentParser(
42
- description="Export a trained model to ONNX format."
43
- )
44
- parser.add_argument(
45
- "--model_path",
46
- type=str,
47
- default="models/ormbg.pth",
48
- help="The path to the trained model file.",
49
- )
50
- parser.add_argument(
51
- "--onnx_path",
52
- type=str,
53
- default="models/ormbg.pth",
54
- help="The path where the ONNX model will be saved.",
55
- )
56
-
57
- args = parser.parse_args()
58
-
59
- export_to_onnx(args.model_path, args.onnx_path)