hyliu commited on
Commit
8ec10cf
·
verified ·
1 Parent(s): 86fb2db

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +68 -0
  2. backup/bin/psrun2 +9 -0
  3. backup/bin/psrun4 +9 -0
  4. backup/deblur/.gitignore +109 -0
  5. backup/deblur/DeepDeblur-PyTorch/.gitignore +109 -0
  6. backup/deblur/DeepDeblur-PyTorch/LICENSE +21 -0
  7. backup/deblur/DeepDeblur-PyTorch/README.md +216 -0
  8. backup/deblur/DeepDeblur-PyTorch/experiment/.gitignore +1 -0
  9. backup/deblur/DeepDeblur-PyTorch/src/__pycache__/option.cpython-37.pyc +0 -0
  10. backup/deblur/DeepDeblur-PyTorch/src/__pycache__/template.cpython-37.pyc +0 -0
  11. backup/deblur/DeepDeblur-PyTorch/src/__pycache__/train.cpython-37.pyc +0 -0
  12. backup/deblur/DeepDeblur-PyTorch/src/__pycache__/utils.cpython-37.pyc +0 -0
  13. backup/deblur/DeepDeblur-PyTorch/src/data/__init__.py +79 -0
  14. backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/__init__.cpython-37.pyc +0 -0
  15. backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/common.cpython-37.pyc +0 -0
  16. backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/dataset.cpython-37.pyc +0 -0
  17. backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/gopro_large.cpython-37.pyc +0 -0
  18. backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/sampler.cpython-37.pyc +0 -0
  19. backup/deblur/DeepDeblur-PyTorch/src/data/common.py +163 -0
  20. backup/deblur/DeepDeblur-PyTorch/src/data/dataset.py +154 -0
  21. backup/deblur/DeepDeblur-PyTorch/src/data/demo.py +22 -0
  22. backup/deblur/DeepDeblur-PyTorch/src/data/gopro_large.py +23 -0
  23. backup/deblur/DeepDeblur-PyTorch/src/data/reds.py +28 -0
  24. backup/deblur/DeepDeblur-PyTorch/src/data/sampler.py +115 -0
  25. backup/deblur/DeepDeblur-PyTorch/src/launch.py +55 -0
  26. backup/deblur/DeepDeblur-PyTorch/src/loss/__init__.py +464 -0
  27. backup/deblur/DeepDeblur-PyTorch/src/loss/__pycache__/__init__.cpython-37.pyc +0 -0
  28. backup/deblur/DeepDeblur-PyTorch/src/loss/__pycache__/metric.cpython-37.pyc +0 -0
  29. backup/deblur/DeepDeblur-PyTorch/src/loss/adversarial.py +52 -0
  30. backup/deblur/DeepDeblur-PyTorch/src/loss/metric.py +112 -0
  31. backup/deblur/DeepDeblur-PyTorch/src/main.py +67 -0
  32. backup/deblur/DeepDeblur-PyTorch/src/model/MSResNet.py +67 -0
  33. backup/deblur/DeepDeblur-PyTorch/src/model/ResNet.py +41 -0
  34. backup/deblur/DeepDeblur-PyTorch/src/model/__init__.py +136 -0
  35. backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/MSResNet.cpython-37.pyc +0 -0
  36. backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/ResNet.cpython-37.pyc +0 -0
  37. backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/__init__.cpython-37.pyc +0 -0
  38. backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/common.cpython-37.pyc +0 -0
  39. backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/discriminator.cpython-37.pyc +0 -0
  40. backup/deblur/DeepDeblur-PyTorch/src/model/common.py +161 -0
  41. backup/deblur/DeepDeblur-PyTorch/src/model/discriminator.py +41 -0
  42. backup/deblur/DeepDeblur-PyTorch/src/model/structure.py +56 -0
  43. backup/deblur/DeepDeblur-PyTorch/src/optim/__init__.py +206 -0
  44. backup/deblur/DeepDeblur-PyTorch/src/optim/__pycache__/__init__.cpython-37.pyc +0 -0
  45. backup/deblur/DeepDeblur-PyTorch/src/optim/warm_multi_step_lr.py +32 -0
  46. backup/deblur/DeepDeblur-PyTorch/src/option.py +274 -0
  47. backup/deblur/DeepDeblur-PyTorch/src/prepare.sh +16 -0
  48. backup/deblur/DeepDeblur-PyTorch/src/template.py +9 -0
  49. backup/deblur/DeepDeblur-PyTorch/src/train.py +225 -0
  50. backup/deblur/DeepDeblur-PyTorch/src/utils.py +200 -0
.gitattributes CHANGED
@@ -33,3 +33,71 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ backup/deblur/experiment/LamRes_L1/result/GOPR0384_11_00/blur_gamma/000031.png filter=lfs diff=lfs merge=lfs -text
37
+ backup/deblur/experiment/LamRes_L1/result/GOPR0384_11_00/blur_gamma/000041.png filter=lfs diff=lfs merge=lfs -text
38
+ backup/deblur/experiment/LamRes_L1/result/GOPR0384_11_00/blur_gamma/000051.png filter=lfs diff=lfs merge=lfs -text
39
+ backup/deblur/experiment/LamRes_L1/result/GOPR0384_11_00/blur_gamma/000071.png filter=lfs diff=lfs merge=lfs -text
40
+ backup/deblur/experiment/LamRes_L1/result/GOPR0384_11_00/blur_gamma/000091.png filter=lfs diff=lfs merge=lfs -text
41
+ backup/deblur/experiment/LamRes_L1/result/GOPR0384_11_05/blur_gamma/004091.png filter=lfs diff=lfs merge=lfs -text
42
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000001.png filter=lfs diff=lfs merge=lfs -text
43
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000011.png filter=lfs diff=lfs merge=lfs -text
44
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000021.png filter=lfs diff=lfs merge=lfs -text
45
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000031.png filter=lfs diff=lfs merge=lfs -text
46
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000041.png filter=lfs diff=lfs merge=lfs -text
47
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000051.png filter=lfs diff=lfs merge=lfs -text
48
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000061.png filter=lfs diff=lfs merge=lfs -text
49
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000071.png filter=lfs diff=lfs merge=lfs -text
50
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000081.png filter=lfs diff=lfs merge=lfs -text
51
+ backup/deblur/experiment/LamRes_L1/result/GOPR0396_11_00/blur_gamma/000091.png filter=lfs diff=lfs merge=lfs -text
52
+ backup/deblur/experiment/LamRes_L1/result/GOPR0854_11_00/blur_gamma/000007.png filter=lfs diff=lfs merge=lfs -text
53
+ backup/deblur/experiment/LamRes_L1/result/GOPR0854_11_00/blur_gamma/000017.png filter=lfs diff=lfs merge=lfs -text
54
+ backup/deblur/experiment/LamRes_L1/result/GOPR0854_11_00/blur_gamma/000037.png filter=lfs diff=lfs merge=lfs -text
55
+ backup/deblur/experiment/LamRes_L1/result/GOPR0854_11_00/blur_gamma/000047.png filter=lfs diff=lfs merge=lfs -text
56
+ backup/deblur/experiment/LamRes_L1/result/GOPR0854_11_00/blur_gamma/000057.png filter=lfs diff=lfs merge=lfs -text
57
+ backup/deblur/experiment/LamRes_L1/result/GOPR0854_11_00/blur_gamma/000067.png filter=lfs diff=lfs merge=lfs -text
58
+ backup/deblur/experiment/LamRes_L1/result/GOPR0854_11_00/blur_gamma/000077.png filter=lfs diff=lfs merge=lfs -text
59
+ backup/deblur/experiment/LamRes_L1/result/GOPR0854_11_00/blur_gamma/000087.png filter=lfs diff=lfs merge=lfs -text
60
+ backup/deblur/experiment/LamRes_L1/result/GOPR0854_11_00/blur_gamma/000097.png filter=lfs diff=lfs merge=lfs -text
61
+ backup/deblur/experiment/LamRes_L1/result/GOPR0862_11_00/blur_gamma/000007.png filter=lfs diff=lfs merge=lfs -text
62
+ backup/deblur/experiment/LamRes_L1/result/GOPR0862_11_00/blur_gamma/000017.png filter=lfs diff=lfs merge=lfs -text
63
+ backup/deblur/experiment/LamRes_L1/result/GOPR0862_11_00/blur_gamma/000027.png filter=lfs diff=lfs merge=lfs -text
64
+ backup/deblur/experiment/LamRes_L1/result/GOPR0862_11_00/blur_gamma/000037.png filter=lfs diff=lfs merge=lfs -text
65
+ backup/deblur/experiment/LamRes_L1/result/GOPR0862_11_00/blur_gamma/000047.png filter=lfs diff=lfs merge=lfs -text
66
+ backup/deblur/experiment/LamRes_L1/result/GOPR0862_11_00/blur_gamma/000057.png filter=lfs diff=lfs merge=lfs -text
67
+ backup/deblur/experiment/LamRes_L1/result/GOPR0862_11_00/blur_gamma/000067.png filter=lfs diff=lfs merge=lfs -text
68
+ backup/deblur/experiment/LamRes_L1/result/GOPR0862_11_00/blur_gamma/000077.png filter=lfs diff=lfs merge=lfs -text
69
+ backup/deblur/experiment/LamRes_L1/result/GOPR0868_11_00/blur_gamma/000010.png filter=lfs diff=lfs merge=lfs -text
70
+ backup/deblur/experiment/LamRes_L1/result/GOPR0868_11_00/blur_gamma/000020.png filter=lfs diff=lfs merge=lfs -text
71
+ backup/deblur/experiment/LamRes_L1/result/GOPR0868_11_00/blur_gamma/000040.png filter=lfs diff=lfs merge=lfs -text
72
+ backup/deblur/experiment/LamRes_L1/result/GOPR0868_11_00/blur_gamma/000050.png filter=lfs diff=lfs merge=lfs -text
73
+ backup/deblur/experiment/LamRes_L1/result/GOPR0868_11_00/blur_gamma/000060.png filter=lfs diff=lfs merge=lfs -text
74
+ backup/deblur/experiment/LamRes_L1/result/GOPR0868_11_00/blur_gamma/000070.png filter=lfs diff=lfs merge=lfs -text
75
+ backup/deblur/experiment/LamRes_L1/result/GOPR0868_11_00/blur_gamma/000080.png filter=lfs diff=lfs merge=lfs -text
76
+ backup/deblur/experiment/LamRes_L1/result/GOPR0868_11_00/blur_gamma/000100.png filter=lfs diff=lfs merge=lfs -text
77
+ backup/deblur/experiment/LamRes_L1/result/GOPR0869_11_00/blur_gamma/000040.png filter=lfs diff=lfs merge=lfs -text
78
+ backup/deblur/experiment/LamRes_L1/result/GOPR0869_11_00/blur_gamma/000050.png filter=lfs diff=lfs merge=lfs -text
79
+ backup/deblur/experiment/LamRes_L1/result/GOPR0869_11_00/blur_gamma/000060.png filter=lfs diff=lfs merge=lfs -text
80
+ backup/deblur/experiment/LamRes_L1/result/GOPR0869_11_00/blur_gamma/000070.png filter=lfs diff=lfs merge=lfs -text
81
+ backup/deblur/experiment/LamRes_L1/result/GOPR0869_11_00/blur_gamma/000100.png filter=lfs diff=lfs merge=lfs -text
82
+ backup/deblur/experiment/LamRes_L1/result/GOPR0871_11_00/blur_gamma/000020.png filter=lfs diff=lfs merge=lfs -text
83
+ backup/deblur/experiment/LamRes_L1/result/GOPR0871_11_00/blur_gamma/000030.png filter=lfs diff=lfs merge=lfs -text
84
+ backup/deblur/experiment/LamRes_L1/result/GOPR0871_11_00/blur_gamma/000040.png filter=lfs diff=lfs merge=lfs -text
85
+ backup/deblur/experiment/LamRes_L1/result/GOPR0871_11_00/blur_gamma/000050.png filter=lfs diff=lfs merge=lfs -text
86
+ backup/deblur/experiment/LamRes_L1/result/GOPR0871_11_00/blur_gamma/000060.png filter=lfs diff=lfs merge=lfs -text
87
+ backup/deblur/experiment/LamRes_L1/result/GOPR0871_11_00/blur_gamma/000070.png filter=lfs diff=lfs merge=lfs -text
88
+ backup/deblur/experiment/LamRes_L1/result/GOPR0871_11_00/blur_gamma/000080.png filter=lfs diff=lfs merge=lfs -text
89
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000001.png filter=lfs diff=lfs merge=lfs -text
90
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000011.png filter=lfs diff=lfs merge=lfs -text
91
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000021.png filter=lfs diff=lfs merge=lfs -text
92
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000031.png filter=lfs diff=lfs merge=lfs -text
93
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000041.png filter=lfs diff=lfs merge=lfs -text
94
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000051.png filter=lfs diff=lfs merge=lfs -text
95
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000061.png filter=lfs diff=lfs merge=lfs -text
96
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000071.png filter=lfs diff=lfs merge=lfs -text
97
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000081.png filter=lfs diff=lfs merge=lfs -text
98
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0396_11_00/blur_gamma/000091.png filter=lfs diff=lfs merge=lfs -text
99
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0854_11_00/blur_gamma/000067.png filter=lfs diff=lfs merge=lfs -text
100
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0862_11_00/blur_gamma/000047.png filter=lfs diff=lfs merge=lfs -text
101
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0862_11_00/blur_gamma/000077.png filter=lfs diff=lfs merge=lfs -text
102
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0868_11_00/blur_gamma/000010.png filter=lfs diff=lfs merge=lfs -text
103
+ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0868_11_00/blur_gamma/000020.png filter=lfs diff=lfs merge=lfs -text
backup/bin/psrun2 ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ while true; do
4
+ date
5
+ #srun -p gpu_24h --gres=gpu:2 -c 20 --exclude=gpu42 --pty bash
6
+ srun -p gpu_24h --gres=gpu:2 -c 20 --constraint="rtx2080|titanv" --pty bash
7
+ #echo 1
8
+ sleep 5
9
+ done
backup/bin/psrun4 ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ while true; do
4
+ date
5
+ #srun -p gpu_24h --gres=gpu:4 -c 40 --exclude=gpu42 --pty bash
6
+ srun -p gpu_24h --gres=gpu:4 -c 40 --constraint="rtx2080|titanv" --pty bash
7
+ #echo 1
8
+ sleep 15
9
+ done
backup/deblur/.gitignore ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+
49
+ # Translations
50
+ *.mo
51
+ *.pot
52
+
53
+ # Django stuff:
54
+ *.log
55
+ local_settings.py
56
+
57
+ # Flask stuff:
58
+ instance/
59
+ .webassets-cache
60
+
61
+ # Scrapy stuff:
62
+ .scrapy
63
+
64
+ # Sphinx documentation
65
+ docs/_build/
66
+
67
+ # PyBuilder
68
+ target/
69
+
70
+ # Jupyter Notebook
71
+ .ipynb_checkpoints
72
+
73
+ # pyenv
74
+ .python-version
75
+
76
+ # celery beat schedule file
77
+ celerybeat-schedule
78
+
79
+ # SageMath parsed files
80
+ *.sage.py
81
+
82
+ # dotenv
83
+ .env
84
+
85
+ # virtualenv
86
+ .venv
87
+ venv/
88
+ ENV/
89
+
90
+ # Spyder project settings
91
+ .spyderproject
92
+ .spyproject
93
+
94
+ # Rope project settings
95
+ .ropeproject
96
+
97
+ # mkdocs documentation
98
+ /site
99
+
100
+ # mypy
101
+ .mypy_cache/
102
+
103
+ experiment
104
+ *.vscode
105
+ *.pt
106
+
107
+ src/test_ssim.py
108
+
109
+ *.sh
backup/deblur/DeepDeblur-PyTorch/.gitignore ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+
49
+ # Translations
50
+ *.mo
51
+ *.pot
52
+
53
+ # Django stuff:
54
+ *.log
55
+ local_settings.py
56
+
57
+ # Flask stuff:
58
+ instance/
59
+ .webassets-cache
60
+
61
+ # Scrapy stuff:
62
+ .scrapy
63
+
64
+ # Sphinx documentation
65
+ docs/_build/
66
+
67
+ # PyBuilder
68
+ target/
69
+
70
+ # Jupyter Notebook
71
+ .ipynb_checkpoints
72
+
73
+ # pyenv
74
+ .python-version
75
+
76
+ # celery beat schedule file
77
+ celerybeat-schedule
78
+
79
+ # SageMath parsed files
80
+ *.sage.py
81
+
82
+ # dotenv
83
+ .env
84
+
85
+ # virtualenv
86
+ .venv
87
+ venv/
88
+ ENV/
89
+
90
+ # Spyder project settings
91
+ .spyderproject
92
+ .spyproject
93
+
94
+ # Rope project settings
95
+ .ropeproject
96
+
97
+ # mkdocs documentation
98
+ /site
99
+
100
+ # mypy
101
+ .mypy_cache/
102
+
103
+ experiment
104
+ *.vscode
105
+ *.pt
106
+
107
+ src/test_ssim.py
108
+
109
+ *.sh
backup/deblur/DeepDeblur-PyTorch/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Seungjun Nah
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
backup/deblur/DeepDeblur-PyTorch/README.md ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepDeblur-PyTorch
2
+
3
+ This is a pytorch implementation of our research. Please refer to our CVPR 2017 paper for details:
4
+
5
+ Deep Multi-scale Convolutional Neural Network for Dynamic Scene Deblurring
6
+ [[paper](http://openaccess.thecvf.com/content_cvpr_2017/papers/Nah_Deep_Multi-Scale_Convolutional_CVPR_2017_paper.pdf)]
7
+ [[supplementary](http://openaccess.thecvf.com/content_cvpr_2017/supplemental/Nah_Deep_Multi-Scale_Convolutional_2017_CVPR_supplemental.zip)]
8
+ [[slide](https://drive.google.com/file/d/1sj7l2tGgJR-8wTyauvnSDGpiokjOzX_C/view?usp=sharing)]
9
+
10
+ If you find our work useful in your research or publication, please cite our work:
11
+ ```
12
+ @InProceedings{Nah_2017_CVPR,
13
+ author = {Nah, Seungjun and Kim, Tae Hyun and Lee, Kyoung Mu},
14
+ title = {Deep Multi-Scale Convolutional Neural Network for Dynamic Scene Deblurring},
15
+ booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
16
+ month = {July},
17
+ year = {2017}
18
+ }
19
+ ```
20
+
21
+ Original Torch7 implementaion is available [here](https://github.com/SeungjunNah/DeepDeblur_release).
22
+
23
+ ## Dependencies
24
+
25
+ * python 3 (tested with anaconda3)
26
+ * PyTorch 1.6
27
+ * tqdm
28
+ * imageio
29
+ * scikit-image
30
+ * numpy
31
+ * matplotlib
32
+ * readline
33
+
34
+ Please refer to [this issue](https://github.com/SeungjunNah/DeepDeblur-PyTorch/issues/5#issuecomment-651177352) for the versions.
35
+
36
+ ## Datasets
37
+
38
+ * GOPRO_Large: [link](https://seungjunnah.github.io/Datasets/gopro)
39
+ * REDS: [link](https://seungjunnah.github.io/Datasets/reds)
40
+
41
+ ## Usage examples
42
+
43
+ * Preparing dataset
44
+
45
+ Before running the code, put the datasets on a desired directory. By default, the data root is set as '~/Research/dataset'
46
+ See: [src/option.py](src/option.py)
47
+ ```python
48
+ group_data.add_argument('--data_root', type=str, default='~/Research/dataset', help='dataset root location')
49
+ ```
50
+ Put your dataset under ```args.data_root```.
51
+
52
+ The dataset location should be like:
53
+ ```bash
54
+ # GOPRO_Large dataset
55
+ ~/Research/dataset/GOPRO_Large/train/GOPR0372_07_00/blur_gamma/....
56
+ # REDS dataset
57
+ ~/Research/dataset/REDS/train/train_blur/000/...
58
+ ```
59
+
60
+ * Example commands
61
+
62
+ ```bash
63
+ # single GPU training
64
+ python main.py --n_GPUs 1 --batch_size 8 # save the results in default experiment/YYYY-MM-DD_hh-mm-ss
65
+ python main.py --n_GPUs 1 --batch_size 8 --save_dir GOPRO_L1 # save the results in experiment/GOPRO_L1
66
+
67
+ # adversarial training
68
+ python main.py --n_GPUs 1 --batch_size 8 --loss 1*L1+1*ADV
69
+ python main.py --n_GPUs 1 --batch_size 8 --loss 1*L1+3*ADV
70
+ python main.py --n_GPUs 1 --batch_size 8 --loss 1*L1+0.1*ADV
71
+
72
+ # train with GOPRO_Large dataset
73
+ python main.py --n_GPUs 1 --batch_size 8 --dataset GOPRO_Large
74
+ # train with REDS dataset (always set --do_test false)
75
+ python main.py --n_GPUs 1 --batch_size 8 --dataset REDS --do_test false --milestones 100 150 180 --end_epoch 200
76
+
77
+ # save part of the evaluation results (default)
78
+ python main.py --n_GPUs 1 --batch_size 8 --dataset GOPRO_Large --save_results part
79
+ # save no evaluation results (faster at test time)
80
+ python main.py --n_GPUs 1 --batch_size 8 --dataset GOPRO_Large --save_results none
81
+ # save all of the evaluation results
82
+ python main.py --n_GPUs 1 --batch_size 8 --dataset GOPRO_Large --save_results all
83
+ ```
84
+
85
+ ```bash
86
+ # multi-GPU training (DataParallel)
87
+ python main.py --n_GPUs 2 --batch_size 16
88
+ ```
89
+
90
+ ```bash
91
+ # multi-GPU training (DistributedDataParallel), recommended for the best speed
92
+ # single command version (do not set ranks)
93
+ python launch.py --n_GPUs 2 main.py --batch_size 16
94
+
95
+ # multi-command version (type in independent shells with the corresponding ranks, useful for debugging)
96
+ python main.py --batch_size 16 --distributed true --n_GPUs 2 --rank 0 # shell 0
97
+ python main.py --batch_size 16 --distributed true --n_GPUs 2 --rank 1 # shell 1
98
+ ```
99
+
100
+ ```bash
101
+ # single precision inference (default)
102
+ python launch.py --n_GPUs 2 main.py --batch_size 16 --precision single
103
+
104
+ # half precision inference (faster and requires less memory)
105
+ python launch.py --n_GPUs 2 main.py --batch_size 16 --precision half
106
+
107
+ # half precision inference with AMP
108
+ python launch.py --n_GPUs 2 main.py --batch_size 16 --amp true
109
+ ```
110
+
111
+ ```bash
112
+ # optional mixed-precision training
113
+ # mixed precision training may result in different accuracy
114
+ python main.py --n_GPUs 1 --batch_size 16 --amp true
115
+ python main.py --n_GPUs 2 --batch_size 16 --amp true
116
+ python launch.py --n_GPUs 2 main.py --batch_size 16 --amp true
117
+ ```
118
+
119
+ ```bash
120
+ # Advanced usage examples
121
+ # using launch.py is recommended for the best speed and convenience
122
+ python launch.py --n_GPUs 4 main.py --dataset GOPRO_Large
123
+ python launch.py --n_GPUs 4 main.py --dataset GOPRO_Large --milestones 500 750 900 --end_epoch 1000 --save_results none
124
+ python launch.py --n_GPUs 4 main.py --dataset GOPRO_Large --milestones 500 750 900 --end_epoch 1000 --save_results part
125
+ python launch.py --n_GPUs 4 main.py --dataset GOPRO_Large --milestones 500 750 900 --end_epoch 1000 --save_results all
126
+ python launch.py --n_GPUs 4 main.py --dataset GOPRO_Large --milestones 500 750 900 --end_epoch 1000 --save_results all --amp true
127
+
128
+ python launch.py --n_GPUs 4 main.py --dataset REDS --milestones 100 150 180 --end_epoch 200 --save_results all --do_test false
129
+ python launch.py --n_GPUs 4 main.py --dataset REDS --milestones 100 150 180 --end_epoch 200 --save_results all --do_test false --do_validate false
130
+ ```
131
+
132
+ ```bash
133
+ # Commands used to generate the below results
134
+ python launch.py --n_GPUs 2 main.py --dataset GOPRO_Large --milestones 500 750 900 --end_epoch 1000
135
+ python launch.py --n_GPUs 4 main.py --dataset REDS --milestones 100 150 180 --end_epoch 200 --do_test false
136
+ ```
137
+
138
+ For more advanced usage, please take a look at src/option.py
139
+
140
+ ## Results
141
+
142
+ * Single-precision training results
143
+
144
+ Dataset | GOPRO_Large | REDS
145
+ :--:|:--:|:--:
146
+ PSNR | 30.40 | 32.89
147
+ SSIM | 0.9018 | 0.9207
148
+ Download | [link](https://drive.google.com/file/d/1-wGC6s2D2ba-PSV60AeHf48HtYd9JkQ4/view?usp=sharing) | [link](https://drive.google.com/file/d/1aSPgVsNcPNqeGPn0Y2uGmgIwaIn5Njkv/view?usp=sharing)
149
+
150
+ * Mixed-precision training results
151
+
152
+ Dataset | GOPRO_Large | REDS | REDS (GOPRO_Large pretrained)
153
+ :--:|:--:|:--:|:--:
154
+ PSNR| 30.42 | 32.95 | 33.13
155
+ SSIM| 0.9021 | 0.9209 | 0.9237
156
+ Download | [link](https://drive.google.com/file/d/1TgiiiB-4lwWIIy8c-oSSkIy5g4GvDBKB/view?usp=sharing) | [link](https://drive.google.com/file/d/10hH5vtfGUUpy8jLvIBRCBqRoEhWRO1va/view?usp=sharing) | [link](https://drive.google.com/file/d/1YV6uhGLDBbvaiWN2_cYgUhYakmvLMAM9/view?usp=sharing)
157
+
158
+ Mixed-precision training uses less memory and is faster, especially on NVIDIA Turing-generation GPUs.
159
+ Loss scaling technique is adopted to cope with the narrow representation range of fp16.
160
+ This could improve/degrade accuracy.
161
+
162
+ * Inference speed on RTX 2080 Ti (resolution: 1280x720)
163
+
164
+ Inference in half precision has negligible effect on accuracy while it requires less memory and computation time.
165
+ type | FP32 | FP16
166
+ :--:|:--:|:--:
167
+ fps | 1.06 | 3.03
168
+ time (s) | 0.943 | 0.330
169
+
170
+ ## Demo
171
+
172
+ To use the trained models, download files, unzip, and put them under DeepDeblur-PyTorch/experiment
173
+ * [GOPRO_L1](https://drive.google.com/file/d/1AfZhyUXEA8_UdZco9EdtpWjTBAb8BbWv/view?usp=sharing)
174
+ * [REDS_L1](https://drive.google.com/file/d/1UwFNXnGBz2rCBxhvq2gKt9Uhj5FeEsa4/view?usp=sharing)
175
+ * [GOPRO_L1_amp](https://drive.google.com/file/d/1ZcP3l2ZXj-C6yrDge5d3UxcaAKRN725w/view?usp=sharing)
176
+ * [REDS_L1_amp](https://drive.google.com/file/d/1do_HOjVFj2AYTX4BbwQ0enELRWtzhW6F/view?usp=sharing)
177
+ * [REDS_L1_amp_pretrained](https://drive.google.com/file/d/1BkEgUrFtOSymVnaADfptOvqfNOYiD3J1/view?usp=sharing)
178
+
179
+ ```bash
180
+ python main.py --save_dir SAVE_DIR --demo true --demo_input_dir INPUT_DIR_NAME --demo_output_dir OUTPUT_DIR_NAME
181
+ # SAVE_DIR is the experiment directory where the parameters are saved (GOPRO_L1, REDS_L1)
182
+ # SAVE_DIR is relative to DeepDeblur-PyTorch/experiment
183
+ # demo_output_dir is by default SAVE_DIR/results
184
+ # image dataloader looks into DEMO_INPUT_DIR, recursively
185
+
186
+ # example
187
+ # single GPU (GOPRO_Large, single precision)
188
+ python main.py --save_dir GOPRO_L1 --demo true --demo_input_dir ~/Research/dataset/GOPRO_Large/test/GOPR0384_11_00/blur_gamma
189
+ # single GPU (GOPRO_Large, amp-trained model, half precision)
190
+ python main.py --save_dir GOPRO_L1_amp --demo true --demo_input_dir ~/Research/dataset/GOPRO_Large/test/GOPR0384_11_00/blur_gamma --precision half
191
+ # multi-GPU (REDS, single precision)
192
+ python launch.py --n_GPUs 2 main.py --save_dir REDS_L1 --demo true --demo_input_dir ~/Research/dataset/REDS/test/test_blur --demo_output_dir OUTPUT_DIR_NAME
193
+ # multi-GPU (REDS, half precision)
194
+ python launch.py --n_GPUs 2 main.py --save_dir REDS_L1 --demo true --demo_input_dir ~/Research/dataset/REDS/test/test_blur --demo_output_dir OUTPUT_DIR_NAME --precision half
195
+ ```
196
+
197
+ ## Differences from the original code
198
+
199
+ The default options are different from the original paper.
200
+ * RGB range is [0, 255]
201
+ * L1 loss (without adversarial loss. Usage possible. See above examples)
202
+ * Batch size increased to 16.
203
+ * Distributed multi-gpu training is recommended.
204
+ * Mixed-precision training enabled. Accuracy not guaranteed.
205
+ * SSIM function changed from MATLAB to python
206
+
207
+ ## SSIM issue
208
+
209
+ There are many different SSIM implementations.
210
+ In this repository, SSIM metric is based on the following function:
211
+ ```python
212
+ from skimage.metrics import structural_similarity
213
+ ssim = structural_similarity(ref_im, res_im, multichannel=True, gaussian_weights=True, use_sample_covariance=False)
214
+ ```
215
+ `SSIM` class in [src/loss/metric.py](src/loss/metric.py) supports PyTorch.
216
+ SSIM function in MATLAB is not correct if applied to RGB images. See [this issue](https://github.com/SeungjunNah/DeepDeblur_release/issues/51) for details.
backup/deblur/DeepDeblur-PyTorch/experiment/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ !.gitignore
backup/deblur/DeepDeblur-PyTorch/src/__pycache__/option.cpython-37.pyc ADDED
Binary file (8.84 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/__pycache__/template.cpython-37.pyc ADDED
Binary file (468 Bytes). View file
 
backup/deblur/DeepDeblur-PyTorch/src/__pycache__/train.cpython-37.pyc ADDED
Binary file (5.32 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/__pycache__/utils.cpython-37.pyc ADDED
Binary file (6.6 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/data/__init__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generic dataset loader"""
2
+
3
+ from importlib import import_module
4
+
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.data import SequentialSampler, RandomSampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from .sampler import DistributedEvalSampler
9
+
10
+ class Data():
11
+ def __init__(self, args):
12
+
13
+ self.modes = ['train', 'val', 'test', 'demo']
14
+
15
+ self.action = {
16
+ 'train': args.do_train,
17
+ 'val': args.do_validate,
18
+ 'test': args.do_test,
19
+ 'demo': args.demo
20
+ }
21
+
22
+ self.dataset_name = {
23
+ 'train': args.data_train,
24
+ 'val': args.data_val,
25
+ 'test': args.data_test,
26
+ 'demo': 'Demo'
27
+ }
28
+
29
+ self.args = args
30
+
31
+ def _get_data_loader(mode='train'):
32
+ dataset_name = self.dataset_name[mode]
33
+ dataset = import_module('data.' + dataset_name.lower())
34
+ dataset = getattr(dataset, dataset_name)(args, mode)
35
+
36
+ if mode == 'train':
37
+ if args.distributed:
38
+ batch_size = int(args.batch_size / args.n_GPUs) # batch size per GPU (single-node training)
39
+ sampler = DistributedSampler(dataset, shuffle=True, num_replicas=args.world_size, rank=args.rank)
40
+ num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training)
41
+ else:
42
+ batch_size = args.batch_size
43
+ sampler = RandomSampler(dataset, replacement=False)
44
+ num_workers = args.num_workers
45
+ drop_last = True
46
+
47
+ elif mode in ('val', 'test', 'demo'):
48
+ if args.distributed:
49
+ batch_size = 1 # 1 image per GPU
50
+ sampler = DistributedEvalSampler(dataset, shuffle=False, num_replicas=args.world_size, rank=args.rank)
51
+ num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training)
52
+ else:
53
+ batch_size = args.n_GPUs # 1 image per GPU
54
+ sampler = SequentialSampler(dataset)
55
+ num_workers = args.num_workers
56
+ drop_last = False
57
+
58
+ loader = DataLoader(
59
+ dataset=dataset,
60
+ batch_size=batch_size,
61
+ shuffle=False,
62
+ sampler=sampler,
63
+ num_workers=num_workers,
64
+ pin_memory=True,
65
+ drop_last=drop_last,
66
+ )
67
+
68
+ return loader
69
+
70
+ self.loaders = {}
71
+ for mode in self.modes:
72
+ if self.action[mode]:
73
+ self.loaders[mode] = _get_data_loader(mode)
74
+ print('===> Loading {} dataset: {}'.format(mode, self.dataset_name[mode]))
75
+ else:
76
+ self.loaders[mode] = None
77
+
78
+ def get_loader(self):
79
+ return self.loaders
backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (2.04 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/common.cpython-37.pyc ADDED
Binary file (5.14 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/dataset.cpython-37.pyc ADDED
Binary file (4.54 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/gopro_large.cpython-37.pyc ADDED
Binary file (1.25 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/data/__pycache__/sampler.cpython-37.pyc ADDED
Binary file (4.48 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/data/common.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from skimage.color import rgb2hsv, hsv2rgb
4
+ from skimage.transform import pyramid_gaussian
5
+
6
+ import torch
7
+
8
+ def _apply(func, x):
9
+
10
+ if isinstance(x, (list, tuple)):
11
+ return [_apply(func, x_i) for x_i in x]
12
+ elif isinstance(x, dict):
13
+ y = {}
14
+ for key, value in x.items():
15
+ y[key] = _apply(func, value)
16
+ return y
17
+ else:
18
+ return func(x)
19
+
20
+ def crop(*args, ps=256): # patch_size
21
+ # args = [input, target]
22
+ def _get_shape(*args):
23
+ if isinstance(args[0], (list, tuple)):
24
+ return _get_shape(args[0][0])
25
+ elif isinstance(args[0], dict):
26
+ return _get_shape(list(args[0].values())[0])
27
+ else:
28
+ return args[0].shape
29
+
30
+ h, w, _ = _get_shape(args)
31
+
32
+ py = random.randrange(0, h-ps+1)
33
+ px = random.randrange(0, w-ps+1)
34
+
35
+ def _crop(img):
36
+ if img.ndim == 2:
37
+ return img[py:py+ps, px:px+ps, np.newaxis]
38
+ else:
39
+ return img[py:py+ps, px:px+ps, :]
40
+
41
+ return _apply(_crop, args)
42
+
43
+ def add_noise(*args, sigma_sigma=2, rgb_range=255):
44
+
45
+ if len(args) == 1: # usually there is only a single input
46
+ args = args[0]
47
+
48
+ sigma = np.random.normal() * sigma_sigma * rgb_range/255
49
+
50
+ def _add_noise(img):
51
+ noise = np.random.randn(*img.shape).astype(np.float32) * sigma
52
+ return (img + noise).clip(0, rgb_range)
53
+
54
+ return _apply(_add_noise, args)
55
+
56
+ def augment(*args, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=255):
57
+ """augmentation consistent to input and target"""
58
+
59
+ choices = (False, True)
60
+
61
+ hflip = hflip and random.choice(choices)
62
+ vflip = rot and random.choice(choices)
63
+ rot90 = rot and random.choice(choices)
64
+ # shuffle = shuffle
65
+
66
+ if shuffle:
67
+ rgb_order = list(range(3))
68
+ random.shuffle(rgb_order)
69
+ if rgb_order == list(range(3)):
70
+ shuffle = False
71
+
72
+ if change_saturation:
73
+ amp_factor = np.random.uniform(0.5, 1.5)
74
+
75
+ def _augment(img):
76
+ if hflip: img = img[:, ::-1, :]
77
+ if vflip: img = img[::-1, :, :]
78
+ if rot90: img = img.transpose(1, 0, 2)
79
+ if shuffle and img.ndim > 2:
80
+ if img.shape[-1] == 3: # RGB image only
81
+ img = img[..., rgb_order]
82
+
83
+ if change_saturation:
84
+ hsv_img = rgb2hsv(img)
85
+ hsv_img[..., 1] *= amp_factor
86
+
87
+ img = hsv2rgb(hsv_img).clip(0, 1) * rgb_range
88
+
89
+ return img.astype(np.float32)
90
+
91
+ return _apply(_augment, args)
92
+
93
+ def pad(img, divisor=4, pad_width=None, negative=False):
94
+
95
+ def _pad_numpy(img, divisor=4, pad_width=None, negative=False):
96
+ if pad_width is None:
97
+ (h, w, _) = img.shape
98
+ pad_h = -h % divisor
99
+ pad_w = -w % divisor
100
+ pad_width = ((0, pad_h), (0, pad_w), (0, 0))
101
+
102
+ img = np.pad(img, pad_width, mode='edge')
103
+
104
+ return img, pad_width
105
+
106
+ def _pad_tensor(img, divisor=4, pad_width=None, negative=False):
107
+
108
+ n, c, h, w = img.shape
109
+ if pad_width is None:
110
+ pad_h = -h % divisor
111
+ pad_w = -w % divisor
112
+ pad_width = (0, pad_w, 0, pad_h)
113
+ else:
114
+ try:
115
+ pad_h = pad_width[0][1]
116
+ pad_w = pad_width[1][1]
117
+ if isinstance(pad_h, torch.Tensor):
118
+ pad_h = pad_h.item()
119
+ if isinstance(pad_w, torch.Tensor):
120
+ pad_w = pad_w.item()
121
+
122
+ pad_width = (0, pad_w, 0, pad_h)
123
+ except:
124
+ pass
125
+
126
+ if negative:
127
+ pad_width = [-val for val in pad_width]
128
+
129
+ img = torch.nn.functional.pad(img, pad_width, 'reflect')
130
+
131
+ return img, pad_width
132
+
133
+ if isinstance(img, np.ndarray):
134
+ return _pad_numpy(img, divisor, pad_width, negative)
135
+ else: # torch.Tensor
136
+ return _pad_tensor(img, divisor, pad_width, negative)
137
+
138
+ def generate_pyramid(*args, n_scales):
139
+
140
+ def _generate_pyramid(img):
141
+ if img.dtype != np.float32:
142
+ img = img.astype(np.float32)
143
+ pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True))
144
+
145
+ return pyramid
146
+
147
+ return _apply(_generate_pyramid, args)
148
+
149
+ def np2tensor(*args):
150
+ def _np2tensor(x):
151
+ np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1))
152
+ tensor = torch.from_numpy(np_transpose)
153
+
154
+ return tensor
155
+
156
+ return _apply(_np2tensor, args)
157
+
158
+ def to(*args, device=None, dtype=torch.float):
159
+
160
+ def _to(x):
161
+ return x.to(device=device, dtype=dtype, non_blocking=True, copy=False)
162
+
163
+ return _apply(_to, args)
backup/deblur/DeepDeblur-PyTorch/src/data/dataset.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import imageio
4
+ import numpy as np
5
+ import torch.utils.data as data
6
+
7
+ from data import common
8
+
9
+ from utils import interact
10
+
11
+ class Dataset(data.Dataset):
12
+ """Basic dataloader class
13
+ """
14
+ def __init__(self, args, mode='train'):
15
+ super(Dataset, self).__init__()
16
+ self.args = args
17
+ self.mode = mode
18
+
19
+ self.modes = ()
20
+ self.set_modes()
21
+ self._check_mode()
22
+
23
+ self.set_keys()
24
+
25
+ if self.mode == 'train':
26
+ dataset = args.data_train
27
+ elif self.mode == 'val':
28
+ dataset = args.data_val
29
+ elif self.mode == 'test':
30
+ dataset = args.data_test
31
+ elif self.mode == 'demo':
32
+ pass
33
+ else:
34
+ raise NotImplementedError('not implemented for this mode: {}!'.format(self.mode))
35
+
36
+ if self.mode == 'demo':
37
+ self.subset_root = args.demo_input_dir
38
+ else:
39
+ self.subset_root = os.path.join(args.data_root, dataset, self.mode)
40
+
41
+ self.blur_list = []
42
+ self.sharp_list = []
43
+
44
+ self._scan()
45
+
46
+ def set_modes(self):
47
+ self.modes = ('train', 'val', 'test', 'demo')
48
+
49
+ def _check_mode(self):
50
+ """Should be called in the child class __init__() after super
51
+ """
52
+ if self.mode not in self.modes:
53
+ raise NotImplementedError('mode error: not for {}'.format(self.mode))
54
+
55
+ return
56
+
57
+ def set_keys(self):
58
+ self.blur_key = 'blur' # to be overwritten by child class
59
+ self.sharp_key = 'sharp' # to be overwritten by child class
60
+
61
+ self.non_blur_keys = []
62
+ self.non_sharp_keys = []
63
+
64
+ return
65
+
66
+ def _scan(self, root=None):
67
+ """Should be called in the child class __init__() after super
68
+ """
69
+ if root is None:
70
+ root = self.subset_root
71
+
72
+ if self.blur_key in self.non_blur_keys:
73
+ self.non_blur_keys.remove(self.blur_key)
74
+ if self.sharp_key in self.non_sharp_keys:
75
+ self.non_sharp_keys.remove(self.sharp_key)
76
+
77
+ def _key_check(path, true_key, false_keys):
78
+ path = os.path.join(path, '')
79
+ if path.find(true_key) >= 0:
80
+ for false_key in false_keys:
81
+ if path.find(false_key) >= 0:
82
+ return False
83
+
84
+ return True
85
+ else:
86
+ return False
87
+
88
+ def _get_list_by_key(root, true_key, false_keys):
89
+ data_list = []
90
+ for sub, dirs, files in os.walk(root):
91
+ if not dirs:
92
+ file_list = [os.path.join(sub, f) for f in files]
93
+ if _key_check(sub, true_key, false_keys):
94
+ data_list += file_list
95
+
96
+ data_list.sort()
97
+
98
+ return data_list
99
+
100
+ def _rectify_keys():
101
+ self.blur_key = os.path.join(self.blur_key, '')
102
+ self.non_blur_keys = [os.path.join(non_blur_key, '') for non_blur_key in self.non_blur_keys]
103
+ self.sharp_key = os.path.join(self.sharp_key, '')
104
+ self.non_sharp_keys = [os.path.join(non_sharp_key, '') for non_sharp_key in self.non_sharp_keys]
105
+
106
+ _rectify_keys()
107
+
108
+ self.blur_list = _get_list_by_key(root, self.blur_key, self.non_blur_keys)
109
+ self.sharp_list = _get_list_by_key(root, self.sharp_key, self.non_sharp_keys)
110
+
111
+ if len(self.sharp_list) > 0:
112
+ assert(len(self.blur_list) == len(self.sharp_list))
113
+
114
+ return
115
+
116
+ def __getitem__(self, idx):
117
+
118
+ blur = imageio.imread(self.blur_list[idx], pilmode='RGB')
119
+ if len(self.sharp_list) > 0:
120
+ sharp = imageio.imread(self.sharp_list[idx], pilmode='RGB')
121
+ imgs = [blur, sharp]
122
+ else:
123
+ imgs = [blur]
124
+
125
+ pad_width = 0 # dummy value
126
+ if self.mode == 'train':
127
+ imgs = common.crop(*imgs, ps=self.args.patch_size)
128
+ if self.args.augment:
129
+ imgs = common.augment(*imgs, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=self.args.rgb_range)
130
+ imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range)
131
+ elif self.mode == 'demo':
132
+ imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1)) # pad in case of non-divisible size
133
+ else:
134
+ pass # deliver test image as is.
135
+
136
+ if self.args.gaussian_pyramid:
137
+ imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales)
138
+
139
+ imgs = common.np2tensor(*imgs)
140
+ relpath = os.path.relpath(self.blur_list[idx], self.subset_root)
141
+
142
+ blur = imgs[0]
143
+ sharp = imgs[1] if len(imgs) > 1 else False
144
+
145
+ return blur, sharp, pad_width, idx, relpath
146
+
147
+ def __len__(self):
148
+ return len(self.blur_list)
149
+ # return 32
150
+
151
+
152
+
153
+
154
+
backup/deblur/DeepDeblur-PyTorch/src/data/demo.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.dataset import Dataset
2
+
3
+ from utils import interact
4
+
5
+ class Demo(Dataset):
6
+ """Demo train, test subset class
7
+ """
8
+ def __init__(self, args, mode='demo'):
9
+ super(Demo, self).__init__(args, mode)
10
+
11
+ def set_modes(self):
12
+ self.modes = ('demo')
13
+
14
+ def set_keys(self):
15
+ super(Demo, self).set_keys()
16
+ self.blur_key = '' # all the files
17
+ self.non_sharp_keys = [''] # no files
18
+
19
+ def __getitem__(self, idx):
20
+ blur, sharp, pad_width, idx, relpath = super(Demo, self).__getitem__(idx)
21
+
22
+ return blur, sharp, pad_width, idx, relpath
backup/deblur/DeepDeblur-PyTorch/src/data/gopro_large.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.dataset import Dataset
2
+
3
+ from utils import interact
4
+
5
+ class GOPRO_Large(Dataset):
6
+ """GOPRO_Large train, test subset class
7
+ """
8
+ def __init__(self, args, mode='train'):
9
+ super(GOPRO_Large, self).__init__(args, mode)
10
+
11
+ def set_modes(self):
12
+ self.modes = ('train', 'test')
13
+
14
+ def set_keys(self):
15
+ super(GOPRO_Large, self).set_keys()
16
+ self.blur_key = 'blur_gamma'
17
+ # self.sharp_key = 'sharp'
18
+
19
+ def __getitem__(self, idx):
20
+ blur, sharp, pad_width, idx, relpath = super(GOPRO_Large, self).__getitem__(idx)
21
+ relpath = relpath.replace('{}/'.format(self.blur_key), '')
22
+
23
+ return blur, sharp, pad_width, idx, relpath
backup/deblur/DeepDeblur-PyTorch/src/data/reds.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.dataset import Dataset
2
+
3
+ from utils import interact
4
+
5
+ class REDS(Dataset):
6
+ """REDS train, val, test subset class
7
+ """
8
+ def __init__(self, args, mode='train'):
9
+ super(REDS, self).__init__(args, mode)
10
+
11
+ def set_modes(self):
12
+ self.modes = ('train', 'val', 'test')
13
+
14
+ def set_keys(self):
15
+ super(REDS, self).set_keys()
16
+ # self.blur_key = 'blur'
17
+ # self.sharp_key = 'sharp'
18
+
19
+ self.non_blur_keys = ['blur', 'blur_comp', 'blur_bicubic']
20
+ self.non_blur_keys.remove(self.blur_key)
21
+ self.non_sharp_keys = ['sharp_bicubic', 'sharp']
22
+ self.non_sharp_keys.remove(self.sharp_key)
23
+
24
+ def __getitem__(self, idx):
25
+ blur, sharp, pad_width, idx, relpath = super(REDS, self).__getitem__(idx)
26
+ relpath = relpath.replace('{}/{}/'.format(self.mode, self.blur_key), '')
27
+
28
+ return blur, sharp, pad_width, idx, relpath
backup/deblur/DeepDeblur-PyTorch/src/data/sampler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data import Sampler
4
+ import torch.distributed as dist
5
+
6
+
7
+ class DistributedEvalSampler(Sampler):
8
+ r"""
9
+ DistributedEvalSampler is different from DistributedSampler.
10
+ It does NOT add extra samples to make it evenly divisible.
11
+ DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever.
12
+ See this issue for details: https://github.com/pytorch/pytorch/issues/22584
13
+ shuffle is disabled by default
14
+
15
+ DistributedEvalSampler is for evaluation purpose where synchronization does not happen every epoch.
16
+ Synchronization should be done outside the dataloader loop.
17
+
18
+ Sampler that restricts data loading to a subset of the dataset.
19
+
20
+ It is especially useful in conjunction with
21
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
22
+ process can pass a :class`~torch.utils.data.DistributedSampler` instance as a
23
+ :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
24
+ original dataset that is exclusive to it.
25
+
26
+ .. note::
27
+ Dataset is assumed to be of constant size.
28
+
29
+ Arguments:
30
+ dataset: Dataset used for sampling.
31
+ num_replicas (int, optional): Number of processes participating in
32
+ distributed training. By default, :attr:`rank` is retrieved from the
33
+ current distributed group.
34
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
35
+ By default, :attr:`rank` is retrieved from the current distributed
36
+ group.
37
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
38
+ indices.
39
+ seed (int, optional): random seed used to shuffle the sampler if
40
+ :attr:`shuffle=True`. This number should be identical across all
41
+ processes in the distributed group. Default: ``0``.
42
+
43
+ .. warning::
44
+ In distributed mode, calling the :meth`set_epoch(epoch) <set_epoch>` method at
45
+ the beginning of each epoch **before** creating the :class:`DataLoader` iterator
46
+ is necessary to make shuffling work properly across multiple epochs. Otherwise,
47
+ the same ordering will be always used.
48
+
49
+ Example::
50
+
51
+ >>> sampler = DistributedSampler(dataset) if is_distributed else None
52
+ >>> loader = DataLoader(dataset, shuffle=(sampler is None),
53
+ ... sampler=sampler)
54
+ >>> for epoch in range(start_epoch, n_epochs):
55
+ ... if is_distributed:
56
+ ... sampler.set_epoch(epoch)
57
+ ... train(loader)
58
+ """
59
+
60
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, seed=0):
61
+ if num_replicas is None:
62
+ if not dist.is_available():
63
+ raise RuntimeError("Requires distributed package to be available")
64
+ num_replicas = dist.get_world_size()
65
+ if rank is None:
66
+ if not dist.is_available():
67
+ raise RuntimeError("Requires distributed package to be available")
68
+ rank = dist.get_rank()
69
+ self.dataset = dataset
70
+ self.num_replicas = num_replicas
71
+ self.rank = rank
72
+ self.epoch = 0
73
+ # self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
74
+ # self.total_size = self.num_samples * self.num_replicas
75
+ self.total_size = len(self.dataset) # true value without extra samples
76
+ indices = list(range(self.total_size))
77
+ indices = indices[self.rank:self.total_size:self.num_replicas]
78
+ self.num_samples = len(indices) # true value without extra samples
79
+
80
+ self.shuffle = shuffle
81
+ self.seed = seed
82
+
83
+ def __iter__(self):
84
+ if self.shuffle:
85
+ # deterministically shuffle based on epoch and seed
86
+ g = torch.Generator()
87
+ g.manual_seed(self.seed + self.epoch)
88
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
89
+ else:
90
+ indices = list(range(len(self.dataset)))
91
+
92
+
93
+ # # add extra samples to make it evenly divisible
94
+ # indices += indices[:(self.total_size - len(indices))]
95
+ # assert len(indices) == self.total_size
96
+
97
+ # subsample
98
+ indices = indices[self.rank:self.total_size:self.num_replicas]
99
+ assert len(indices) == self.num_samples
100
+
101
+ return iter(indices)
102
+
103
+ def __len__(self):
104
+ return self.num_samples
105
+
106
+ def set_epoch(self, epoch):
107
+ r"""
108
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
109
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
110
+ sampler will yield the same ordering.
111
+
112
+ Arguments:
113
+ epoch (int): _epoch number.
114
+ """
115
+ self.epoch = epoch
backup/deblur/DeepDeblur-PyTorch/src/launch.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ distributed launcher adopted from torch.distributed.launch
2
+ usage example: https://github.com/facebookresearch/maskrcnn-benchmark
3
+ This enables using multiprocessing for each spawned process (as they are treated as main processes)
4
+ """
5
+ import sys
6
+ import subprocess
7
+ from argparse import ArgumentParser, REMAINDER
8
+
9
+ from utils import str2bool, int2str
10
+
11
+ def parse_args():
12
+ parser = ArgumentParser(description="PyTorch distributed training launch "
13
+ "helper utilty that will spawn up "
14
+ "multiple distributed processes")
15
+
16
+
17
+ parser.add_argument('--n_GPUs', type=int, default=1, help='the number of GPUs for training')
18
+
19
+ # positional
20
+ parser.add_argument("training_script", type=str,
21
+ help="The full path to the single GPU training "
22
+ "program/script to be launched in parallel, "
23
+ "followed by all the arguments for the "
24
+ "training script")
25
+
26
+ # rest from the training program
27
+ parser.add_argument('training_script_args', nargs=REMAINDER)
28
+ return parser.parse_args()
29
+
30
+ def main():
31
+ args = parse_args()
32
+
33
+ processes = []
34
+ for rank in range(0, args.n_GPUs):
35
+ cmd = [sys.executable]
36
+
37
+ cmd.append(args.training_script)
38
+ cmd.extend(args.training_script_args)
39
+
40
+ cmd += ['--distributed', 'True']
41
+ cmd += ['--launched', 'True']
42
+ cmd += ['--n_GPUs', str(args.n_GPUs)]
43
+ cmd += ['--rank', str(rank)]
44
+
45
+ process = subprocess.Popen(cmd)
46
+ processes.append(process)
47
+
48
+ for process in processes:
49
+ process.wait()
50
+ if process.returncode != 0:
51
+ raise subprocess.CalledProcessError(returncode=process.returncode,
52
+ cmd=cmd)
53
+
54
+ if __name__ == "__main__":
55
+ main()
backup/deblur/DeepDeblur-PyTorch/src/loss/__init__.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.distributed as dist
7
+
8
+ import matplotlib.pyplot as plt
9
+ plt.switch_backend('agg') # https://github.com/matplotlib/matplotlib/issues/3466
10
+
11
+ from .metric import PSNR, SSIM
12
+
13
+ from utils import interact
14
+
15
+ class Loss(torch.nn.modules.loss._Loss):
16
+ def __init__(self, args, epoch=None, model=None, optimizer=None):
17
+ """
18
+ input:
19
+ args.loss use '+' to sum over different loss functions
20
+ use '*' to specify the loss weight
21
+
22
+ example:
23
+ 1*MSE+0.5*VGG54
24
+ loss = sum of MSE and VGG54(weight=0.5)
25
+
26
+ args.measure similar to args.loss, but without weight
27
+
28
+ example:
29
+ MSE+PSNR
30
+ measure MSE and PSNR, independently
31
+ """
32
+ super(Loss, self).__init__()
33
+
34
+ self.args = args
35
+
36
+ self.rgb_range = args.rgb_range
37
+ self.device_type = args.device_type
38
+ self.synchronized = False
39
+
40
+ self.epoch = args.start_epoch if epoch is None else epoch
41
+ self.save_dir = args.save_dir
42
+ self.save_name = os.path.join(self.save_dir, 'loss.pt')
43
+
44
+ # self.training = True
45
+ self.validating = False
46
+ self.testing = False
47
+ self.mode = 'train'
48
+ self.modes = ('train', 'val', 'test')
49
+
50
+ # Loss
51
+ self.loss = nn.ModuleDict()
52
+ self.loss_types = []
53
+ self.weight = {}
54
+
55
+ self.loss_stat = {mode:{} for mode in self.modes}
56
+ # loss_stat[mode][loss_type][epoch] = loss_value
57
+ # loss_stat[mode]['Total'][epoch] = loss_total
58
+
59
+ for weighted_loss in args.loss.split('+'):
60
+ w, l = weighted_loss.split('*')
61
+ l = l.upper()
62
+ if l in ('ABS', 'L1'):
63
+ loss_type = 'L1'
64
+ func = nn.L1Loss()
65
+ elif l in ('MSE', 'L2'):
66
+ loss_type = 'L2'
67
+ func = nn.MSELoss()
68
+ elif l in ('ADV', 'GAN'):
69
+ loss_type = 'ADV'
70
+ m = import_module('loss.adversarial')
71
+ func = getattr(m, 'Adversarial')(args, model, optimizer)
72
+ else:
73
+ loss_type = l
74
+ m = import_module*'loss.{}'.format(l.lower())
75
+ func = getattr(m, l)(args)
76
+
77
+ self.loss_types += [loss_type]
78
+ self.loss[loss_type] = func
79
+ self.weight[loss_type] = float(w)
80
+
81
+ print('Loss function: {}'.format(args.loss))
82
+
83
+ # Metrics
84
+ self.do_measure = args.metric.lower() != 'none'
85
+
86
+ self.metric = nn.ModuleDict()
87
+ self.metric_types = []
88
+ self.metric_stat = {mode:{} for mode in self.modes}
89
+ # metric_stat[mode][metric_type][epoch] = metric_value
90
+
91
+ if self.do_measure:
92
+ for metric_type in args.metric.split(','):
93
+ metric_type = metric_type.upper()
94
+ if metric_type == 'PSNR':
95
+ metric_func = PSNR()
96
+ elif metric_type == 'SSIM':
97
+ metric_func = SSIM(args.device_type) # single precision
98
+ else:
99
+ raise NotImplementedError
100
+
101
+ self.metric_types += [metric_type]
102
+ self.metric[metric_type] = metric_func
103
+
104
+ print('Metrics: {}'.format(args.metric))
105
+
106
+ if args.start_epoch != 1:
107
+ self.load(args.start_epoch - 1)
108
+
109
+ for mode in self.modes:
110
+ for loss_type in self.loss:
111
+ if loss_type not in self.loss_stat[mode]:
112
+ self.loss_stat[mode][loss_type] = {} # initialize loss
113
+
114
+ if 'Total' not in self.loss_stat[mode]:
115
+ self.loss_stat[mode]['Total'] = {}
116
+
117
+ if self.do_measure:
118
+ for metric_type in self.metric:
119
+ if metric_type not in self.metric_stat[mode]:
120
+ self.metric_stat[mode][metric_type] = {}
121
+
122
+ self.count = 0
123
+ self.count_m = 0
124
+
125
+ self.to(args.device, dtype=args.dtype)
126
+
127
+ def train(self, mode=True):
128
+ super(Loss, self).train(mode)
129
+ if mode:
130
+ self.validating = False
131
+ self.testing = False
132
+ self.mode = 'train'
133
+ else: # default test mode
134
+ self.validating = False
135
+ self.testing = True
136
+ self.mode = 'test'
137
+
138
+ def validate(self):
139
+ super(Loss, self).eval()
140
+ # self.training = False
141
+ self.validating = True
142
+ self.testing = False
143
+ self.mode = 'val'
144
+
145
+ def test(self):
146
+ super(Loss, self).eval()
147
+ # self.training = False
148
+ self.validating = False
149
+ self.testing = True
150
+ self.mode = 'test'
151
+
152
+ def forward(self, input, target):
153
+ self.synchronized = False
154
+
155
+ loss = 0
156
+
157
+ def _ms_forward(input, target, func):
158
+ if isinstance(input, (list, tuple)): # loss for list output
159
+ _loss = []
160
+ for (input_i, target_i) in zip(input, target):
161
+ _loss += [func(input_i, target_i)]
162
+ return sum(_loss)
163
+ elif isinstance(input, dict): # loss for dict output
164
+ _loss = []
165
+ for key in input:
166
+ _loss += [func(input[key], target[key])]
167
+ return sum(_loss)
168
+ else: # loss for tensor output
169
+ return func(input, target)
170
+
171
+ # initialize
172
+ if self.count == 0:
173
+ for loss_type in self.loss_types:
174
+ self.loss_stat[self.mode][loss_type][self.epoch] = 0
175
+ self.loss_stat[self.mode]['Total'][self.epoch] = 0
176
+
177
+ if isinstance(input, list):
178
+ count = input[0].shape[0]
179
+ else: # Tensor
180
+ count = input.shape[0] # batch size
181
+
182
+ isnan = False
183
+ for loss_type in self.loss_types:
184
+
185
+ if loss_type == 'ADV':
186
+ _loss = self.loss[loss_type](input[0], target[0], self.training) * self.weight[loss_type]
187
+ else:
188
+ _loss = _ms_forward(input, target, self.loss[loss_type]) * self.weight[loss_type]
189
+
190
+ if torch.isnan(_loss):
191
+ isnan = True # skip recording (will also be skipped at backprop)
192
+ else:
193
+ self.loss_stat[self.mode][loss_type][self.epoch] += _loss.item() * count
194
+ self.loss_stat[self.mode]['Total'][self.epoch] += _loss.item() * count
195
+
196
+ loss += _loss
197
+
198
+ if not isnan:
199
+ self.count += count
200
+
201
+ if not self.training and self.do_measure:
202
+ self.measure(input, target)
203
+
204
+ return loss
205
+
206
+ def measure(self, input, target):
207
+ if isinstance(input, (list, tuple)):
208
+ self.measure(input[0], target[0])
209
+ return
210
+ elif isinstance(input, dict):
211
+ first_key = list(input.keys())[0]
212
+ self.measure(input[first_key], target[first_key])
213
+ return
214
+ else:
215
+ pass
216
+
217
+ if self.count_m == 0:
218
+ for metric_type in self.metric_stat[self.mode]:
219
+ self.metric_stat[self.mode][metric_type][self.epoch] = 0
220
+
221
+ if isinstance(input, list):
222
+ count = input[0].shape[0]
223
+ else: # Tensor
224
+ count = input.shape[0] # batch size
225
+
226
+ for metric_type in self.metric_stat[self.mode]:
227
+
228
+ input = input.clamp(0, self.rgb_range) # not in_place
229
+ if self.rgb_range == 255:
230
+ input.round_()
231
+
232
+ _metric = self.metric[metric_type](input, target)
233
+ self.metric_stat[self.mode][metric_type][self.epoch] += _metric.item() * count
234
+
235
+ self.count_m += count
236
+
237
+ return
238
+
239
+ def normalize(self):
240
+ if self.args.distributed:
241
+ dist.barrier()
242
+ if not self.synchronized:
243
+ self.all_reduce()
244
+
245
+ if self.count > 0:
246
+ for loss_type in self.loss_stat[self.mode]: # including 'Total'
247
+ self.loss_stat[self.mode][loss_type][self.epoch] /= self.count
248
+ self.count = 0
249
+
250
+ if self.count_m > 0:
251
+ for metric_type in self.metric_stat[self.mode]:
252
+ self.metric_stat[self.mode][metric_type][self.epoch] /= self.count_m
253
+ self.count_m = 0
254
+
255
+ return
256
+
257
+ def all_reduce(self, epoch=None):
258
+ # synchronize loss for distributed GPU processes
259
+
260
+ if epoch is None:
261
+ epoch = self.epoch
262
+
263
+ def _reduce_value(value, ReduceOp=dist.ReduceOp.SUM):
264
+ value_tensor = torch.Tensor([value]).to(self.args.device, self.args.dtype, non_blocking=True)
265
+ dist.all_reduce(value_tensor, ReduceOp, async_op=False)
266
+ value = value_tensor.item()
267
+ del value_tensor
268
+
269
+ return value
270
+
271
+ dist.barrier()
272
+ if self.count > 0: # I assume this should be true
273
+ self.count = _reduce_value(self.count, dist.ReduceOp.SUM)
274
+
275
+ for loss_type in self.loss_stat[self.mode]:
276
+ self.loss_stat[self.mode][loss_type][epoch] = _reduce_value(
277
+ self.loss_stat[self.mode][loss_type][epoch],
278
+ dist.ReduceOp.SUM
279
+ )
280
+
281
+ if self.count_m > 0:
282
+ self.count_m = _reduce_value(self.count_m, dist.ReduceOp.SUM)
283
+
284
+ for metric_type in self.metric_stat[self.mode]:
285
+ self.metric_stat[self.mode][metric_type][epoch] = _reduce_value(
286
+ self.metric_stat[self.mode][metric_type][epoch],
287
+ dist.ReduceOp.SUM
288
+ )
289
+
290
+ self.synchronized = True
291
+
292
+ return
293
+
294
+ def print_metrics(self):
295
+
296
+ print(self.get_metric_desc())
297
+ return
298
+
299
+ def get_last_loss(self):
300
+ return self.loss_stat[self.mode]['Total'][self.epoch]
301
+
302
+ def get_loss_desc(self):
303
+
304
+ if self.mode == 'train':
305
+ desc_prefix = 'Train'
306
+ elif self.mode == 'val':
307
+ desc_prefix = 'Validation'
308
+ else:
309
+ desc_prefix = 'Test'
310
+
311
+ loss = self.loss_stat[self.mode]['Total'][self.epoch]
312
+ if self.count > 0:
313
+ loss /= self.count
314
+ desc = '{} Loss: {:.1f}'.format(desc_prefix, loss)
315
+
316
+ if self.mode in ('val', 'test'):
317
+ metric_desc = self.get_metric_desc()
318
+ desc = '{}{}'.format(desc, metric_desc)
319
+
320
+ return desc
321
+
322
+ def get_metric_desc(self):
323
+ desc = ''
324
+ for metric_type in self.metric_stat[self.mode]:
325
+ measured = self.metric_stat[self.mode][metric_type][self.epoch]
326
+ if self.count_m > 0:
327
+ measured /= self.count_m
328
+
329
+ if metric_type == 'PSNR':
330
+ desc += ' {}: {:2.2f}'.format(metric_type, measured)
331
+ elif metric_type == 'SSIM':
332
+ desc += ' {}: {:1.4f}'.format(metric_type, measured)
333
+ else:
334
+ desc += ' {}: {:2.4f}'.format(metric_type, measured)
335
+
336
+ return desc
337
+
338
+ def step(self, plot_name=None):
339
+ self.normalize()
340
+ self.plot(plot_name)
341
+ if not self.training and self.do_measure:
342
+ # self.print_metrics()
343
+ self.plot_metric()
344
+ # self.epoch += 1
345
+
346
+ return
347
+
348
+ def save(self):
349
+
350
+ state = {
351
+ 'loss_stat': self.loss_stat,
352
+ 'metric_stat': self.metric_stat,
353
+ }
354
+ torch.save(state, self.save_name)
355
+
356
+ return
357
+
358
+ def load(self, epoch=None):
359
+
360
+ print('Loading loss record from {}'.format(self.save_name))
361
+ if os.path.exists(self.save_name):
362
+ state = torch.load(self.save_name, map_location=self.args.device)
363
+
364
+ self.loss_stat = state['loss_stat']
365
+ if 'metric_stat' in state:
366
+ self.metric_stat = state['metric_stat']
367
+ else:
368
+ pass
369
+ else:
370
+ print('no loss record found for {}!'.format(self.save_name))
371
+
372
+ if epoch is not None:
373
+ self.epoch = epoch
374
+
375
+ return
376
+
377
+ def plot(self, plot_name=None, metric=False):
378
+
379
+ self.plot_loss(plot_name)
380
+
381
+ if metric:
382
+ self.plot_metric(plot_name)
383
+ # else:
384
+ # self.plot_loss(plot_name)
385
+
386
+ return
387
+
388
+
389
+ def plot_loss(self, plot_name=None):
390
+ if plot_name is None:
391
+ plot_name = os.path.join(self.save_dir, "{}_loss.pdf".format(self.mode))
392
+
393
+ title = "{} loss".format(self.mode)
394
+
395
+ fig = plt.figure()
396
+ plt.title(title)
397
+ plt.xlabel('epochs')
398
+ plt.ylabel('loss')
399
+ plt.grid(True, linestyle=':')
400
+
401
+ for loss_type, loss_record in self.loss_stat[self.mode].items(): # including Total
402
+ axis = sorted([epoch for epoch in loss_record.keys() if epoch <= self.epoch])
403
+ value = [self.loss_stat[self.mode][loss_type][epoch] for epoch in axis]
404
+ label = loss_type
405
+
406
+ plt.plot(axis, value, label=label)
407
+
408
+ plt.xlim(0, self.epoch)
409
+ plt.legend()
410
+ plt.savefig(plot_name)
411
+ plt.close(fig)
412
+
413
+ return
414
+
415
+ def plot_metric(self, plot_name=None):
416
+ # assume there are only max 2 metrics
417
+ if plot_name is None:
418
+ plot_name = os.path.join(self.save_dir, "{}_metric.pdf".format(self.mode))
419
+
420
+ title = "{} metrics".format(self.mode)
421
+
422
+ fig, ax1 = plt.subplots()
423
+ plt.title(title)
424
+ plt.grid(True, linestyle=':')
425
+ ax1.set_xlabel('epochs')
426
+
427
+ plots = None
428
+ for metric_type, metric_record in self.metric_stat[self.mode].items():
429
+ axis = sorted([epoch for epoch in metric_record.keys() if epoch <= self.epoch])
430
+ value = [metric_record[epoch] for epoch in axis]
431
+ label = metric_type
432
+
433
+ if metric_type == 'PSNR':
434
+ ax = ax1
435
+ color='C0'
436
+ elif metric_type == 'SSIM':
437
+ ax2 = ax1.twinx()
438
+ ax = ax2
439
+ color='C1'
440
+
441
+ ax.set_ylabel(metric_type)
442
+ if plots is None:
443
+ plots = ax.plot(axis, value, label=label, color=color)
444
+ else:
445
+ plots += ax.plot(axis, value, label=label, color=color)
446
+
447
+ labels = [plot.get_label() for plot in plots]
448
+ plt.legend(plots, labels)
449
+ plt.xlim(0, self.epoch)
450
+ plt.savefig(plot_name)
451
+ plt.close(fig)
452
+
453
+ return
454
+
455
+ def sort(self):
456
+ # sort the loss/metric record
457
+ for mode in self.modes:
458
+ for loss_type, loss_epochs in self.loss_stat[mode].items():
459
+ self.loss_stat[mode][loss_type] = {epoch: loss_epochs[epoch] for epoch in sorted(loss_epochs)}
460
+
461
+ for metric_type, metric_epochs in self.metric_stat[mode].items():
462
+ self.metric_stat[mode][metric_type] = {epoch: metric_epochs[epoch] for epoch in sorted(metric_epochs)}
463
+
464
+ return self
backup/deblur/DeepDeblur-PyTorch/src/loss/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (11.1 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/loss/__pycache__/metric.cpython-37.pyc ADDED
Binary file (3.37 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/loss/adversarial.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from utils import interact
5
+
6
+ import torch.cuda.amp as amp
7
+
8
+ class Adversarial(nn.modules.loss._Loss):
9
+ # pure loss function without saving & loading option
10
+ # but trains deiscriminator
11
+ def __init__(self, args, model, optimizer):
12
+ super(Adversarial, self).__init__()
13
+ self.args = args
14
+ self.model = model.model
15
+ self.optimizer = optimizer
16
+ self.scaler = amp.GradScaler(
17
+ init_scale=self.args.init_scale,
18
+ enabled=self.args.amp
19
+ )
20
+
21
+ self.gan_k = 1
22
+
23
+ self.BCELoss = nn.BCEWithLogitsLoss()
24
+
25
+ def forward(self, fake, real, training=False):
26
+ if training:
27
+ # update discriminator
28
+ fake_detach = fake.detach()
29
+ for _ in range(self.gan_k):
30
+ self.optimizer.D.zero_grad()
31
+ # d: B x 1 tensor
32
+ with amp.autocast(self.args.amp):
33
+ d_fake = self.model.D(fake_detach)
34
+ d_real = self.model.D(real)
35
+
36
+ label_fake = torch.zeros_like(d_fake)
37
+ label_real = torch.ones_like(d_real)
38
+
39
+ loss_d = self.BCELoss(d_fake, label_fake) + self.BCELoss(d_real, label_real)
40
+
41
+ self.scaler.scale(loss_d).backward(retain_graph=False)
42
+ self.scaler.step(self.optimizer.D)
43
+ self.scaler.update()
44
+ else:
45
+ d_real = self.model.D(real)
46
+ label_real = torch.ones_like(d_real)
47
+
48
+ # update generator (outside here)
49
+ d_fake_bp = self.model.D(fake)
50
+ loss_g = self.BCELoss(d_fake_bp, label_real)
51
+
52
+ return loss_g
backup/deblur/DeepDeblur-PyTorch/src/loss/metric.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from skimage.metrics import peak_signal_noise_ratio, structural_similarity
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ def _expand(img):
7
+ if img.ndim < 4:
8
+ img = img.expand([1] * (4-img.ndim) + list(img.shape))
9
+
10
+ return img
11
+
12
+ class PSNR(nn.Module):
13
+ def __init__(self):
14
+ super(PSNR, self).__init__()
15
+
16
+ def forward(self, im1, im2, data_range=None):
17
+ # tensor input, constant output
18
+
19
+ if data_range is None:
20
+ data_range = 255 if im1.max() > 1 else 1
21
+
22
+ se = (im1-im2)**2
23
+ se = _expand(se)
24
+
25
+ mse = se.mean(dim=list(range(1, se.ndim)))
26
+ psnr = 10 * (data_range**2/mse).log10().mean()
27
+
28
+ return psnr
29
+
30
+ class SSIM(nn.Module):
31
+ def __init__(self, device_type='cpu', dtype=torch.float32):
32
+ super(SSIM, self).__init__()
33
+
34
+ self.device_type = device_type
35
+ self.dtype = dtype # SSIM in half precision could be inaccurate
36
+
37
+ def _get_ssim_weight():
38
+ truncate = 3.5
39
+ sigma = 1.5
40
+ r = int(truncate * sigma + 0.5) # radius as in ndimage
41
+ win_size = 2 * r + 1
42
+ nch = 3
43
+
44
+ weight = torch.Tensor([-(x - win_size//2)**2/float(2*sigma**2) for x in range(win_size)]).exp().unsqueeze(1)
45
+ weight = weight.mm(weight.t())
46
+ weight /= weight.sum()
47
+ weight = weight.repeat(nch, 1, 1, 1)
48
+
49
+ return weight
50
+
51
+ self.weight = _get_ssim_weight().to(self.device_type, dtype=self.dtype, non_blocking=True)
52
+
53
+ def forward(self, im1, im2, data_range=None):
54
+ """Implementation adopted from skimage.metrics.structural_similarity
55
+ Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False
56
+ """
57
+
58
+ im1 = im1.to(self.device_type, dtype=self.dtype, non_blocking=True)
59
+ im2 = im2.to(self.device_type, dtype=self.dtype, non_blocking=True)
60
+
61
+ K1 = 0.01
62
+ K2 = 0.03
63
+ sigma = 1.5
64
+
65
+ truncate = 3.5
66
+ r = int(truncate * sigma + 0.5) # radius as in ndimage
67
+ win_size = 2 * r + 1
68
+
69
+ im1 = _expand(im1)
70
+ im2 = _expand(im2)
71
+
72
+ nch = im1.shape[1]
73
+
74
+ if im1.shape[2] < win_size or im1.shape[3] < win_size:
75
+ raise ValueError(
76
+ "win_size exceeds image extent. If the input is a multichannel "
77
+ "(color) image, set multichannel=True.")
78
+
79
+ if data_range is None:
80
+ data_range = 255 if im1.max() > 1 else 1
81
+
82
+ def filter_func(img): # no padding
83
+ return nn.functional.conv2d(img, self.weight, groups=nch).to(self.dtype)
84
+ # return torch.conv2d(img, self.weight, groups=nch).to(self.dtype)
85
+
86
+ # compute (weighted) means
87
+ ux = filter_func(im1)
88
+ uy = filter_func(im2)
89
+
90
+ # compute (weighted) variances and covariances
91
+ uxx = filter_func(im1 * im1)
92
+ uyy = filter_func(im2 * im2)
93
+ uxy = filter_func(im1 * im2)
94
+ vx = (uxx - ux * ux)
95
+ vy = (uyy - uy * uy)
96
+ vxy = (uxy - ux * uy)
97
+
98
+ R = data_range
99
+ C1 = (K1 * R) ** 2
100
+ C2 = (K2 * R) ** 2
101
+
102
+ A1, A2, B1, B2 = ((2 * ux * uy + C1,
103
+ 2 * vxy + C2,
104
+ ux ** 2 + uy ** 2 + C1,
105
+ vx + vy + C2))
106
+ D = B1 * B2
107
+ S = (A1 * A2) / D
108
+
109
+ # compute (weighted) mean of ssim
110
+ mssim = S.mean()
111
+
112
+ return mssim
backup/deblur/DeepDeblur-PyTorch/src/main.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """main file that does everything"""
2
+ from utils import interact
3
+
4
+ from option import args, setup, cleanup
5
+ from data import Data
6
+ from model import Model
7
+ from loss import Loss
8
+ from optim import Optimizer
9
+ from train import Trainer
10
+
11
+ def main_worker(rank, args):
12
+ args.rank = rank
13
+ args = setup(args)
14
+
15
+ loaders = Data(args).get_loader()
16
+ model = Model(args)
17
+ model.parallelize()
18
+ optimizer = Optimizer(args, model)
19
+
20
+ criterion = Loss(args, model=model, optimizer=optimizer)
21
+
22
+ trainer = Trainer(args, model, criterion, optimizer, loaders)
23
+
24
+ if args.stay:
25
+ interact(local=locals())
26
+ exit()
27
+
28
+ if args.demo:
29
+ trainer.evaluate(epoch=args.start_epoch, mode='demo')
30
+ exit()
31
+
32
+ for epoch in range(1, args.start_epoch):
33
+ if args.do_validate:
34
+ if epoch % args.validate_every == 0:
35
+ trainer.fill_evaluation(epoch, 'val')
36
+ if args.do_test:
37
+ if epoch % args.test_every == 0:
38
+ trainer.fill_evaluation(epoch, 'test')
39
+
40
+ for epoch in range(args.start_epoch, args.end_epoch+1):
41
+ if args.do_train:
42
+ trainer.train(epoch)
43
+
44
+ if args.do_validate:
45
+ if epoch % args.validate_every == 0:
46
+ if trainer.epoch != epoch:
47
+ trainer.load(epoch)
48
+ trainer.validate(epoch)
49
+
50
+ if args.do_test:
51
+ if epoch % args.test_every == 0:
52
+ if trainer.epoch != epoch:
53
+ trainer.load(epoch)
54
+ trainer.test(epoch)
55
+
56
+ if args.rank == 0 or not args.launched:
57
+ print('')
58
+
59
+ trainer.imsaver.join_background()
60
+
61
+ cleanup(args)
62
+
63
+ def main():
64
+ main_worker(args.rank, args)
65
+
66
+ if __name__ == "__main__":
67
+ main()
backup/deblur/DeepDeblur-PyTorch/src/model/MSResNet.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import common
5
+ from .ResNet import ResNet
6
+
7
+
8
+ def build_model(args):
9
+ return MSResNet(args)
10
+
11
+ class conv_end(nn.Module):
12
+ def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
13
+ super(conv_end, self).__init__()
14
+
15
+ modules = [
16
+ common.default_conv(in_channels, out_channels, kernel_size),
17
+ nn.PixelShuffle(ratio)
18
+ ]
19
+
20
+ self.uppath = nn.Sequential(*modules)
21
+
22
+ def forward(self, x):
23
+ return self.uppath(x)
24
+
25
+ class MSResNet(nn.Module):
26
+ def __init__(self, args):
27
+ super(MSResNet, self).__init__()
28
+
29
+ self.rgb_range = args.rgb_range
30
+ self.mean = self.rgb_range / 2
31
+
32
+ self.n_resblocks = args.n_resblocks
33
+ self.n_feats = args.n_feats
34
+ self.kernel_size = args.kernel_size
35
+
36
+ self.n_scales = args.n_scales
37
+
38
+ self.body_models = nn.ModuleList([
39
+ ResNet(args, 3, 3, mean_shift=False),
40
+ ])
41
+ for _ in range(1, self.n_scales):
42
+ self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False))
43
+
44
+ self.conv_end_models = nn.ModuleList([None])
45
+ for _ in range(1, self.n_scales):
46
+ self.conv_end_models += [conv_end(3, 12)]
47
+
48
+ def forward(self, input_pyramid):
49
+
50
+ scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse
51
+
52
+ for s in scales:
53
+ input_pyramid[s] = input_pyramid[s] - self.mean
54
+
55
+ output_pyramid = [None] * self.n_scales
56
+
57
+ input_s = input_pyramid[-1]
58
+ for s in scales: # [2, 1, 0]
59
+ output_pyramid[s] = self.body_models[s](input_s)
60
+ if s > 0:
61
+ up_feat = self.conv_end_models[s](output_pyramid[s])
62
+ input_s = torch.cat((input_pyramid[s-1], up_feat), 1)
63
+
64
+ for s in scales:
65
+ output_pyramid[s] = output_pyramid[s] + self.mean
66
+
67
+ return output_pyramid
backup/deblur/DeepDeblur-PyTorch/src/model/ResNet.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from . import common
4
+
5
+ def build_model(args):
6
+ return ResNet(args)
7
+
8
+ class ResNet(nn.Module):
9
+ def __init__(self, args, in_channels=3, out_channels=3, n_feats=None, kernel_size=None, n_resblocks=None, mean_shift=True):
10
+ super(ResNet, self).__init__()
11
+
12
+ self.in_channels = in_channels
13
+ self.out_channels = out_channels
14
+
15
+ self.n_feats = args.n_feats if n_feats is None else n_feats
16
+ self.kernel_size = args.kernel_size if kernel_size is None else kernel_size
17
+ self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks
18
+
19
+ self.mean_shift = mean_shift
20
+ self.rgb_range = args.rgb_range
21
+ self.mean = self.rgb_range / 2
22
+
23
+ modules = []
24
+ modules.append(common.default_conv(self.in_channels, self.n_feats, self.kernel_size))
25
+ for _ in range(self.n_resblocks):
26
+ modules.append(common.ResBlock(self.n_feats, self.kernel_size))
27
+ modules.append(common.default_conv(self.n_feats, self.out_channels, self.kernel_size))
28
+
29
+ self.body = nn.Sequential(*modules)
30
+
31
+ def forward(self, input):
32
+ if self.mean_shift:
33
+ input = input - self.mean
34
+
35
+ output = self.body(input)
36
+
37
+ if self.mean_shift:
38
+ output = output + self.mean
39
+
40
+ return output
41
+
backup/deblur/DeepDeblur-PyTorch/src/model/__init__.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from importlib import import_module
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
8
+
9
+ import torch.distributed as dist
10
+ from torch.nn.utils import parameters_to_vector, vector_to_parameters
11
+
12
+ from .discriminator import Discriminator
13
+
14
+ from utils import interact
15
+
16
+ class Model(nn.Module):
17
+ def __init__(self, args):
18
+ super(Model, self).__init__()
19
+
20
+ self.args = args
21
+ self.device = args.device
22
+ self.n_GPUs = args.n_GPUs
23
+ self.save_dir = os.path.join(args.save_dir, 'models')
24
+ os.makedirs(self.save_dir, exist_ok=True)
25
+
26
+ module = import_module('model.' + args.model)
27
+
28
+ self.model = nn.ModuleDict()
29
+ self.model.G = module.build_model(args)
30
+ if self.args.loss.lower().find('adv') >= 0:
31
+ self.model.D = Discriminator(self.args)
32
+ else:
33
+ self.model.D = None
34
+
35
+ self.to(args.device, dtype=args.dtype, non_blocking=True)
36
+ self.load(args.load_epoch, path=args.pretrained)
37
+
38
+ def parallelize(self):
39
+ if self.args.device_type == 'cuda':
40
+ if self.args.distributed:
41
+ Parallel = DistributedDataParallel
42
+ parallel_args = {
43
+ "device_ids": [self.args.rank],
44
+ "output_device": self.args.rank,
45
+ }
46
+ else:
47
+ Parallel = DataParallel
48
+ parallel_args = {
49
+ 'device_ids': list(range(self.n_GPUs)),
50
+ 'output_device': self.args.rank # always 0
51
+ }
52
+
53
+ for model_key in self.model:
54
+ if self.model[model_key] is not None:
55
+ self.model[model_key] = Parallel(self.model[model_key], **parallel_args)
56
+
57
+ def forward(self, input):
58
+ return self.model.G(input)
59
+
60
+ def _save_path(self, epoch):
61
+ model_path = os.path.join(self.save_dir, 'model-{:d}.pt'.format(epoch))
62
+ return model_path
63
+
64
+ def state_dict(self):
65
+ state_dict = {}
66
+ for model_key in self.model:
67
+ if self.model[model_key] is not None:
68
+ parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel))
69
+ if parallelized:
70
+ state_dict[model_key] = self.model[model_key].module.state_dict()
71
+ else:
72
+ state_dict[model_key] = self.model[model_key].state_dict()
73
+
74
+ return state_dict
75
+
76
+ def load_state_dict(self, state_dict, strict=True):
77
+ for model_key in self.model:
78
+ parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel))
79
+ if model_key in state_dict:
80
+ if parallelized:
81
+ self.model[model_key].module.load_state_dict(state_dict[model_key], strict)
82
+ else:
83
+ self.model[model_key].load_state_dict(state_dict[model_key], strict)
84
+
85
+ def save(self, epoch):
86
+ torch.save(self.state_dict(), self._save_path(epoch))
87
+
88
+ def load(self, epoch=None, path=None):
89
+ if path:
90
+ model_name = path
91
+ elif isinstance(epoch, int):
92
+ if epoch < 0:
93
+ epoch = self.get_last_epoch()
94
+ if epoch == 0: # epoch 0
95
+ # make sure model parameters are synchronized at initial
96
+ # for multi-node training (not in current implementation)
97
+ # self.synchronize()
98
+
99
+ return # leave model as initialized
100
+
101
+ model_name = self._save_path(epoch)
102
+ else:
103
+ raise Exception('no epoch number or model path specified!')
104
+
105
+ print('Loading model from {}'.format(model_name))
106
+ state_dict = torch.load(model_name, map_location=self.args.device)
107
+ self.load_state_dict(state_dict)
108
+
109
+ return
110
+
111
+ def synchronize(self):
112
+ if self.args.distributed:
113
+ # synchronize model parameters across nodes
114
+ vector = parameters_to_vector(self.parameters())
115
+
116
+ dist.broadcast(vector, 0) # broadcast parameters to other processes
117
+ if self.args.rank != 0:
118
+ vector_to_parameters(vector, self.parameters())
119
+
120
+ del vector
121
+
122
+ return
123
+
124
+ def get_last_epoch(self):
125
+ model_list = sorted(os.listdir(self.save_dir))
126
+ if len(model_list) == 0:
127
+ epoch = 0
128
+ else:
129
+ epoch = int(re.findall('\\d+', model_list[-1])[0]) # model example name model-100.pt
130
+
131
+ return epoch
132
+
133
+ def print(self):
134
+ print(self.model)
135
+
136
+ return
backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/MSResNet.cpython-37.pyc ADDED
Binary file (2.13 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/ResNet.cpython-37.pyc ADDED
Binary file (1.32 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (3.9 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/common.cpython-37.pyc ADDED
Binary file (5.95 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/model/__pycache__/discriminator.cpython-37.pyc ADDED
Binary file (1.42 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/model/common.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def default_conv(in_channels, out_channels, kernel_size, bias=True, groups=1):
7
+ return nn.Conv2d(
8
+ in_channels, out_channels, kernel_size,
9
+ padding=(kernel_size // 2), bias=bias, groups=groups)
10
+
11
+ def default_norm(n_feats):
12
+ return nn.BatchNorm2d(n_feats)
13
+
14
+ def default_act():
15
+ return nn.ReLU(True)
16
+
17
+ def empty_h(x, n_feats):
18
+ '''
19
+ create an empty hidden state
20
+
21
+ input
22
+ x: B x T x 3 x H x W
23
+
24
+ output
25
+ h: B x C x H/4 x W/4
26
+ '''
27
+ b = x.size(0)
28
+ h, w = x.size()[-2:]
29
+ return x.new_zeros((b, n_feats, h//4, w//4))
30
+
31
+ class Normalization(nn.Conv2d):
32
+ """Normalize input tensor value with convolutional layer"""
33
+ def __init__(self, mean=(0, 0, 0), std=(1, 1, 1)):
34
+ super(Normalization, self).__init__(3, 3, kernel_size=1)
35
+ tensor_mean = torch.Tensor(mean)
36
+ tensor_inv_std = torch.Tensor(std).reciprocal()
37
+
38
+ self.weight.data = torch.eye(3).mul(tensor_inv_std).view(3, 3, 1, 1)
39
+ self.bias.data = torch.Tensor(-tensor_mean.mul(tensor_inv_std))
40
+
41
+ for params in self.parameters():
42
+ params.requires_grad = False
43
+
44
+ class BasicBlock(nn.Sequential):
45
+ """Convolution layer + Activation layer"""
46
+ def __init__(
47
+ self, in_channels, out_channels, kernel_size, bias=True,
48
+ conv=default_conv, norm=False, act=default_act):
49
+
50
+ modules = []
51
+ modules.append(
52
+ conv(in_channels, out_channels, kernel_size, bias=bias))
53
+ if norm: modules.append(norm(out_channels))
54
+ if act: modules.append(act())
55
+
56
+ super(BasicBlock, self).__init__(*modules)
57
+
58
+ class ResBlock(nn.Module):
59
+ def __init__(
60
+ self, n_feats, kernel_size, bias=True,
61
+ conv=default_conv, norm=False, act=default_act):
62
+
63
+ super(ResBlock, self).__init__()
64
+
65
+ modules = []
66
+ for i in range(2):
67
+ modules.append(conv(n_feats, n_feats, kernel_size, bias=bias))
68
+ if norm: modules.append(norm(n_feats))
69
+ if act and i == 0: modules.append(act())
70
+
71
+ self.body = nn.Sequential(*modules)
72
+
73
+ def forward(self, x):
74
+ res = self.body(x)
75
+ res += x
76
+
77
+ return res
78
+
79
+ class ResBlock_mobile(nn.Module):
80
+ def __init__(
81
+ self, n_feats, kernel_size, bias=True,
82
+ conv=default_conv, norm=False, act=default_act, dropout=False):
83
+
84
+ super(ResBlock_mobile, self).__init__()
85
+
86
+ modules = []
87
+ for i in range(2):
88
+ modules.append(conv(n_feats, n_feats, kernel_size, bias=False, groups=n_feats))
89
+ modules.append(conv(n_feats, n_feats, 1, bias=False))
90
+ if dropout and i == 0: modules.append(nn.Dropout2d(dropout))
91
+ if norm: modules.append(norm(n_feats))
92
+ if act and i == 0: modules.append(act())
93
+
94
+ self.body = nn.Sequential(*modules)
95
+
96
+ def forward(self, x):
97
+ res = self.body(x)
98
+ res += x
99
+
100
+ return res
101
+
102
+ class Upsampler(nn.Sequential):
103
+ def __init__(
104
+ self, scale, n_feats, bias=True,
105
+ conv=default_conv, norm=False, act=False):
106
+
107
+ modules = []
108
+ if (scale & (scale - 1)) == 0: # Is scale = 2^n?
109
+ for _ in range(int(math.log(scale, 2))):
110
+ modules.append(conv(n_feats, 4 * n_feats, 3, bias))
111
+ modules.append(nn.PixelShuffle(2))
112
+ if norm: modules.append(norm(n_feats))
113
+ if act: modules.append(act())
114
+ elif scale == 3:
115
+ modules.append(conv(n_feats, 9 * n_feats, 3, bias))
116
+ modules.append(nn.PixelShuffle(3))
117
+ if norm: modules.append(norm(n_feats))
118
+ if act: modules.append(act())
119
+ else:
120
+ raise NotImplementedError
121
+
122
+ super(Upsampler, self).__init__(*modules)
123
+
124
+ # Only support 1 / 2
125
+ class PixelSort(nn.Module):
126
+ """The inverse operation of PixelShuffle
127
+ Reduces the spatial resolution, increasing the number of channels.
128
+ Currently, scale 0.5 is supported only.
129
+ Later, torch.nn.functional.pixel_sort may be implemented.
130
+ Reference:
131
+ http://pytorch.org/docs/0.3.0/_modules/torch/nn/modules/pixelshuffle.html#PixelShuffle
132
+ http://pytorch.org/docs/0.3.0/_modules/torch/nn/functional.html#pixel_shuffle
133
+ """
134
+ def __init__(self, upscale_factor=0.5):
135
+ super(PixelSort, self).__init__()
136
+ self.upscale_factor = upscale_factor
137
+
138
+ def forward(self, x):
139
+ b, c, h, w = x.size()
140
+ x = x.view(b, c, 2, 2, h // 2, w // 2)
141
+ x = x.permute(0, 1, 5, 3, 2, 4).contiguous()
142
+ x = x.view(b, 4 * c, h // 2, w // 2)
143
+
144
+ return x
145
+
146
+ class Downsampler(nn.Sequential):
147
+ def __init__(
148
+ self, scale, n_feats, bias=True,
149
+ conv=default_conv, norm=False, act=False):
150
+
151
+ modules = []
152
+ if scale == 0.5:
153
+ modules.append(PixelSort())
154
+ modules.append(conv(4 * n_feats, n_feats, 3, bias))
155
+ if norm: modules.append(norm(n_feats))
156
+ if act: modules.append(act())
157
+ else:
158
+ raise NotImplementedError
159
+
160
+ super(Downsampler, self).__init__(*modules)
161
+
backup/deblur/DeepDeblur-PyTorch/src/model/discriminator.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class Discriminator(nn.Module):
4
+ def __init__(self, args):
5
+ super(Discriminator, self).__init__()
6
+
7
+ # self.args = args
8
+ n_feats = args.n_feats
9
+ kernel_size = args.kernel_size
10
+
11
+ def conv(kernel_size, in_channel, n_feats, stride, pad=None):
12
+ if pad is None:
13
+ pad = (kernel_size-1)//2
14
+
15
+ return nn.Conv2d(in_channel, n_feats, kernel_size, stride=stride, padding=pad, bias=False)
16
+
17
+ self.conv_layers = nn.ModuleList([
18
+ conv(kernel_size, 3, n_feats//2, 1), # 256
19
+ conv(kernel_size, n_feats//2, n_feats//2, 2), # 128
20
+ conv(kernel_size, n_feats//2, n_feats, 1),
21
+ conv(kernel_size, n_feats, n_feats, 2), # 64
22
+ conv(kernel_size, n_feats, n_feats*2, 1),
23
+ conv(kernel_size, n_feats*2, n_feats*2, 4), # 16
24
+ conv(kernel_size, n_feats*2, n_feats*4, 1),
25
+ conv(kernel_size, n_feats*4, n_feats*4, 4), # 4
26
+ conv(kernel_size, n_feats*4, n_feats*8, 1),
27
+ conv(4, n_feats*8, n_feats*8, 4, 0), # 1
28
+ ])
29
+
30
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
31
+ self.dense = nn.Conv2d(n_feats*8, 1, 1, bias=False)
32
+
33
+ def forward(self, x):
34
+
35
+ for layer in self.conv_layers:
36
+ x = self.act(layer(x))
37
+
38
+ x = self.dense(x)
39
+
40
+ return x
41
+
backup/deblur/DeepDeblur-PyTorch/src/model/structure.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .common import ResBlock, default_conv
4
+
5
+ def encoder(in_channels, n_feats):
6
+ """RGB / IR feature encoder
7
+ """
8
+
9
+ # in_channels == 1 or 3 or 4 or ....
10
+ # After 1st conv, B x n_feats x H x W
11
+ # After 2nd conv, B x 2n_feats x H/2 x W/2
12
+ # After 3rd conv, B x 3n_feats x H/4 x W/4
13
+ return nn.Sequential(
14
+ nn.Conv2d(in_channels, 1 * n_feats, 5, stride=1, padding=2),
15
+ nn.Conv2d(1 * n_feats, 2 * n_feats, 5, stride=2, padding=2),
16
+ nn.Conv2d(2 * n_feats, 3 * n_feats, 5, stride=2, padding=2),
17
+ )
18
+
19
+ def decoder(out_channels, n_feats):
20
+ """RGB / IR / Depth decoder
21
+ """
22
+ # After 1st deconv, B x 2n_feats x H/2 x W/2
23
+ # After 2nd deconv, B x n_feats x H x W
24
+ # After 3rd conv, B x out_channels x H x W
25
+ deconv_kargs = {'stride': 2, 'padding': 1, 'output_padding': 1}
26
+
27
+ return nn.Sequential(
28
+ nn.ConvTranspose2d(3 * n_feats, 2 * n_feats, 3, **deconv_kargs),
29
+ nn.ConvTranspose2d(2 * n_feats, 1 * n_feats, 3, **deconv_kargs),
30
+ nn.Conv2d(n_feats, out_channels, 5, stride=1, padding=2),
31
+ )
32
+
33
+ # def ResNet(n_feats, in_channels=None, out_channels=None):
34
+ def ResNet(n_feats, kernel_size, n_blocks, in_channels=None, out_channels=None):
35
+ """sequential ResNet
36
+ """
37
+
38
+ # if in_channels is None:
39
+ # in_channels = n_feats
40
+ # if out_channels is None:
41
+ # out_channels = n_feats
42
+ # # currently not implemented
43
+
44
+ m = []
45
+
46
+ if in_channels is not None:
47
+ m += [default_conv(in_channels, n_feats, kernel_size)]
48
+
49
+ m += [ResBlock(n_feats, 3)] * n_blocks
50
+
51
+ if out_channels is not None:
52
+ m += [default_conv(n_feats, out_channels, kernel_size)]
53
+
54
+
55
+ return nn.Sequential(*m)
56
+
backup/deblur/DeepDeblur-PyTorch/src/optim/__init__.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import torch.optim.lr_scheduler as lrs
4
+
5
+ import os
6
+ from collections import Counter
7
+
8
+ from model import Model
9
+ from utils import interact, Map
10
+
11
+ class Optimizer(object):
12
+ def __init__(self, args, model):
13
+ self.args = args
14
+
15
+ self.save_dir = os.path.join(self.args.save_dir, 'optim')
16
+ os.makedirs(self.save_dir, exist_ok=True)
17
+
18
+ if isinstance(model, Model):
19
+ model = model.model
20
+
21
+ # set base arguments
22
+ kwargs_optimizer = {
23
+ 'lr': args.lr,
24
+ 'weight_decay': args.weight_decay
25
+ }
26
+
27
+ if args.optimizer == 'SGD':
28
+ optimizer_class = optim.SGD
29
+ kwargs_optimizer['momentum'] = args.momentum
30
+ elif args.optimizer == 'ADAM':
31
+ optimizer_class = optim.Adam
32
+ kwargs_optimizer['betas'] = args.betas
33
+ kwargs_optimizer['eps'] = args.epsilon
34
+ elif args.optimizer == 'RMSPROP':
35
+ optimizer_class = optim.RMSprop
36
+ kwargs_optimizer['eps'] = args.epsilon
37
+
38
+ # scheduler
39
+ if args.scheduler == 'step':
40
+ scheduler_class = lrs.MultiStepLR
41
+ kwargs_scheduler = {
42
+ 'milestones': args.milestones,
43
+ 'gamma': args.gamma,
44
+ }
45
+ elif args.scheduler == 'plateau':
46
+ scheduler_class = lrs.ReduceLROnPlateau
47
+ kwargs_scheduler = {
48
+ 'mode': 'min',
49
+ 'factor': args.gamma,
50
+ 'patience': 10,
51
+ 'verbose': True,
52
+ 'threshold': 0,
53
+ 'threshold_mode': 'abs',
54
+ 'cooldown': 10,
55
+ }
56
+
57
+ self.kwargs_optimizer = kwargs_optimizer
58
+ self.scheduler_class = scheduler_class
59
+ self.kwargs_scheduler = kwargs_scheduler
60
+
61
+ def _get_optimizer(model):
62
+
63
+ class _Optimizer(optimizer_class):
64
+ def __init__(self, model, args, scheduler_class, kwargs_scheduler):
65
+ trainable = filter(lambda x: x.requires_grad, model.parameters())
66
+ super(_Optimizer, self).__init__(trainable, **kwargs_optimizer)
67
+
68
+ self.args = args
69
+
70
+ self._register_scheduler(scheduler_class, kwargs_scheduler)
71
+
72
+ def _register_scheduler(self, scheduler_class, kwargs_scheduler):
73
+ self.scheduler = scheduler_class(self, **kwargs_scheduler)
74
+
75
+ def schedule(self, metrics=None):
76
+ if isinstance(self, lrs.ReduceLROnPlateau):
77
+ self.scheduler.step(metrics)
78
+ else:
79
+ self.scheduler.step()
80
+
81
+ def get_last_epoch(self):
82
+ return self.scheduler.last_epoch
83
+
84
+ def get_lr(self):
85
+ return self.param_groups[0]['lr']
86
+
87
+ def get_last_lr(self):
88
+ return self.scheduler.get_last_lr()[0]
89
+
90
+ def state_dict(self):
91
+ state_dict = super(_Optimizer, self).state_dict() # {'state': ..., 'param_groups': ...}
92
+ state_dict['scheduler'] = self.scheduler.state_dict()
93
+
94
+ return state_dict
95
+
96
+ def load_state_dict(self, state_dict, epoch=None):
97
+ # optimizer
98
+ super(_Optimizer, self).load_state_dict(state_dict) # load 'state' and 'param_groups' only
99
+ # scheduler
100
+ self.scheduler.load_state_dict(state_dict['scheduler']) # should work for plateau or simple resuming
101
+
102
+ reschedule = False
103
+ if isinstance(self.scheduler, lrs.MultiStepLR):
104
+ if self.args.milestones != list(self.scheduler.milestones) or self.args.gamma != self.scheduler.gamma:
105
+ reschedule = True
106
+
107
+ if reschedule:
108
+ if epoch is None:
109
+ if self.scheduler.last_epoch > 1:
110
+ epoch = self.scheduler.last_epoch
111
+ else:
112
+ epoch = self.args.start_epoch - 1
113
+
114
+ # if False:
115
+ # # option 1. new scheduler
116
+ # for i, group in enumerate(self.param_groups):
117
+ # self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial
118
+ # # self.scheduler = None
119
+ # self._register_scheduler(scheduler_class, kwargs_scheduler)
120
+
121
+ # self.zero_grad()
122
+ # self.step()
123
+ # for _ in range(epoch):
124
+ # self.scheduler.step()
125
+ # self._step_count -= 1
126
+
127
+ # else:
128
+ # option 2. modify existing scheduler
129
+ self.scheduler.milestones = Counter(self.args.milestones)
130
+ self.scheduler.gamma = self.args.gamma
131
+ for i, group in enumerate(self.param_groups):
132
+ self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial
133
+ multiplier = 1
134
+ for milestone in self.scheduler.milestones:
135
+ if epoch >= milestone:
136
+ multiplier *= self.scheduler.gamma
137
+
138
+ self.param_groups[i]['lr'] *= multiplier
139
+
140
+ return _Optimizer(model, args, scheduler_class, kwargs_scheduler)
141
+
142
+ self.G = _get_optimizer(model.G)
143
+ if model.D is not None:
144
+ self.D = _get_optimizer(model.D)
145
+ else:
146
+ self.D = None
147
+
148
+ self.load(args.load_epoch)
149
+
150
+ def zero_grad(self):
151
+ self.G.zero_grad()
152
+
153
+ def step(self):
154
+ self.G.step()
155
+
156
+ def schedule(self, metrics=None):
157
+ self.G.schedule(metrics)
158
+ if self.D is not None:
159
+ self.D.schedule(metrics)
160
+
161
+ def get_last_epoch(self):
162
+ return self.G.get_last_epoch()
163
+
164
+ def get_lr(self):
165
+ return self.G.get_lr()
166
+
167
+ def get_last_lr(self):
168
+ return self.G.get_last_lr()
169
+
170
+ def state_dict(self):
171
+ state_dict = Map()
172
+ state_dict.G = self.G.state_dict()
173
+ if self.D is not None:
174
+ state_dict.D = self.D.state_dict()
175
+
176
+ return state_dict.toDict()
177
+
178
+ def load_state_dict(self, state_dict, epoch=None):
179
+ state_dict = Map(**state_dict)
180
+ self.G.load_state_dict(state_dict.G, epoch)
181
+ if self.D is not None:
182
+ self.D.load_state_dict(state_dict.D, epoch)
183
+
184
+ def _save_path(self, epoch=None):
185
+ epoch = epoch if epoch is not None else self.get_last_epoch()
186
+ save_path = os.path.join(self.save_dir, 'optim-{:d}.pt'.format(epoch))
187
+
188
+ return save_path
189
+
190
+ def save(self, epoch=None):
191
+ if epoch is None:
192
+ epoch = self.G.scheduler.last_epoch
193
+ torch.save(self.state_dict(), self._save_path(epoch))
194
+
195
+ def load(self, epoch):
196
+ if epoch > 0:
197
+ print('Loading optimizer from {}'.format(self._save_path(epoch)))
198
+ self.load_state_dict(torch.load(self._save_path(epoch), map_location=self.args.device), epoch=epoch)
199
+
200
+ elif epoch == 0:
201
+ pass
202
+ else:
203
+ raise NotImplementedError
204
+
205
+ return
206
+
backup/deblur/DeepDeblur-PyTorch/src/optim/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (6.44 kB). View file
 
backup/deblur/DeepDeblur-PyTorch/src/optim/warm_multi_step_lr.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from bisect import bisect_right
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+ # MultiStep learning rate scheduler with warm restart
6
+ class WarmMultiStepLR(_LRScheduler):
7
+ def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, scale=1):
8
+ if not list(milestones) == sorted(milestones):
9
+ raise ValueError(
10
+ 'Milestones should be a list of increasing integers. Got {}',
11
+ milestones
12
+ )
13
+
14
+ self.milestones = milestones
15
+ self.gamma = gamma
16
+ self.scale = scale
17
+
18
+ self.warmup_epochs = 5
19
+ self.gradual = (self.scale - 1) / self.warmup_epochs
20
+ super(WarmMultiStepLR, self).__init__(optimizer, last_epoch)
21
+
22
+ def get_lr(self):
23
+ if self.last_epoch < self.warmup_epochs:
24
+ return [
25
+ base_lr * (1 + self.last_epoch * self.gradual) / self.scale
26
+ for base_lr in self.base_lrs
27
+ ]
28
+ else:
29
+ return [
30
+ base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
31
+ for base_lr in self.base_lrs
32
+ ]
backup/deblur/DeepDeblur-PyTorch/src/option.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """optionional argument parsing"""
2
+ # pylint: disable=C0103, C0301
3
+ import argparse
4
+ import datetime
5
+ import os
6
+ import re
7
+ import shutil
8
+ import time
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.backends.cudnn as cudnn
13
+
14
+ from utils import interact
15
+ from utils import str2bool, int2str
16
+
17
+ import template
18
+
19
+ # Training settings
20
+ parser = argparse.ArgumentParser(description='Dynamic Scene Deblurring')
21
+
22
+ # Device specifications
23
+ group_device = parser.add_argument_group('Device specs')
24
+ group_device.add_argument('--seed', type=int, default=-1, help='random seed')
25
+ group_device.add_argument('--num_workers', type=int, default=7, help='the number of dataloader workers')
26
+ group_device.add_argument('--device_type', type=str, choices=('cpu', 'cuda'), default='cuda', help='device to run models')
27
+ group_device.add_argument('--device_index', type=int, default=0, help='device id to run models')
28
+ group_device.add_argument('--n_GPUs', type=int, default=1, help='the number of GPUs for training')
29
+ group_device.add_argument('--distributed', type=str2bool, default=False, help='use DistributedDataParallel instead of DataParallel for better speed')
30
+ group_device.add_argument('--launched', type=str2bool, default=False, help='identify if main.py was executed from launch.py. Do not set this to be true using main.py.')
31
+
32
+ group_device.add_argument('--master_addr', type=str, default='127.0.0.1', help='master address for distributed')
33
+ group_device.add_argument('--master_port', type=int2str, default='8023', help='master port for distributed')
34
+ group_device.add_argument('--dist_backend', type=str, default='nccl', help='distributed backend')
35
+ group_device.add_argument('--init_method', type=str, default='env://', help='distributed init method URL to discover peers')
36
+ group_device.add_argument('--rank', type=int, default=0, help='rank of the distributed process (gpu id). 0 is the master process.')
37
+ group_device.add_argument('--world_size', type=int, default=1, help='world_size for distributed training (number of GPUs)')
38
+
39
+ # Data
40
+ group_data = parser.add_argument_group('Data specs')
41
+ group_data.add_argument('--data_root', type=str, default='/data/ssd/public/czli/deblur', help='dataset root location')
42
+ group_data.add_argument('--dataset', type=str, default=None, help='training/validation/test dataset name, has priority if not None')
43
+ group_data.add_argument('--data_train', type=str, default='GOPRO_Large', help='training dataset name')
44
+ group_data.add_argument('--data_val', type=str, default=None, help='validation dataset name')
45
+ group_data.add_argument('--data_test', type=str, default='GOPRO_Large', help='test dataset name')
46
+ group_data.add_argument('--blur_key', type=str, default='blur_gamma', choices=('blur', 'blur_gamma'), help='blur type from camera response function for GOPRO_Large dataset')
47
+ group_data.add_argument('--rgb_range', type=int, default=255, help='RGB pixel value ranging from 0')
48
+
49
+ # Model
50
+ group_model = parser.add_argument_group('Model specs')
51
+ group_model.add_argument('--model', type=str, default='MSResNet', help='model architecture')
52
+ group_model.add_argument('--pretrained', type=str, default='', help='pretrained model location')
53
+ group_model.add_argument('--n_scales', type=int, default=3, help='multi-scale deblurring level')
54
+ group_model.add_argument('--gaussian_pyramid', type=str2bool, default=True, help='gaussian pyramid input/target')
55
+ group_model.add_argument('--n_resblocks', type=int, default=19, help='number of residual blocks per scale')
56
+ group_model.add_argument('--n_feats', type=int, default=64, help='number of feature maps')
57
+ group_model.add_argument('--kernel_size', type=int, default=5, help='size of conv kernel')
58
+ group_model.add_argument('--downsample', type=str, choices=('Gaussian', 'bicubic', 'stride'), default='Gaussian', help='input pyramid generation method')
59
+
60
+ group_model.add_argument('--precision', type=str, default='single', choices=('single', 'half'), help='FP precision for test(single | half)')
61
+
62
+ # amp
63
+ group_amp = parser.add_argument_group('AMP specs')
64
+ group_amp.add_argument('--amp', type=str2bool, default=False, help='use automatic mixed precision training')
65
+ group_amp.add_argument('--init_scale', type=float, default=1024., help='initial loss scale')
66
+
67
+ # Training
68
+ group_train = parser.add_argument_group('Training specs')
69
+ group_train.add_argument('--patch_size', type=int, default=256, help='training patch size')
70
+ group_train.add_argument('--batch_size', type=int, default=16, help='input batch size for training')
71
+ group_train.add_argument('--split_batch', type=int, default=1, help='split a minibatch into smaller chunks')
72
+ group_train.add_argument('--augment', type=str2bool, default=True, help='train with data augmentation')
73
+
74
+ # Testing
75
+ group_test = parser.add_argument_group('Testing specs')
76
+ group_test.add_argument('--validate_every', type=int, default=10, help='do validation at every N epochs')
77
+ group_test.add_argument('--test_every', type=int, default=10, help='do test at every N epochs')
78
+ # group_test.add_argument('--chop', type=str2bool, default=False, help='memory-efficient forward')
79
+ # group_test.add_argument('--self_ensemble', type=str2bool, default=False, help='self-ensembled testing')
80
+
81
+ # Action
82
+ group_action = parser.add_argument_group('Source behavior')
83
+ group_action.add_argument('--do_train', type=str2bool, default=True, help='do train the model')
84
+ group_action.add_argument('--do_validate', type=str2bool, default=True, help='do validate the model')
85
+ group_action.add_argument('--do_test', type=str2bool, default=True, help='do test the model')
86
+ group_action.add_argument('--demo', type=str2bool, default=False, help='demo')
87
+ group_action.add_argument('--demo_input_dir', type=str, default='', help='demo input directory')
88
+ group_action.add_argument('--demo_output_dir', type=str, default='', help='demo output directory')
89
+
90
+ # Optimization
91
+ group_optim = parser.add_argument_group('Optimization specs')
92
+ group_optim.add_argument('--lr', type=float, default=1e-4, help='learning rate')
93
+ group_optim.add_argument('--milestones', type=int, nargs='+', default=[500, 750, 900], help='learning rate decay per N epochs')
94
+ group_optim.add_argument('--scheduler', default='step', choices=('step', 'plateau'), help='learning rate scheduler type')
95
+ group_optim.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay')
96
+ group_optim.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM', 'RMSprop'), help='optimizer to use (SGD | ADAM | RMSProp)')
97
+ group_optim.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
98
+ group_optim.add_argument('--betas', type=float, nargs=2, default=(0.9, 0.999), help='ADAM betas')
99
+ group_optim.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon')
100
+ group_optim.add_argument('--weight_decay', type=float, default=0, help='weight decay')
101
+
102
+ # Loss
103
+ group_loss = parser.add_argument_group('Loss specs')
104
+ group_loss.add_argument('--loss', type=str, default='1*L1', help='loss function configuration')
105
+ group_loss.add_argument('--metric', type=str, default='PSNR,SSIM', help='metric function configuration. ex) None | PSNR | SSIM | PSNR,SSIM')
106
+
107
+ # Logging
108
+ group_log = parser.add_argument_group('Logging specs')
109
+ group_log.add_argument('--save_dir', type=str, default='', help='subdirectory to save experiment logs')
110
+ # group_log.add_argument('--load_dir', type=str, default='', help='subdirectory to load experiment logs')
111
+ group_log.add_argument('--start_epoch', type=int, default=-1, help='(re)starting epoch number')
112
+ group_log.add_argument('--end_epoch', type=int, default=1000, help='ending epoch number')
113
+ group_log.add_argument('--load_epoch', type=int, default=-1, help='epoch number to load model (start_epoch-1 for training, start_epoch for testing)')
114
+ group_log.add_argument('--save_every', type=int, default=10, help='save model/optimizer at every N epochs')
115
+ group_log.add_argument('--save_results', type=str, default='part', choices=('none', 'part', 'all'), help='save none/part/all of result images')
116
+
117
+ # Debugging
118
+ group_debug = parser.add_argument_group('Debug specs')
119
+ group_debug.add_argument('--stay', type=str2bool, default=False, help='stay at interactive console after trainer initialization')
120
+
121
+ parser.add_argument('--template', type=str, default='', help='argument template option')
122
+
123
+ args = parser.parse_args()
124
+ template.set_template(args)
125
+
126
+ args.data_root = os.path.expanduser(args.data_root) # recognize home directory
127
+ now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
128
+ if args.save_dir == '':
129
+ args.save_dir = now
130
+ args.save_dir = os.path.join('../experiment', args.save_dir)
131
+ os.makedirs(args.save_dir, exist_ok=True)
132
+
133
+ if args.start_epoch < 0: # start from scratch or continue from the last epoch
134
+ # check if there are any models saved before
135
+ model_dir = os.path.join(args.save_dir, 'models')
136
+ model_prefix = 'model-'
137
+ if os.path.exists(model_dir):
138
+ model_list = [name for name in os.listdir(model_dir) if name.startswith(model_prefix)]
139
+ last_epoch = 0
140
+ for name in model_list:
141
+ epochNumber = int(re.findall('\\d+', name)[0]) # model example name model-100.pt
142
+ if last_epoch < epochNumber:
143
+ last_epoch = epochNumber
144
+
145
+ args.start_epoch = last_epoch + 1
146
+ else:
147
+ # train from scratch
148
+ args.start_epoch = 1
149
+ elif args.start_epoch == 0:
150
+ # remove existing directory and start over
151
+ if args.rank == 0: # maybe local rank
152
+ shutil.rmtree(args.save_dir, ignore_errors=True)
153
+ os.makedirs(args.save_dir, exist_ok=True)
154
+ args.start_epoch = 1
155
+
156
+ if args.load_epoch < 0: # load_epoch == start_epoch when doing a post-training test for a specific epoch
157
+ args.load_epoch = args.start_epoch - 1
158
+
159
+ if args.pretrained:
160
+ if args.start_epoch <= 1:
161
+ args.pretrained = os.path.join('../experiment', args.pretrained)
162
+ else:
163
+ print('starting from epoch {}! ignoring pretrained model path..'.format(args.start_epoch))
164
+ args.pretrained = ''
165
+
166
+ if args.model == 'MSResNet':
167
+ args.gaussian_pyramid = True
168
+
169
+ argname = os.path.join(args.save_dir, 'args.pt')
170
+ argname_txt = os.path.join(args.save_dir, 'args.txt')
171
+ if args.start_epoch > 1:
172
+ # load previous arguments and keep the necessary ones same
173
+
174
+ if os.path.exists(argname):
175
+ args_old = torch.load(argname)
176
+
177
+ load_list = [] # list of arguments that are fixed
178
+ # training
179
+ load_list += ['patch_size']
180
+ load_list += ['batch_size']
181
+ # data format
182
+ load_list += ['rgb_range']
183
+ load_list += ['blur_key']
184
+ # model architecture
185
+ load_list += ['n_scales']
186
+ load_list += ['n_resblocks']
187
+ load_list += ['n_feats']
188
+
189
+ for arg_part in load_list:
190
+ vars(args)[arg_part] = vars(args_old)[arg_part]
191
+
192
+ if args.dataset is not None:
193
+ args.data_train = args.dataset
194
+ args.data_val = args.dataset if args.dataset != 'GOPRO_Large' else None
195
+ args.data_test = args.dataset
196
+
197
+ if args.data_val is None:
198
+ args.do_validate = False
199
+
200
+ if args.demo_input_dir:
201
+ args.demo = True
202
+
203
+ if args.demo:
204
+ assert os.path.basename(args.save_dir) != now, 'You should specify pretrained directory by setting --save_dir SAVE_DIR'
205
+
206
+ args.data_train = ''
207
+ args.data_val = ''
208
+ args.data_test = ''
209
+
210
+ args.do_train = False
211
+ args.do_validate = False
212
+ args.do_test = False
213
+
214
+ assert len(args.demo_input_dir) > 0, 'Please specify demo_input_dir!'
215
+ args.demo_input_dir = os.path.expanduser(args.demo_input_dir)
216
+ if args.demo_output_dir:
217
+ args.demo_output_dir = os.path.expanduser(args.demo_output_dir)
218
+
219
+ args.save_results = 'all'
220
+
221
+ if args.amp:
222
+ args.precision = 'single' # model parameters should stay in fp32
223
+
224
+ if args.seed < 0:
225
+ args.seed = int(time.time())
226
+
227
+ # save arguments
228
+ if args.rank == 0:
229
+ torch.save(args, argname)
230
+ with open(argname_txt, 'a') as file:
231
+ file.write('execution at {}\n'.format(now))
232
+
233
+ for key in args.__dict__:
234
+ file.write(key + ': ' + str(args.__dict__[key]) + '\n')
235
+
236
+ file.write('\n')
237
+
238
+ # device and type
239
+ if args.device_type == 'cuda' and not torch.cuda.is_available():
240
+ raise Exception("GPU not available!")
241
+
242
+ if not args.distributed:
243
+ args.rank = 0
244
+
245
+ def setup(args):
246
+ cudnn.benchmark = True
247
+
248
+ if args.distributed:
249
+ os.environ['MASTER_ADDR'] = args.master_addr
250
+ os.environ['MASTER_PORT'] = args.master_port
251
+
252
+ args.device_index = args.rank
253
+ args.world_size = args.n_GPUs # consider single-node training
254
+
255
+ # initialize the process group
256
+ dist.init_process_group(args.dist_backend, init_method=args.init_method, rank=args.rank, world_size=args.world_size)
257
+
258
+ args.device = torch.device(args.device_type, args.device_index)
259
+ args.dtype = torch.float32
260
+ args.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
261
+
262
+ # set seed for processes (distributed: different seed for each process)
263
+ # model parameters are synchronized explicitly at initial
264
+ torch.manual_seed(args.seed)
265
+ if args.device_type == 'cuda':
266
+ torch.cuda.set_device(args.device)
267
+ if args.rank == 0:
268
+ torch.cuda.manual_seed_all(args.seed)
269
+
270
+ return args
271
+
272
+ def cleanup(args):
273
+ if args.distributed:
274
+ dist.destroy_process_group()
backup/deblur/DeepDeblur-PyTorch/src/prepare.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ workplace=`pwd -P`
3
+
4
+ if [ ! -d "/data/ssd/public/czli/deblur/GOPRO_Large" ]; then
5
+ workplace=`pwd`
6
+ echo "Copying dataset to ssd"
7
+ cd /research/dept7/czli/deblur/dataset
8
+ mkdir -p /data/ssd/public/czli/deblur/GOPRO_Large
9
+ cp GOPRO_Large.zip /data/ssd/public/czli/deblur/GOPRO_Large
10
+ cd /data/ssd/public/czli/deblur/GOPRO_Large
11
+ echo "Dumping zip in data path:" `pwd`
12
+ for f in *.zip; do unzip "$f"; done
13
+ fi
14
+
15
+ cd "$workplace"
16
+ echo "Workplace:" `pwd`
backup/deblur/DeepDeblur-PyTorch/src/template.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ def set_template(args):
2
+ if args.template.find('gopro') >= 0:
3
+ args.dataset = 'GOPRO_Large'
4
+ args.milestones = [500, 750, 900]
5
+ args.end_epoch = 1000
6
+ elif args.template.find('reds') >= 0:
7
+ args.dataset = 'REDS'
8
+ args.milestones = [100, 150, 180]
9
+ args.end_epoch = 200
backup/deblur/DeepDeblur-PyTorch/src/train.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+
4
+ import torch
5
+
6
+ import data.common
7
+ from utils import interact, MultiSaver
8
+
9
+ import torch.cuda.amp as amp
10
+
11
+ class Trainer():
12
+
13
+ def __init__(self, args, model, criterion, optimizer, loaders):
14
+ print('===> Initializing trainer')
15
+ self.args = args
16
+ self.mode = 'train' # 'val', 'test'
17
+ self.epoch = args.start_epoch
18
+ self.save_dir = args.save_dir
19
+
20
+ self.model = model
21
+ self.criterion = criterion
22
+ self.optimizer = optimizer
23
+ self.loaders = loaders
24
+
25
+ self.do_train = args.do_train
26
+ self.do_validate = args.do_validate
27
+ self.do_test = args.do_test
28
+
29
+ self.device = args.device
30
+ self.dtype = args.dtype
31
+ self.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
32
+
33
+ if self.args.demo and self.args.demo_output_dir:
34
+ self.result_dir = self.args.demo_output_dir
35
+ else:
36
+ self.result_dir = os.path.join(self.save_dir, 'result')
37
+ os.makedirs(self.result_dir, exist_ok=True)
38
+ print('results are saved in {}'.format(self.result_dir))
39
+
40
+ self.imsaver = MultiSaver(self.result_dir)
41
+
42
+ self.is_slave = self.args.launched and self.args.rank != 0
43
+
44
+ self.scaler = amp.GradScaler(
45
+ init_scale=self.args.init_scale,
46
+ enabled=self.args.amp
47
+ )
48
+
49
+ def save(self, epoch=None):
50
+ epoch = self.epoch if epoch is None else epoch
51
+ if epoch % self.args.save_every == 0:
52
+ if self.mode == 'train':
53
+ self.model.save(epoch)
54
+ self.optimizer.save(epoch)
55
+ self.criterion.save()
56
+
57
+ return
58
+
59
+ def load(self, epoch=None, pretrained=None):
60
+ if epoch is None:
61
+ epoch = self.args.load_epoch
62
+ self.epoch = epoch
63
+ self.model.load(epoch, pretrained)
64
+ self.optimizer.load(epoch)
65
+ self.criterion.load(epoch)
66
+
67
+ return
68
+
69
+ def train(self, epoch):
70
+ self.mode = 'train'
71
+ self.epoch = epoch
72
+
73
+ self.model.train()
74
+ self.model.to(dtype=self.dtype)
75
+
76
+ self.criterion.train()
77
+ self.criterion.epoch = epoch
78
+
79
+ if not self.is_slave:
80
+ print('[Epoch {} / lr {:.2e}]'.format(
81
+ epoch, self.optimizer.get_lr()
82
+ ))
83
+
84
+ if self.args.distributed:
85
+ self.loaders[self.mode].sampler.set_epoch(epoch)
86
+ if self.is_slave:
87
+ tq = self.loaders[self.mode]
88
+ else:
89
+ tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
90
+
91
+ torch.set_grad_enabled(True)
92
+ for idx, batch in enumerate(tq):
93
+ self.optimizer.zero_grad()
94
+
95
+ input, target = data.common.to(
96
+ batch[0], batch[1], device=self.device, dtype=self.dtype)
97
+
98
+ with amp.autocast(self.args.amp):
99
+ output = self.model(input)
100
+ loss = self.criterion(output, target)
101
+
102
+ self.scaler.scale(loss).backward()
103
+ self.scaler.step(self.optimizer.G)
104
+ self.scaler.update()
105
+
106
+ if isinstance(tq, tqdm):
107
+ tq.set_description(self.criterion.get_loss_desc())
108
+
109
+ self.criterion.normalize()
110
+ if isinstance(tq, tqdm):
111
+ tq.set_description(self.criterion.get_loss_desc())
112
+ tq.display(pos=-1) # overwrite with synchronized loss
113
+
114
+ self.criterion.step()
115
+ self.optimizer.schedule(self.criterion.get_last_loss())
116
+
117
+ if self.args.rank == 0:
118
+ self.save(epoch)
119
+
120
+ return
121
+
122
+ def evaluate(self, epoch, mode='val'):
123
+ self.mode = mode
124
+ self.epoch = epoch
125
+
126
+ self.model.eval()
127
+ self.model.to(dtype=self.dtype_eval)
128
+
129
+ if mode == 'val':
130
+ self.criterion.validate()
131
+ elif mode == 'test':
132
+ self.criterion.test()
133
+ self.criterion.epoch = epoch
134
+
135
+ self.imsaver.join_background()
136
+
137
+ if self.is_slave:
138
+ tq = self.loaders[self.mode]
139
+ else:
140
+ tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
141
+
142
+ compute_loss = True
143
+ torch.set_grad_enabled(False)
144
+ for idx, batch in enumerate(tq):
145
+ input, target = data.common.to(
146
+ batch[0], batch[1], device=self.device, dtype=self.dtype_eval)
147
+ with amp.autocast(self.args.amp):
148
+ output = self.model(input)
149
+
150
+ if mode == 'demo': # remove padded part
151
+ pad_width = batch[2]
152
+ output[0], _ = data.common.pad(output[0], pad_width=pad_width, negative=True)
153
+
154
+ if isinstance(batch[1], torch.BoolTensor):
155
+ compute_loss = False
156
+
157
+ if compute_loss:
158
+ self.criterion(output, target)
159
+ if isinstance(tq, tqdm):
160
+ tq.set_description(self.criterion.get_loss_desc())
161
+
162
+ if self.args.save_results != 'none':
163
+ if isinstance(output, (list, tuple)):
164
+ result = output[0] # select last output in a pyramid
165
+ elif isinstance(output, torch.Tensor):
166
+ result = output
167
+
168
+ names = batch[-1]
169
+
170
+ if self.args.save_results == 'part' and compute_loss: # save all when GT not available
171
+ indices = batch[-2]
172
+ save_ids = [save_id for save_id, idx in enumerate(indices) if idx % 10 == 0]
173
+
174
+ result = result[save_ids]
175
+ names = [names[save_id] for save_id in save_ids]
176
+
177
+ self.imsaver.save_image(result, names)
178
+
179
+ if compute_loss:
180
+ self.criterion.normalize()
181
+ if isinstance(tq, tqdm):
182
+ tq.set_description(self.criterion.get_loss_desc())
183
+ tq.display(pos=-1) # overwrite with synchronized loss
184
+
185
+ self.criterion.step()
186
+ if self.args.rank == 0:
187
+ self.save()
188
+
189
+ self.imsaver.end_background()
190
+
191
+ def validate(self, epoch):
192
+ self.evaluate(epoch, 'val')
193
+ return
194
+
195
+ def test(self, epoch):
196
+ self.evaluate(epoch, 'test')
197
+ return
198
+
199
+ def fill_evaluation(self, epoch, mode=None, force=False):
200
+ if epoch <= 0:
201
+ return
202
+
203
+ if mode is not None:
204
+ self.mode = mode
205
+
206
+ do_eval = force
207
+ if not force:
208
+ loss_missing = epoch not in self.criterion.loss_stat[self.mode]['Total'] # should it switch to all loss types?
209
+
210
+ metric_missing = False
211
+ for metric_type in self.criterion.metric:
212
+ if epoch not in self.criterion.metric_stat[mode][metric_type]:
213
+ metric_missing = True
214
+
215
+ do_eval = loss_missing or metric_missing
216
+
217
+ if do_eval:
218
+ try:
219
+ self.load(epoch)
220
+ self.evaluate(epoch, self.mode)
221
+ except:
222
+ # print('saved model/optimizer at epoch {} not found!'.format(epoch))
223
+ pass
224
+
225
+ return
backup/deblur/DeepDeblur-PyTorch/src/utils.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import readline
2
+ import rlcompleter
3
+ readline.parse_and_bind("tab: complete")
4
+ import code
5
+ import pdb
6
+
7
+ import time
8
+ import argparse
9
+ import os
10
+ import imageio
11
+ import torch
12
+ import torch.multiprocessing as mp
13
+
14
+ # debugging tools
15
+ def interact(local=None):
16
+ """interactive console with autocomplete function. Useful for debugging.
17
+ interact(locals())
18
+ """
19
+ if local is None:
20
+ local=dict(globals(), **locals())
21
+
22
+ readline.set_completer(rlcompleter.Completer(local).complete)
23
+ code.interact(local=local)
24
+
25
+ def set_trace(local=None):
26
+ """debugging with pdb
27
+ """
28
+ if local is None:
29
+ local=dict(globals(), **locals())
30
+
31
+ pdb.Pdb.complete = rlcompleter.Completer(local).complete
32
+ pdb.set_trace()
33
+
34
+ # timer
35
+ class Timer():
36
+ """Brought from https://github.com/thstkdgus35/EDSR-PyTorch
37
+ """
38
+ def __init__(self):
39
+ self.acc = 0
40
+ self.tic()
41
+
42
+ def tic(self):
43
+ self.t0 = time.time()
44
+
45
+ def toc(self):
46
+ return time.time() - self.t0
47
+
48
+ def hold(self):
49
+ self.acc += self.toc()
50
+
51
+ def release(self):
52
+ ret = self.acc
53
+ self.acc = 0
54
+
55
+ return ret
56
+
57
+ def reset(self):
58
+ self.acc = 0
59
+
60
+
61
+ # argument parser type casting functions
62
+ def str2bool(val):
63
+ """enable default constant true arguments"""
64
+ # https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
65
+ if isinstance(val, bool):
66
+ return val
67
+ elif val.lower() == 'true':
68
+ return True
69
+ elif val.lower() == 'false':
70
+ return False
71
+ else:
72
+ raise argparse.ArgumentTypeError('Boolean value expected')
73
+
74
+ def int2str(val):
75
+ """convert int to str for environment variable related arguments"""
76
+ if isinstance(val, int):
77
+ return str(val)
78
+ elif isinstance(val, str):
79
+ return val
80
+ else:
81
+ raise argparse.ArgumentTypeError('number value expected')
82
+
83
+
84
+ # image saver using multiprocessing queue
85
+ class MultiSaver():
86
+ def __init__(self, result_dir=None):
87
+ self.queue = None
88
+ self.process = None
89
+ self.result_dir = result_dir
90
+
91
+ def begin_background(self):
92
+ self.queue = mp.Queue()
93
+
94
+ def t(queue):
95
+ while True:
96
+ if queue.empty():
97
+ continue
98
+ img, name = queue.get()
99
+ if name:
100
+ try:
101
+ basename, ext = os.path.splitext(name)
102
+ if ext != '.png':
103
+ name = '{}.png'.format(basename)
104
+ imageio.imwrite(name, img)
105
+ except Exception as e:
106
+ print(e)
107
+ else:
108
+ return
109
+
110
+ worker = lambda: mp.Process(target=t, args=(self.queue,), daemon=False)
111
+ cpu_count = min(8, mp.cpu_count() - 1)
112
+ self.process = [worker() for _ in range(cpu_count)]
113
+ for p in self.process:
114
+ p.start()
115
+
116
+ def end_background(self):
117
+ if self.queue is None:
118
+ return
119
+
120
+ for _ in self.process:
121
+ self.queue.put((None, None))
122
+
123
+ def join_background(self):
124
+ if self.queue is None:
125
+ return
126
+
127
+ while not self.queue.empty():
128
+ time.sleep(0.5)
129
+
130
+ for p in self.process:
131
+ p.join()
132
+
133
+ self.queue = None
134
+
135
+ def save_image(self, output, save_names, result_dir=None):
136
+ result_dir = result_dir if self.result_dir is None else self.result_dir
137
+ if result_dir is None:
138
+ raise Exception('no result dir specified!')
139
+
140
+ if self.queue is None:
141
+ try:
142
+ self.begin_background()
143
+ except Exception as e:
144
+ print(e)
145
+ return
146
+
147
+ # assume NCHW format
148
+ if output.ndim == 2:
149
+ output = output.expand([1, 1] + list(output.shape))
150
+ elif output.ndim == 3:
151
+ output = output.expand([1] + list(output.shape))
152
+
153
+ for output_img, save_name in zip(output, save_names):
154
+ # assume image range [0, 255]
155
+ output_img = output_img.add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
156
+
157
+ save_name = os.path.join(result_dir, save_name)
158
+ save_dir = os.path.dirname(save_name)
159
+ os.makedirs(save_dir, exist_ok=True)
160
+
161
+ self.queue.put((output_img, save_name))
162
+
163
+ return
164
+
165
+ class Map(dict):
166
+ """
167
+ https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
168
+ Example:
169
+ m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
170
+ """
171
+ def __init__(self, *args, **kwargs):
172
+ super(Map, self).__init__(*args, **kwargs)
173
+ for arg in args:
174
+ if isinstance(arg, dict):
175
+ for k, v in arg.items():
176
+ self[k] = v
177
+
178
+ if kwargs:
179
+ for k, v in kwargs.items():
180
+ self[k] = v
181
+
182
+ def __getattr__(self, attr):
183
+ return self.get(attr)
184
+
185
+ def __setattr__(self, key, value):
186
+ self.__setitem__(key, value)
187
+
188
+ def __setitem__(self, key, value):
189
+ super(Map, self).__setitem__(key, value)
190
+ self.__dict__.update({key: value})
191
+
192
+ def __delattr__(self, item):
193
+ self.__delitem__(item)
194
+
195
+ def __delitem__(self, key):
196
+ super(Map, self).__delitem__(key)
197
+ del self.__dict__[key]
198
+
199
+ def toDict(self):
200
+ return self.__dict__