hyliu commited on
Commit
8cb1339
·
verified ·
1 Parent(s): e98653e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +90 -0
  2. Demosaic/README.md +99 -0
  3. Demosaic/code/LICENSE +21 -0
  4. Demosaic/code/__init__.py +0 -0
  5. Demosaic/code/__pycache__/option.cpython-37.pyc +0 -0
  6. Demosaic/code/__pycache__/template.cpython-37.pyc +0 -0
  7. Demosaic/code/__pycache__/trainer.cpython-37.pyc +0 -0
  8. Demosaic/code/__pycache__/utility.cpython-37.pyc +0 -0
  9. Demosaic/code/data/__init__.py +52 -0
  10. Demosaic/code/data/__pycache__/__init__.cpython-37.pyc +0 -0
  11. Demosaic/code/data/__pycache__/benchmark.cpython-37.pyc +0 -0
  12. Demosaic/code/data/__pycache__/common.cpython-37.pyc +0 -0
  13. Demosaic/code/data/__pycache__/div2k.cpython-37.pyc +0 -0
  14. Demosaic/code/data/__pycache__/srdata.cpython-37.pyc +0 -0
  15. Demosaic/code/data/benchmark.py +25 -0
  16. Demosaic/code/data/common.py +72 -0
  17. Demosaic/code/data/demo.py +39 -0
  18. Demosaic/code/data/div2k.py +32 -0
  19. Demosaic/code/data/div2kjpeg.py +20 -0
  20. Demosaic/code/data/sr291.py +6 -0
  21. Demosaic/code/data/srdata.py +157 -0
  22. Demosaic/code/data/video.py +44 -0
  23. Demosaic/code/dataloader.py +158 -0
  24. Demosaic/code/demo.sb +6 -0
  25. Demosaic/code/lambda_networks/__init__.py +4 -0
  26. Demosaic/code/lambda_networks/__pycache__/__init__.cpython-37.pyc +0 -0
  27. Demosaic/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc +0 -0
  28. Demosaic/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc +0 -0
  29. Demosaic/code/lambda_networks/lambda_networks.py +140 -0
  30. Demosaic/code/lambda_networks/rlambda_networks.py +93 -0
  31. Demosaic/code/loss/__init__.py +173 -0
  32. Demosaic/code/loss/__pycache__/__init__.cpython-37.pyc +0 -0
  33. Demosaic/code/loss/adversarial.py +112 -0
  34. Demosaic/code/loss/discriminator.py +55 -0
  35. Demosaic/code/loss/vgg.py +36 -0
  36. Demosaic/code/main.py +35 -0
  37. Demosaic/code/model/LICENSE +21 -0
  38. Demosaic/code/model/__init__.py +190 -0
  39. Demosaic/code/model/__pycache__/__init__.cpython-37.pyc +0 -0
  40. Demosaic/code/model/__pycache__/attention.cpython-37.pyc +0 -0
  41. Demosaic/code/model/__pycache__/common.cpython-37.pyc +0 -0
  42. Demosaic/code/model/__pycache__/lambdanet.cpython-37.pyc +0 -0
  43. Demosaic/code/model/__pycache__/raftnet.cpython-37.pyc +0 -0
  44. Demosaic/code/model/__pycache__/raftnetlayer.cpython-37.pyc +0 -0
  45. Demosaic/code/model/__pycache__/raftnets.cpython-37.pyc +0 -0
  46. Demosaic/code/model/__pycache__/raftnetsingle.cpython-37.pyc +0 -0
  47. Demosaic/code/model/attention.py +94 -0
  48. Demosaic/code/model/betalambdanet.py +97 -0
  49. Demosaic/code/model/common.py +93 -0
  50. Demosaic/code/model/ddbpn.py +131 -0
.gitattributes CHANGED
@@ -101,3 +101,93 @@ backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0862_11_00/blur_gamma/000
101
  backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0862_11_00/blur_gamma/000077.png filter=lfs diff=lfs merge=lfs -text
102
  backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0868_11_00/blur_gamma/000010.png filter=lfs diff=lfs merge=lfs -text
103
  backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0868_11_00/blur_gamma/000020.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0862_11_00/blur_gamma/000077.png filter=lfs diff=lfs merge=lfs -text
102
  backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0868_11_00/blur_gamma/000010.png filter=lfs diff=lfs merge=lfs -text
103
  backup/deblur/experiment/RAFTNET_L1_R6_test/result/GOPR0868_11_00/blur_gamma/000020.png filter=lfs diff=lfs merge=lfs -text
104
+ Demosaic/experiment/test/results-Urban100/img001_x1_DM.png filter=lfs diff=lfs merge=lfs -text
105
+ Demosaic/experiment/test/results-Urban100/img002_x1_DM.png filter=lfs diff=lfs merge=lfs -text
106
+ Demosaic/experiment/test/results-Urban100/img003_x1_DM.png filter=lfs diff=lfs merge=lfs -text
107
+ Demosaic/experiment/test/results-Urban100/img004_x1_DM.png filter=lfs diff=lfs merge=lfs -text
108
+ Demosaic/experiment/test/results-Urban100/img005_x1_DM.png filter=lfs diff=lfs merge=lfs -text
109
+ Demosaic/experiment/test/results-Urban100/img006_x1_DM.png filter=lfs diff=lfs merge=lfs -text
110
+ Demosaic/experiment/test/results-Urban100/img007_x1_DM.png filter=lfs diff=lfs merge=lfs -text
111
+ Demosaic/experiment/test/results-Urban100/img008_x1_DM.png filter=lfs diff=lfs merge=lfs -text
112
+ Demosaic/experiment/test/results-Urban100/img010_x1_DM.png filter=lfs diff=lfs merge=lfs -text
113
+ Demosaic/experiment/test/results-Urban100/img012_x1_DM.png filter=lfs diff=lfs merge=lfs -text
114
+ Demosaic/experiment/test/results-Urban100/img013_x1_DM.png filter=lfs diff=lfs merge=lfs -text
115
+ Demosaic/experiment/test/results-Urban100/img014_x1_DM.png filter=lfs diff=lfs merge=lfs -text
116
+ Demosaic/experiment/test/results-Urban100/img015_x1_DM.png filter=lfs diff=lfs merge=lfs -text
117
+ Demosaic/experiment/test/results-Urban100/img016_x1_DM.png filter=lfs diff=lfs merge=lfs -text
118
+ Demosaic/experiment/test/results-Urban100/img017_x1_DM.png filter=lfs diff=lfs merge=lfs -text
119
+ Demosaic/experiment/test/results-Urban100/img018_x1_DM.png filter=lfs diff=lfs merge=lfs -text
120
+ Demosaic/experiment/test/results-Urban100/img019_x1_DM.png filter=lfs diff=lfs merge=lfs -text
121
+ Demosaic/experiment/test/results-Urban100/img020_x1_DM.png filter=lfs diff=lfs merge=lfs -text
122
+ Demosaic/experiment/test/results-Urban100/img021_x1_DM.png filter=lfs diff=lfs merge=lfs -text
123
+ Demosaic/experiment/test/results-Urban100/img022_x1_DM.png filter=lfs diff=lfs merge=lfs -text
124
+ Demosaic/experiment/test/results-Urban100/img023_x1_DM.png filter=lfs diff=lfs merge=lfs -text
125
+ Demosaic/experiment/test/results-Urban100/img024_x1_DM.png filter=lfs diff=lfs merge=lfs -text
126
+ Demosaic/experiment/test/results-Urban100/img025_x1_DM.png filter=lfs diff=lfs merge=lfs -text
127
+ Demosaic/experiment/test/results-Urban100/img026_x1_DM.png filter=lfs diff=lfs merge=lfs -text
128
+ Demosaic/experiment/test/results-Urban100/img027_x1_DM.png filter=lfs diff=lfs merge=lfs -text
129
+ Demosaic/experiment/test/results-Urban100/img029_x1_DM.png filter=lfs diff=lfs merge=lfs -text
130
+ Demosaic/experiment/test/results-Urban100/img030_x1_DM.png filter=lfs diff=lfs merge=lfs -text
131
+ Demosaic/experiment/test/results-Urban100/img031_x1_DM.png filter=lfs diff=lfs merge=lfs -text
132
+ Demosaic/experiment/test/results-Urban100/img032_x1_DM.png filter=lfs diff=lfs merge=lfs -text
133
+ Demosaic/experiment/test/results-Urban100/img033_x1_DM.png filter=lfs diff=lfs merge=lfs -text
134
+ Demosaic/experiment/test/results-Urban100/img034_x1_DM.png filter=lfs diff=lfs merge=lfs -text
135
+ Demosaic/experiment/test/results-Urban100/img035_x1_DM.png filter=lfs diff=lfs merge=lfs -text
136
+ Demosaic/experiment/test/results-Urban100/img037_x1_DM.png filter=lfs diff=lfs merge=lfs -text
137
+ Demosaic/experiment/test/results-Urban100/img038_x1_DM.png filter=lfs diff=lfs merge=lfs -text
138
+ Demosaic/experiment/test/results-Urban100/img039_x1_DM.png filter=lfs diff=lfs merge=lfs -text
139
+ Demosaic/experiment/test/results-Urban100/img041_x1_DM.png filter=lfs diff=lfs merge=lfs -text
140
+ Demosaic/experiment/test/results-Urban100/img043_x1_DM.png filter=lfs diff=lfs merge=lfs -text
141
+ Demosaic/experiment/test/results-Urban100/img044_x1_DM.png filter=lfs diff=lfs merge=lfs -text
142
+ Demosaic/experiment/test/results-Urban100/img045_x1_DM.png filter=lfs diff=lfs merge=lfs -text
143
+ Demosaic/experiment/test/results-Urban100/img046_x1_DM.png filter=lfs diff=lfs merge=lfs -text
144
+ Demosaic/experiment/test/results-Urban100/img047_x1_DM.png filter=lfs diff=lfs merge=lfs -text
145
+ Demosaic/experiment/test/results-Urban100/img048_x1_DM.png filter=lfs diff=lfs merge=lfs -text
146
+ Demosaic/experiment/test/results-Urban100/img049_x1_DM.png filter=lfs diff=lfs merge=lfs -text
147
+ Demosaic/experiment/test/results-Urban100/img050_x1_DM.png filter=lfs diff=lfs merge=lfs -text
148
+ Demosaic/experiment/test/results-Urban100/img051_x1_DM.png filter=lfs diff=lfs merge=lfs -text
149
+ Demosaic/experiment/test/results-Urban100/img052_x1_DM.png filter=lfs diff=lfs merge=lfs -text
150
+ Demosaic/experiment/test/results-Urban100/img053_x1_DM.png filter=lfs diff=lfs merge=lfs -text
151
+ Demosaic/experiment/test/results-Urban100/img054_x1_DM.png filter=lfs diff=lfs merge=lfs -text
152
+ Demosaic/experiment/test/results-Urban100/img055_x1_DM.png filter=lfs diff=lfs merge=lfs -text
153
+ Demosaic/experiment/test/results-Urban100/img056_x1_DM.png filter=lfs diff=lfs merge=lfs -text
154
+ Demosaic/experiment/test/results-Urban100/img057_x1_DM.png filter=lfs diff=lfs merge=lfs -text
155
+ Demosaic/experiment/test/results-Urban100/img058_x1_DM.png filter=lfs diff=lfs merge=lfs -text
156
+ Demosaic/experiment/test/results-Urban100/img059_x1_DM.png filter=lfs diff=lfs merge=lfs -text
157
+ Demosaic/experiment/test/results-Urban100/img060_x1_DM.png filter=lfs diff=lfs merge=lfs -text
158
+ Demosaic/experiment/test/results-Urban100/img061_x1_DM.png filter=lfs diff=lfs merge=lfs -text
159
+ Demosaic/experiment/test/results-Urban100/img062_x1_DM.png filter=lfs diff=lfs merge=lfs -text
160
+ Demosaic/experiment/test/results-Urban100/img063_x1_DM.png filter=lfs diff=lfs merge=lfs -text
161
+ Demosaic/experiment/test/results-Urban100/img064_x1_DM.png filter=lfs diff=lfs merge=lfs -text
162
+ Demosaic/experiment/test/results-Urban100/img065_x1_DM.png filter=lfs diff=lfs merge=lfs -text
163
+ Demosaic/experiment/test/results-Urban100/img066_x1_DM.png filter=lfs diff=lfs merge=lfs -text
164
+ Demosaic/experiment/test/results-Urban100/img067_x1_DM.png filter=lfs diff=lfs merge=lfs -text
165
+ Demosaic/experiment/test/results-Urban100/img068_x1_DM.png filter=lfs diff=lfs merge=lfs -text
166
+ Demosaic/experiment/test/results-Urban100/img069_x1_DM.png filter=lfs diff=lfs merge=lfs -text
167
+ Demosaic/experiment/test/results-Urban100/img070_x1_DM.png filter=lfs diff=lfs merge=lfs -text
168
+ Demosaic/experiment/test/results-Urban100/img071_x1_DM.png filter=lfs diff=lfs merge=lfs -text
169
+ Demosaic/experiment/test/results-Urban100/img072_x1_DM.png filter=lfs diff=lfs merge=lfs -text
170
+ Demosaic/experiment/test/results-Urban100/img073_x1_DM.png filter=lfs diff=lfs merge=lfs -text
171
+ Demosaic/experiment/test/results-Urban100/img074_x1_DM.png filter=lfs diff=lfs merge=lfs -text
172
+ Demosaic/experiment/test/results-Urban100/img075_x1_DM.png filter=lfs diff=lfs merge=lfs -text
173
+ Demosaic/experiment/test/results-Urban100/img076_x1_DM.png filter=lfs diff=lfs merge=lfs -text
174
+ Demosaic/experiment/test/results-Urban100/img077_x1_DM.png filter=lfs diff=lfs merge=lfs -text
175
+ Demosaic/experiment/test/results-Urban100/img078_x1_DM.png filter=lfs diff=lfs merge=lfs -text
176
+ Demosaic/experiment/test/results-Urban100/img079_x1_DM.png filter=lfs diff=lfs merge=lfs -text
177
+ Demosaic/experiment/test/results-Urban100/img081_x1_DM.png filter=lfs diff=lfs merge=lfs -text
178
+ Demosaic/experiment/test/results-Urban100/img082_x1_DM.png filter=lfs diff=lfs merge=lfs -text
179
+ Demosaic/experiment/test/results-Urban100/img083_x1_DM.png filter=lfs diff=lfs merge=lfs -text
180
+ Demosaic/experiment/test/results-Urban100/img084_x1_DM.png filter=lfs diff=lfs merge=lfs -text
181
+ Demosaic/experiment/test/results-Urban100/img087_x1_DM.png filter=lfs diff=lfs merge=lfs -text
182
+ Demosaic/experiment/test/results-Urban100/img088_x1_DM.png filter=lfs diff=lfs merge=lfs -text
183
+ Demosaic/experiment/test/results-Urban100/img089_x1_DM.png filter=lfs diff=lfs merge=lfs -text
184
+ Demosaic/experiment/test/results-Urban100/img091_x1_DM.png filter=lfs diff=lfs merge=lfs -text
185
+ Demosaic/experiment/test/results-Urban100/img092_x1_DM.png filter=lfs diff=lfs merge=lfs -text
186
+ Demosaic/experiment/test/results-Urban100/img093_x1_DM.png filter=lfs diff=lfs merge=lfs -text
187
+ Demosaic/experiment/test/results-Urban100/img094_x1_DM.png filter=lfs diff=lfs merge=lfs -text
188
+ Demosaic/experiment/test/results-Urban100/img095_x1_DM.png filter=lfs diff=lfs merge=lfs -text
189
+ Demosaic/experiment/test/results-Urban100/img096_x1_DM.png filter=lfs diff=lfs merge=lfs -text
190
+ Demosaic/experiment/test/results-Urban100/img097_x1_DM.png filter=lfs diff=lfs merge=lfs -text
191
+ Demosaic/experiment/test/results-Urban100/img098_x1_DM.png filter=lfs diff=lfs merge=lfs -text
192
+ Demosaic/experiment/test/results-Urban100/img099_x1_DM.png filter=lfs diff=lfs merge=lfs -text
193
+ Demosaic/experiment/test/results-Urban100/img100_x1_DM.png filter=lfs diff=lfs merge=lfs -text
Demosaic/README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pyramid Attention for Image Restoration
2
+ This repository is for PANet and PA-EDSR introduced in the following paper
3
+
4
+ [Yiqun Mei](http://yiqunm2.web.illinois.edu/), [Yuchen Fan](https://scholar.google.com/citations?user=BlfdYL0AAAAJ&hl=en), [Yulun Zhang](http://yulunzhang.com/), [Jiahui Yu](https://jiahuiyu.com/), [Yuqian Zhou](https://yzhouas.github.io/), [Ding Liu](https://scholar.google.com/citations?user=PGtHUI0AAAAJ&hl=en), [Yun Fu](http://www1.ece.neu.edu/~yunfu/), [Thomas S. Huang](http://ifp-uiuc.github.io/) and [Honghui Shi](https://www.humphreyshi.com/) "Pyramid Attention for Image Restoration", [[Arxiv]](https://arxiv.org/abs/2004.13824)
5
+
6
+ The code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) & [RNAN](https://github.com/yulunzhang/RNAN) and tested on Ubuntu 18.04 environment (Python3.6, PyTorch_1.1) with Titan X/1080Ti/V100 GPUs.
7
+
8
+ ## Contents
9
+ 1. [Train](#train)
10
+ 2. [Test](#test)
11
+ 3. [Results](#results)
12
+ 4. [Citation](#citation)
13
+ 5. [Acknowledgements](#acknowledgements)
14
+
15
+ ## Train
16
+ ### Prepare training data
17
+
18
+ 1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar).
19
+
20
+ 2. Specify '--dir_data' based on the HR and LR images path.
21
+
22
+ 3. Organize training data like:
23
+ ```bash
24
+ DIV2K/
25
+ ├── DIV2K_train_HR
26
+ ├── DIV2K_train_LR_bicubic
27
+ │ └── X1
28
+ ├── DIV2K_valid_HR
29
+ └── DIV2K_valid_LR_bicubic
30
+ └── X1
31
+ ```
32
+ For more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch).
33
+
34
+ ### Begin to train
35
+
36
+ 1. (optional) All the pretrained models and visual results can be downloaded from [Google Drive](https://drive.google.com/open?id=1q9iUzqYX0fVRzDu4J6fvSPRosgOZoJJE).
37
+
38
+ 2. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts to train models.
39
+
40
+ **You can use scripts in file 'demo.sb' to train and test models for our paper.**
41
+
42
+ ```bash
43
+ # Example Usage:
44
+ python main.py --n_GPUs 1 --lr 1e-4 --decay 200-400-600-800 --epoch 1000 --batch_size 16 --n_resblocks 80 --save_models --model PANET --scale 1 --patch_size 48 --save PANET_DEMOSAIC --n_feats 64 --data_train DIV2K --chop
45
+
46
+
47
+ ```
48
+ ## Test
49
+ ### Quick start
50
+
51
+ 1. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts.
52
+
53
+ **You can use scripts in file 'demo.sb' to produce results for our paper.**
54
+
55
+ ```bash
56
+ # No self-ensemble, use different testsets to reproduce the results in the paper.
57
+ # Example Usage:
58
+ python main.py --model PANET --save_results --n_GPUs 1 --chop --n_resblocks 80 --n_feats 64 --data_test McM+Kodak24+CBSD68+Urban100 --scale 1 --pre_train ../model_best.pt --test_only
59
+ ```
60
+
61
+ ### The whole test pipeline
62
+ 1. Prepare test data. Organize training data like:
63
+ ```bash
64
+ benchmark/
65
+ ├── testset1
66
+ │ └── HR
67
+ │ └── LR_bicubic
68
+ │ └── X1
69
+ │ └── ..
70
+ ├── testset2
71
+ ```
72
+
73
+
74
+ 2. Conduct image CAR.
75
+
76
+ See **Quick start**
77
+ 3. Evaluate the results.
78
+
79
+ Run 'Evaluate_PSNR_SSIM.m' to obtain PSNR/SSIM values for paper.
80
+
81
+ ## Citation
82
+ If you find the code helpful in your resarch or work, please cite the following papers.
83
+ ```
84
+ @article{mei2020pyramid,
85
+ title={Pyramid Attention Networks for Image Restoration},
86
+ author={Mei, Yiqun and Fan, Yuchen and Zhang, Yulun and Yu, Jiahui and Zhou, Yuqian and Liu, Ding and Fu, Yun and Huang, Thomas S and Shi, Honghui},
87
+ journal={arXiv preprint arXiv:2004.13824},
88
+ year={2020}
89
+ }
90
+ @InProceedings{Lim_2017_CVPR_Workshops,
91
+ author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
92
+ title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
93
+ booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
94
+ month = {July},
95
+ year = {2017}
96
+ }
97
+ ```
98
+ ## Acknowledgements
99
+ This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch), [RNAN](https://github.com/yulunzhang/RNAN) and [generative-inpainting-pytorch](https://github.com/daa233/generative-inpainting-pytorch). We thank the authors for sharing their codes.
Demosaic/code/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Sanghyun Son
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Demosaic/code/__init__.py ADDED
File without changes
Demosaic/code/__pycache__/option.cpython-37.pyc ADDED
Binary file (4.92 kB). View file
 
Demosaic/code/__pycache__/template.cpython-37.pyc ADDED
Binary file (999 Bytes). View file
 
Demosaic/code/__pycache__/trainer.cpython-37.pyc ADDED
Binary file (4.87 kB). View file
 
Demosaic/code/__pycache__/utility.cpython-37.pyc ADDED
Binary file (9.11 kB). View file
 
Demosaic/code/data/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ #from dataloader import MSDataLoader
3
+ from torch.utils.data import dataloader
4
+ from torch.utils.data import ConcatDataset
5
+
6
+ # This is a simple wrapper function for ConcatDataset
7
+ class MyConcatDataset(ConcatDataset):
8
+ def __init__(self, datasets):
9
+ super(MyConcatDataset, self).__init__(datasets)
10
+ self.train = datasets[0].train
11
+
12
+ def set_scale(self, idx_scale):
13
+ for d in self.datasets:
14
+ if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
15
+
16
+ class Data:
17
+ def __init__(self, args):
18
+ self.loader_train = None
19
+ if not args.test_only:
20
+ datasets = []
21
+ for d in args.data_train:
22
+ module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
23
+ m = import_module('data.' + module_name.lower())
24
+ datasets.append(getattr(m, module_name)(args, name=d))
25
+
26
+ self.loader_train = dataloader.DataLoader(
27
+ MyConcatDataset(datasets),
28
+ batch_size=args.batch_size,
29
+ shuffle=True,
30
+ pin_memory=not args.cpu,
31
+ num_workers=args.n_threads,
32
+ )
33
+
34
+ self.loader_test = []
35
+ for d in args.data_test:
36
+ if d in ['CBSD68','Kodak24','McM','Set5', 'Set14', 'B100', 'Urban100']:
37
+ m = import_module('data.benchmark')
38
+ testset = getattr(m, 'Benchmark')(args, train=False, name=d)
39
+ else:
40
+ module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
41
+ m = import_module('data.' + module_name.lower())
42
+ testset = getattr(m, module_name)(args, train=False, name=d)
43
+
44
+ self.loader_test.append(
45
+ dataloader.DataLoader(
46
+ testset,
47
+ batch_size=1,
48
+ shuffle=False,
49
+ pin_memory=not args.cpu,
50
+ num_workers=args.n_threads,
51
+ )
52
+ )
Demosaic/code/data/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.79 kB). View file
 
Demosaic/code/data/__pycache__/benchmark.cpython-37.pyc ADDED
Binary file (1.07 kB). View file
 
Demosaic/code/data/__pycache__/common.cpython-37.pyc ADDED
Binary file (2.74 kB). View file
 
Demosaic/code/data/__pycache__/div2k.cpython-37.pyc ADDED
Binary file (1.75 kB). View file
 
Demosaic/code/data/__pycache__/srdata.cpython-37.pyc ADDED
Binary file (4.83 kB). View file
 
Demosaic/code/data/benchmark.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import common
4
+ from data import srdata
5
+
6
+ import numpy as np
7
+
8
+ import torch
9
+ import torch.utils.data as data
10
+
11
+ class Benchmark(srdata.SRData):
12
+ def __init__(self, args, name='', train=True, benchmark=True):
13
+ super(Benchmark, self).__init__(
14
+ args, name=name, train=train, benchmark=True
15
+ )
16
+
17
+ def _set_filesystem(self, dir_data):
18
+ self.apath = os.path.join(dir_data, 'benchmark', self.name)
19
+ self.dir_hr = os.path.join(self.apath, 'HR')
20
+ if self.input_large:
21
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
22
+ else:
23
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
24
+ self.ext = ('', '.png')
25
+
Demosaic/code/data/common.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import skimage.color as sc
5
+
6
+ import torch
7
+
8
+ def get_patch(*args, patch_size=96, scale=1, multi=False, input_large=False):
9
+ ih, iw = args[0].shape[:2]
10
+
11
+ if not input_large:
12
+ p = 1 if multi else 1
13
+ tp = p * patch_size
14
+ ip = tp // 1
15
+ else:
16
+ tp = patch_size
17
+ ip = patch_size
18
+
19
+ ix = random.randrange(0, iw - ip + 1)
20
+ iy = random.randrange(0, ih - ip + 1)
21
+
22
+ if not input_large:
23
+ tx, ty = 1 * ix, 1 * iy
24
+ else:
25
+ tx, ty = ix, iy
26
+
27
+ ret = [
28
+ args[0][iy:iy + ip, ix:ix + ip, :],
29
+ *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
30
+ ]
31
+
32
+ return ret
33
+
34
+ def set_channel(*args, n_channels=3):
35
+ def _set_channel(img):
36
+ if img.ndim == 2:
37
+ img = np.expand_dims(img, axis=2)
38
+
39
+ c = img.shape[2]
40
+ if n_channels == 1 and c == 3:
41
+ img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
42
+ elif n_channels == 3 and c == 1:
43
+ img = np.concatenate([img] * n_channels, 2)
44
+
45
+ return img
46
+
47
+ return [_set_channel(a) for a in args]
48
+
49
+ def np2Tensor(*args, rgb_range=255):
50
+ def _np2Tensor(img):
51
+ np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
52
+ tensor = torch.from_numpy(np_transpose).float()
53
+ tensor.mul_(rgb_range / 255)
54
+
55
+ return tensor
56
+
57
+ return [_np2Tensor(a) for a in args]
58
+
59
+ def augment(*args, hflip=True, rot=True):
60
+ hflip = hflip and random.random() < 0.5
61
+ vflip = rot and random.random() < 0.5
62
+ rot90 = rot and random.random() < 0.5
63
+
64
+ def _augment(img):
65
+ if hflip: img = img[:, ::-1, :]
66
+ if vflip: img = img[::-1, :, :]
67
+ if rot90: img = img.transpose(1, 0, 2)
68
+
69
+ return img
70
+
71
+ return [_augment(a) for a in args]
72
+
Demosaic/code/data/demo.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import common
4
+
5
+ import numpy as np
6
+ import imageio
7
+
8
+ import torch
9
+ import torch.utils.data as data
10
+
11
+ class Demo(data.Dataset):
12
+ def __init__(self, args, name='Demo', train=False, benchmark=False):
13
+ self.args = args
14
+ self.name = name
15
+ self.scale = args.scale
16
+ self.idx_scale = 0
17
+ self.train = False
18
+ self.benchmark = benchmark
19
+
20
+ self.filelist = []
21
+ for f in os.listdir(args.dir_demo):
22
+ if f.find('.png') >= 0 or f.find('.jp') >= 0:
23
+ self.filelist.append(os.path.join(args.dir_demo, f))
24
+ self.filelist.sort()
25
+
26
+ def __getitem__(self, idx):
27
+ filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
28
+ lr = imageio.imread(self.filelist[idx])
29
+ lr, = common.set_channel(lr, n_channels=self.args.n_colors)
30
+ lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
31
+
32
+ return lr_t, -1, filename
33
+
34
+ def __len__(self):
35
+ return len(self.filelist)
36
+
37
+ def set_scale(self, idx_scale):
38
+ self.idx_scale = idx_scale
39
+
Demosaic/code/data/div2k.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data import srdata
3
+
4
+ class DIV2K(srdata.SRData):
5
+ def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6
+ data_range = [r.split('-') for r in args.data_range.split('/')]
7
+ if train:
8
+ data_range = data_range[0]
9
+ else:
10
+ if args.test_only and len(data_range) == 1:
11
+ data_range = data_range[0]
12
+ else:
13
+ data_range = data_range[1]
14
+
15
+ self.begin, self.end = list(map(lambda x: int(x), data_range))
16
+ super(DIV2K, self).__init__(
17
+ args, name=name, train=train, benchmark=benchmark
18
+ )
19
+
20
+ def _scan(self):
21
+ names_hr, names_lr = super(DIV2K, self)._scan()
22
+ names_hr = names_hr[self.begin - 1:self.end]
23
+ names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24
+
25
+ return names_hr, names_lr
26
+
27
+ def _set_filesystem(self, dir_data):
28
+ super(DIV2K, self)._set_filesystem(dir_data)
29
+ self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
30
+ self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
31
+ if self.input_large: self.dir_lr += 'L'
32
+
Demosaic/code/data/div2kjpeg.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data import srdata
3
+ from data import div2k
4
+
5
+ class DIV2KJPEG(div2k.DIV2K):
6
+ def __init__(self, args, name='', train=True, benchmark=False):
7
+ self.q_factor = int(name.replace('DIV2K-Q', ''))
8
+ super(DIV2KJPEG, self).__init__(
9
+ args, name=name, train=train, benchmark=benchmark
10
+ )
11
+
12
+ def _set_filesystem(self, dir_data):
13
+ self.apath = os.path.join(dir_data, 'DIV2K')
14
+ self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
15
+ self.dir_lr = os.path.join(
16
+ self.apath, 'DIV2K_Q{}'.format(self.q_factor)
17
+ )
18
+ if self.input_large: self.dir_lr += 'L'
19
+ self.ext = ('.png', '.jpg')
20
+
Demosaic/code/data/sr291.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from data import srdata
2
+
3
+ class SR291(srdata.SRData):
4
+ def __init__(self, args, name='SR291', train=True, benchmark=False):
5
+ super(SR291, self).__init__(args, name=name)
6
+
Demosaic/code/data/srdata.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+ import pickle
5
+
6
+ from data import common
7
+
8
+ import numpy as np
9
+ import imageio
10
+ import torch
11
+ import torch.utils.data as data
12
+
13
+ class SRData(data.Dataset):
14
+ def __init__(self, args, name='', train=True, benchmark=False):
15
+ self.args = args
16
+ self.name = name
17
+ self.train = train
18
+ self.split = 'train' if train else 'test'
19
+ self.do_eval = True
20
+ self.benchmark = benchmark
21
+ self.input_large = (args.model == 'VDSR')
22
+ self.scale = args.scale
23
+ self.idx_scale = 0
24
+
25
+ self._set_filesystem(args.dir_data)
26
+ if args.ext.find('img') < 0:
27
+ path_bin = os.path.join(self.apath, 'bin')
28
+ os.makedirs(path_bin, exist_ok=True)
29
+
30
+ list_hr, list_lr = self._scan()
31
+ if args.ext.find('img') >= 0 or benchmark:
32
+ self.images_hr, self.images_lr = list_hr, list_lr
33
+ elif args.ext.find('sep') >= 0:
34
+ os.makedirs(
35
+ self.dir_hr.replace(self.apath, path_bin),
36
+ exist_ok=True
37
+ )
38
+ for s in self.scale:
39
+ os.makedirs(
40
+ os.path.join(
41
+ self.dir_lr.replace(self.apath, path_bin),
42
+ 'X{}'.format(s)
43
+ ),
44
+ exist_ok=True
45
+ )
46
+
47
+ self.images_hr, self.images_lr = [], [[] for _ in self.scale]
48
+ for h in list_hr:
49
+ b = h.replace(self.apath, path_bin)
50
+ b = b.replace(self.ext[0], '.pt')
51
+ self.images_hr.append(b)
52
+ self._check_and_load(args.ext, h, b, verbose=True)
53
+ for i, ll in enumerate(list_lr):
54
+ for l in ll:
55
+ b = l.replace(self.apath, path_bin)
56
+ b = b.replace(self.ext[1], '.pt')
57
+ self.images_lr[i].append(b)
58
+ self._check_and_load(args.ext, l, b, verbose=True)
59
+ if train:
60
+ n_patches = args.batch_size * args.test_every
61
+ n_images = len(args.data_train) * len(self.images_hr)
62
+ if n_images == 0:
63
+ self.repeat = 0
64
+ else:
65
+ self.repeat = max(n_patches // n_images, 1)
66
+
67
+ # Below functions as used to prepare images
68
+ def _scan(self):
69
+ names_hr = sorted(
70
+ glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
71
+ )
72
+ names_lr = [[] for _ in self.scale]
73
+ for f in names_hr:
74
+ filename, _ = os.path.splitext(os.path.basename(f))
75
+ for si, s in enumerate(self.scale):
76
+ names_lr[si].append(os.path.join(
77
+ self.dir_lr, 'X{}/{}x{}{}'.format(
78
+ s, filename, s, self.ext[1]
79
+ )
80
+ ))
81
+
82
+ return names_hr, names_lr
83
+
84
+ def _set_filesystem(self, dir_data):
85
+ self.apath = os.path.join(dir_data, self.name)
86
+ self.dir_hr = os.path.join(self.apath, 'HR')
87
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
88
+ if self.input_large: self.dir_lr += 'L'
89
+ self.ext = ('.png', '.png')
90
+
91
+ def _check_and_load(self, ext, img, f, verbose=True):
92
+ if not os.path.isfile(f) or ext.find('reset') >= 0:
93
+ if verbose:
94
+ print('Making a binary: {}'.format(f))
95
+ with open(f, 'wb') as _f:
96
+ pickle.dump(imageio.imread(img), _f)
97
+
98
+ def __getitem__(self, idx):
99
+ lr, hr, filename = self._load_file(idx)
100
+ pair = self.get_patch(lr, hr)
101
+ pair = common.set_channel(*pair, n_channels=self.args.n_colors)
102
+ pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
103
+
104
+ return pair_t[0], pair_t[1], filename
105
+
106
+ def __len__(self):
107
+ if self.train:
108
+ return len(self.images_hr) * self.repeat
109
+ else:
110
+ return len(self.images_hr)
111
+
112
+ def _get_index(self, idx):
113
+ if self.train:
114
+ return idx % len(self.images_hr)
115
+ else:
116
+ return idx
117
+
118
+ def _load_file(self, idx):
119
+ idx = self._get_index(idx)
120
+ f_hr = self.images_hr[idx]
121
+ f_lr = self.images_lr[self.idx_scale][idx]
122
+
123
+ filename, _ = os.path.splitext(os.path.basename(f_hr))
124
+ if self.args.ext == 'img' or self.benchmark:
125
+ hr = imageio.imread(f_hr)
126
+ lr = imageio.imread(f_lr)
127
+ elif self.args.ext.find('sep') >= 0:
128
+ with open(f_hr, 'rb') as _f:
129
+ hr = pickle.load(_f)
130
+ with open(f_lr, 'rb') as _f:
131
+ lr = pickle.load(_f)
132
+
133
+ return lr, hr, filename
134
+
135
+ def get_patch(self, lr, hr):
136
+ scale = self.scale[self.idx_scale]
137
+ if self.train:
138
+ lr, hr = common.get_patch(
139
+ lr, hr,
140
+ patch_size=self.args.patch_size,
141
+ scale=scale,
142
+ multi=(len(self.scale) > 1),
143
+ input_large=self.input_large
144
+ )
145
+ if not self.args.no_augment: lr, hr = common.augment(lr, hr)
146
+ else:
147
+ ih, iw = lr.shape[:2]
148
+ hr = hr[0:ih * scale, 0:iw * scale]
149
+
150
+ return lr, hr
151
+
152
+ def set_scale(self, idx_scale):
153
+ if not self.input_large:
154
+ self.idx_scale = idx_scale
155
+ else:
156
+ self.idx_scale = random.randint(0, len(self.scale) - 1)
157
+
Demosaic/code/data/video.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import common
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import imageio
8
+
9
+ import torch
10
+ import torch.utils.data as data
11
+
12
+ class Video(data.Dataset):
13
+ def __init__(self, args, name='Video', train=False, benchmark=False):
14
+ self.args = args
15
+ self.name = name
16
+ self.scale = args.scale
17
+ self.idx_scale = 0
18
+ self.train = False
19
+ self.do_eval = False
20
+ self.benchmark = benchmark
21
+
22
+ self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
23
+ self.vidcap = cv2.VideoCapture(args.dir_demo)
24
+ self.n_frames = 0
25
+ self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
26
+
27
+ def __getitem__(self, idx):
28
+ success, lr = self.vidcap.read()
29
+ if success:
30
+ self.n_frames += 1
31
+ lr, = common.set_channel(lr, n_channels=self.args.n_colors)
32
+ lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
33
+
34
+ return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
35
+ else:
36
+ vidcap.release()
37
+ return None
38
+
39
+ def __len__(self):
40
+ return self.total_frames
41
+
42
+ def set_scale(self, idx_scale):
43
+ self.idx_scale = idx_scale
44
+
Demosaic/code/dataloader.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import random
3
+
4
+ import torch
5
+ import torch.multiprocessing as multiprocessing
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.data import SequentialSampler
8
+ from torch.utils.data import RandomSampler
9
+ from torch.utils.data import BatchSampler
10
+ from torch.utils.data import _utils
11
+ from torch.utils.data.dataloader import _DataLoaderIter
12
+
13
+ from torch.utils.data._utils import collate
14
+ from torch.utils.data._utils import signal_handling
15
+ from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
16
+ from torch.utils.data._utils import ExceptionWrapper
17
+ from torch.utils.data._utils import IS_WINDOWS
18
+ from torch.utils.data._utils.worker import ManagerWatchdog
19
+
20
+ from torch._six import queue
21
+
22
+ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
23
+ try:
24
+ collate._use_shared_memory = True
25
+ signal_handling._set_worker_signal_handlers()
26
+
27
+ torch.set_num_threads(1)
28
+ random.seed(seed)
29
+ torch.manual_seed(seed)
30
+
31
+ data_queue.cancel_join_thread()
32
+
33
+ if init_fn is not None:
34
+ init_fn(worker_id)
35
+
36
+ watchdog = ManagerWatchdog()
37
+
38
+ while watchdog.is_alive():
39
+ try:
40
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
41
+ except queue.Empty:
42
+ continue
43
+
44
+ if r is None:
45
+ assert done_event.is_set()
46
+ return
47
+ elif done_event.is_set():
48
+ continue
49
+
50
+ idx, batch_indices = r
51
+ try:
52
+ idx_scale = 0
53
+ if len(scale) > 1 and dataset.train:
54
+ idx_scale = random.randrange(0, len(scale))
55
+ dataset.set_scale(idx_scale)
56
+
57
+ samples = collate_fn([dataset[i] for i in batch_indices])
58
+ samples.append(idx_scale)
59
+ except Exception:
60
+ data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
61
+ else:
62
+ data_queue.put((idx, samples))
63
+ del samples
64
+
65
+ except KeyboardInterrupt:
66
+ pass
67
+
68
+ class _MSDataLoaderIter(_DataLoaderIter):
69
+
70
+ def __init__(self, loader):
71
+ self.dataset = loader.dataset
72
+ self.scale = loader.scale
73
+ self.collate_fn = loader.collate_fn
74
+ self.batch_sampler = loader.batch_sampler
75
+ self.num_workers = loader.num_workers
76
+ self.pin_memory = loader.pin_memory and torch.cuda.is_available()
77
+ self.timeout = loader.timeout
78
+
79
+ self.sample_iter = iter(self.batch_sampler)
80
+
81
+ base_seed = torch.LongTensor(1).random_().item()
82
+
83
+ if self.num_workers > 0:
84
+ self.worker_init_fn = loader.worker_init_fn
85
+ self.worker_queue_idx = 0
86
+ self.worker_result_queue = multiprocessing.Queue()
87
+ self.batches_outstanding = 0
88
+ self.worker_pids_set = False
89
+ self.shutdown = False
90
+ self.send_idx = 0
91
+ self.rcvd_idx = 0
92
+ self.reorder_dict = {}
93
+ self.done_event = multiprocessing.Event()
94
+
95
+ base_seed = torch.LongTensor(1).random_()[0]
96
+
97
+ self.index_queues = []
98
+ self.workers = []
99
+ for i in range(self.num_workers):
100
+ index_queue = multiprocessing.Queue()
101
+ index_queue.cancel_join_thread()
102
+ w = multiprocessing.Process(
103
+ target=_ms_loop,
104
+ args=(
105
+ self.dataset,
106
+ index_queue,
107
+ self.worker_result_queue,
108
+ self.done_event,
109
+ self.collate_fn,
110
+ self.scale,
111
+ base_seed + i,
112
+ self.worker_init_fn,
113
+ i
114
+ )
115
+ )
116
+ w.daemon = True
117
+ w.start()
118
+ self.index_queues.append(index_queue)
119
+ self.workers.append(w)
120
+
121
+ if self.pin_memory:
122
+ self.data_queue = queue.Queue()
123
+ pin_memory_thread = threading.Thread(
124
+ target=_utils.pin_memory._pin_memory_loop,
125
+ args=(
126
+ self.worker_result_queue,
127
+ self.data_queue,
128
+ torch.cuda.current_device(),
129
+ self.done_event
130
+ )
131
+ )
132
+ pin_memory_thread.daemon = True
133
+ pin_memory_thread.start()
134
+ self.pin_memory_thread = pin_memory_thread
135
+ else:
136
+ self.data_queue = self.worker_result_queue
137
+
138
+ _utils.signal_handling._set_worker_pids(
139
+ id(self), tuple(w.pid for w in self.workers)
140
+ )
141
+ _utils.signal_handling._set_SIGCHLD_handler()
142
+ self.worker_pids_set = True
143
+
144
+ for _ in range(2 * self.num_workers):
145
+ self._put_indices()
146
+
147
+
148
+ class MSDataLoader(DataLoader):
149
+
150
+ def __init__(self, cfg, *args, **kwargs):
151
+ super(MSDataLoader, self).__init__(
152
+ *args, **kwargs, num_workers=cfg.n_threads
153
+ )
154
+ self.scale = cfg.scale
155
+
156
+ def __iter__(self):
157
+ return _MSDataLoaderIter(self)
158
+
Demosaic/code/demo.sb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # PANET Train
4
+ #python main.py --n_GPUs 4 --lr 1e-4 --decay 200-400-600-800 --epoch 1000 --batch_size 16 --n_resblocks 80 --save_models --model PANET --scale 1 --patch_size 48 --save PANET_DEMOSAIC --n_feats 64 --data_train DIV2K --chop
5
+ # Test
6
+ python main.py --model PANET --save_results --n_GPUs 1 --chop --data_test McM+Kodak24+CBSD68+Urban100 --scale 1 --pre_train ../model_best.pt --test_only
Demosaic/code/lambda_networks/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from lambda_networks.lambda_networks import LambdaLayer
2
+ from lambda_networks.lambda_networks import Recursion
3
+ from lambda_networks.rlambda_networks import RLambdaLayer
4
+ λLayer = LambdaLayer
Demosaic/code/lambda_networks/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (343 Bytes). View file
 
Demosaic/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc ADDED
Binary file (4.89 kB). View file
 
Demosaic/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc ADDED
Binary file (2.93 kB). View file
 
Demosaic/code/lambda_networks/lambda_networks.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ # helpers functions
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+ def default(val, d):
12
+ return val if exists(val) else d
13
+
14
+ # lambda layer
15
+
16
+ class LambdaLayer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim,
20
+ *,
21
+ dim_k,
22
+ n = None,
23
+ r = None,
24
+ heads = 4,
25
+ dim_out = None,
26
+ dim_u = 1,
27
+ norm="batch"):
28
+ super().__init__()
29
+ dim_out = default(dim_out, dim)
30
+ self.u = dim_u # intra-depth dimension
31
+ self.heads = heads
32
+
33
+ assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
34
+ dim_v = dim_out // heads
35
+
36
+ self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
37
+ self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
38
+ self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
39
+ if norm=="instance":
40
+ self.norm_q = nn.InstanceNorm2d(dim_k * heads)
41
+ self.norm_v = nn.InstanceNorm2d(dim_v * dim_u)
42
+ else:
43
+ self.norm_q = nn.BatchNorm2d(dim_k * heads)
44
+ self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
45
+ self.local_contexts = exists(r)
46
+ if exists(r):
47
+ assert (r % 2) == 1, 'Receptive kernel size should be odd'
48
+ self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
49
+ else:
50
+ assert exists(n), 'You must specify the total sequence length (h x w)'
51
+ self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
52
+
53
+
54
+ def forward(self, x):
55
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
56
+
57
+ q = self.to_q(x)
58
+ k = self.to_k(x)
59
+ v = self.to_v(x)
60
+
61
+ q = self.norm_q(q)
62
+ v = self.norm_v(v)
63
+
64
+ q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
65
+ k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
66
+ v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
67
+
68
+ k = k.softmax(dim=-1)
69
+
70
+ λc = einsum('b u k m, b u v m -> b k v', k, v)
71
+ Yc = einsum('b h k n, b k v -> b h v n', q, λc)
72
+
73
+ if self.local_contexts:
74
+ v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
75
+ λp = self.pos_conv(v)
76
+ Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
77
+ else:
78
+ λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
79
+ Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
80
+
81
+ Y = Yc + Yp
82
+ out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
83
+ return out
84
+
85
+
86
+ # i'm not sure whether this will work or not
87
+ class Recursion(nn.Module):
88
+ def __init__(self, N: int, hidden_dim:int=64):
89
+ super(Recursion,self).__init__()
90
+ self.N = N
91
+ self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
92
+ # merge upstream information here
93
+ self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
94
+ self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
95
+ self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
96
+ self.pixel_shuffle = nn.PixelShuffle(N)
97
+
98
+ def forward(self, x: torch.Tensor):
99
+ N = self.N
100
+
101
+ def to_patch(blocks:torch.Tensor)->torch.Tensor:
102
+ shape = blocks.shape
103
+ blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
104
+ blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
105
+ num_patch = blocks_patch.shape[-1]
106
+ blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
107
+ return blocks_patch, num_patch
108
+
109
+ def combine_patch(processed_patch,shape,num_patch):
110
+ processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
111
+ processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
112
+ processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
113
+ return processed
114
+
115
+ def process(blocks:torch.Tensor)->torch.Tensor:
116
+ shape = blocks.shape
117
+ if blocks.shape[-1] == N:
118
+ processed = self.lambdaNxN_identity(blocks)
119
+ return processed
120
+ # to NxN patchs
121
+ blocks_patch,num_patch=to_patch(blocks)
122
+ # pass through identity
123
+ processed_patch = self.lambdaNxN_identity(blocks_patch)
124
+ # back to HxW
125
+ processed=combine_patch(processed_patch,shape,num_patch)
126
+ # get feedback
127
+ feedback = process(self.downscale_conv(processed))
128
+ # upscale feedback
129
+ upscale_feedback = self.upscale_conv(feedback)
130
+ upscale_feedback=self.pixel_shuffle(upscale_feedback)
131
+ # combine results
132
+ combined = torch.cat([processed, upscale_feedback], dim=1)
133
+ combined_shape=combined.shape
134
+ combined_patch,num_patch=to_patch(combined)
135
+ combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
136
+ ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
137
+ ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
138
+ return ret
139
+
140
+ return process(x)
Demosaic/code/lambda_networks/rlambda_networks.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+
7
+ # helpers functions
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+
13
+ def default(val, d):
14
+ return val if exists(val) else d
15
+
16
+
17
+ # lambda layer
18
+
19
+ class RLambdaLayer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ dim,
23
+ *,
24
+ dim_k,
25
+ n=None,
26
+ r=None,
27
+ heads=4,
28
+ dim_out=None,
29
+ dim_u=1,
30
+ recurrence=None
31
+ ):
32
+ super().__init__()
33
+ dim_out = default(dim_out, dim)
34
+ self.u = dim_u # intra-depth dimension
35
+ self.heads = heads
36
+
37
+ assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
38
+ dim_v = dim_out // heads
39
+
40
+ self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False)
41
+ self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False)
42
+ self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False)
43
+
44
+ self.norm_q = nn.BatchNorm2d(dim_k * heads)
45
+ self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
46
+
47
+ self.local_contexts = exists(r)
48
+ self.recurrence = recurrence
49
+ if exists(r):
50
+ assert (r % 2) == 1, 'Receptive kernel size should be odd'
51
+ self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
52
+ else:
53
+ assert exists(n), 'You must specify the total sequence length (h x w)'
54
+ self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
55
+
56
+ def apply_lambda(self, lambda_c, lambda_p, x):
57
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
58
+ q = self.to_q(x)
59
+ q = self.norm_q(q)
60
+ q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h)
61
+ Yc = einsum('b h k n, b k v -> b h v n', q, lambda_c)
62
+ if self.local_contexts:
63
+ Yp = einsum('b h k n, b k v n -> b h v n', q, lambda_p.flatten(3))
64
+ else:
65
+ Yp = einsum('b h k n, b n k v -> b h v n', q, lambda_p)
66
+ Y = Yc + Yp
67
+ out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh=hh, ww=ww)
68
+ return out
69
+
70
+ def forward(self, x):
71
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
72
+
73
+ k = self.to_k(x)
74
+ v = self.to_v(x)
75
+
76
+ v = self.norm_v(v)
77
+
78
+ k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u)
79
+ v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u)
80
+
81
+ k = k.softmax(dim=-1)
82
+
83
+ λc = einsum('b u k m, b u v m -> b k v', k, v)
84
+
85
+ if self.local_contexts:
86
+ v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
87
+ λp = self.pos_conv(v)
88
+ else:
89
+ λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
90
+ out = x
91
+ for i in range(self.recurrence):
92
+ out = self.apply_lambda(λc, λp, out)
93
+ return out
Demosaic/code/loss/__init__.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import matplotlib.pyplot as plt
7
+
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ def sequence_loss(sr, hr, loss_func, gamma=0.8, max_val=None):
15
+ """ Loss function defined over sequence of flow predictions """
16
+
17
+ n_recurrence = len(sr)
18
+ total_loss = 0.0
19
+ buffer=[0.0]*n_recurrence
20
+ # exlude invalid pixels and extremely large diplacements
21
+ for i in range(n_recurrence):
22
+ i_weight = gamma**(n_recurrence - i - 1)
23
+ i_loss = loss_func(sr[i],hr)
24
+ buffer[i]=i_loss.item()
25
+ # total_loss += i_weight * (valid[:, None] * i_loss).mean()
26
+ total_loss += i_weight * (i_loss)
27
+ return total_loss,buffer
28
+
29
+ class Loss(nn.modules.loss._Loss):
30
+ def __init__(self, args, ckp):
31
+ super(Loss, self).__init__()
32
+ print('Preparing loss function:')
33
+ self.buffer=[0.0]*args.recurrence
34
+ self.n_GPUs = args.n_GPUs
35
+ self.loss = []
36
+ self.loss_module = nn.ModuleList()
37
+ for loss in args.loss.split('+'):
38
+ weight, loss_type = loss.split('*')
39
+ if loss_type == 'MSE':
40
+ loss_function = nn.MSELoss()
41
+ elif loss_type == 'L1':
42
+ loss_function = nn.L1Loss()
43
+ elif loss_type.find('VGG') >= 0:
44
+ module = import_module('loss.vgg')
45
+ loss_function = getattr(module, 'VGG')(
46
+ loss_type[3:],
47
+ rgb_range=args.rgb_range
48
+ )
49
+ elif loss_type.find('GAN') >= 0:
50
+ module = import_module('loss.adversarial')
51
+ loss_function = getattr(module, 'Adversarial')(
52
+ args,
53
+ loss_type
54
+ )
55
+
56
+ self.loss.append({
57
+ 'type': loss_type,
58
+ 'weight': float(weight),
59
+ 'function': loss_function}
60
+ )
61
+ if loss_type.find('GAN') >= 0:
62
+ self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
63
+
64
+ if len(self.loss) > 1:
65
+ self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
66
+
67
+ for l in self.loss:
68
+ if l['function'] is not None:
69
+ print('{:.3f} * {}'.format(l['weight'], l['type']))
70
+ # self.loss_module.append(l['function'])
71
+
72
+ self.log = torch.Tensor()
73
+
74
+ device = torch.device('cpu' if args.cpu else 'cuda')
75
+ self.loss_module.to(device)
76
+ if args.precision == 'half': self.loss_module.half()
77
+ if not args.cpu and args.n_GPUs > 1:
78
+ self.loss_module = nn.DataParallel(
79
+ self.loss_module, range(args.n_GPUs)
80
+ )
81
+
82
+ if args.load != '': self.load(ckp.dir, cpu=args.cpu)
83
+
84
+ def forward(self, sr, hr):
85
+ losses = []
86
+ for i, l in enumerate(self.loss):
87
+ if l['function'] is not None:
88
+ if isinstance(sr,list):
89
+ # weights=[0.32,0.08,0.02,0.01,0.005]
90
+ # weights=weights[::-1]
91
+ # weights=[0.01,0.02,0.08,0.32]
92
+ # self.buffer=[]
93
+ effective_loss,buffer_lst=sequence_loss(sr,hr,l['function'])
94
+ # for k in range(len(sr)):
95
+ # loss=l['function'](sr[k], hr)
96
+ # self.buffer.append(loss.item())
97
+ # effective_loss=loss*weights[k]*l['weight']
98
+ losses.append(effective_loss)
99
+ self.buffer=buffer_lst
100
+ self.log[-1, i] += effective_loss.item()
101
+ else:
102
+ loss = l['function'](sr, hr)
103
+ effective_loss = l['weight'] * loss
104
+ losses.append(effective_loss)
105
+ self.buffer[0]=effective_loss.item()
106
+ self.log[-1, i] += effective_loss.item()
107
+ elif l['type'] == 'DIS':
108
+ self.log[-1, i] += self.loss[i - 1]['function'].loss
109
+
110
+ loss_sum = sum(losses)
111
+ if len(self.loss) > 1:
112
+ self.log[-1, -1] += loss_sum.item()
113
+
114
+ return loss_sum
115
+
116
+ def step(self):
117
+ for l in self.get_loss_module():
118
+ if hasattr(l, 'scheduler'):
119
+ l.scheduler.step()
120
+
121
+ def start_log(self):
122
+ self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
123
+
124
+ def end_log(self, n_batches):
125
+ self.log[-1].div_(n_batches)
126
+
127
+ def display_loss(self, batch):
128
+ n_samples = batch + 1
129
+ log = []
130
+ for l, c in zip(self.loss, self.log[-1]):
131
+ log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
132
+
133
+ return ''.join(log)
134
+
135
+ def plot_loss(self, apath, epoch):
136
+ axis = np.linspace(1, epoch, epoch)
137
+ for i, l in enumerate(self.loss):
138
+ label = '{} Loss'.format(l['type'])
139
+ fig = plt.figure()
140
+ plt.title(label)
141
+ plt.plot(axis, self.log[:, i].numpy(), label=label)
142
+ plt.legend()
143
+ plt.xlabel('Epochs')
144
+ plt.ylabel('Loss')
145
+ plt.grid(True)
146
+ plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
147
+ plt.close(fig)
148
+
149
+ def get_loss_module(self):
150
+ if self.n_GPUs == 1:
151
+ return self.loss_module
152
+ else:
153
+ return self.loss_module.module
154
+
155
+ def save(self, apath):
156
+ torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
157
+ torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
158
+
159
+ def load(self, apath, cpu=False):
160
+ if cpu:
161
+ kwargs = {'map_location': lambda storage, loc: storage}
162
+ else:
163
+ kwargs = {}
164
+
165
+ self.load_state_dict(torch.load(
166
+ os.path.join(apath, 'loss.pt'),
167
+ **kwargs
168
+ ))
169
+ self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
170
+ for l in self.get_loss_module():
171
+ if hasattr(l, 'scheduler'):
172
+ for _ in range(len(self.log)): l.scheduler.step()
173
+
Demosaic/code/loss/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (5.14 kB). View file
 
Demosaic/code/loss/adversarial.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import utility
2
+ from types import SimpleNamespace
3
+
4
+ from model import common
5
+ from loss import discriminator
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+
12
+ class Adversarial(nn.Module):
13
+ def __init__(self, args, gan_type):
14
+ super(Adversarial, self).__init__()
15
+ self.gan_type = gan_type
16
+ self.gan_k = args.gan_k
17
+ self.dis = discriminator.Discriminator(args)
18
+ if gan_type == 'WGAN_GP':
19
+ # see https://arxiv.org/pdf/1704.00028.pdf pp.4
20
+ optim_dict = {
21
+ 'optimizer': 'ADAM',
22
+ 'betas': (0, 0.9),
23
+ 'epsilon': 1e-8,
24
+ 'lr': 1e-5,
25
+ 'weight_decay': args.weight_decay,
26
+ 'decay': args.decay,
27
+ 'gamma': args.gamma
28
+ }
29
+ optim_args = SimpleNamespace(**optim_dict)
30
+ else:
31
+ optim_args = args
32
+
33
+ self.optimizer = utility.make_optimizer(optim_args, self.dis)
34
+
35
+ def forward(self, fake, real):
36
+ # updating discriminator...
37
+ self.loss = 0
38
+ fake_detach = fake.detach() # do not backpropagate through G
39
+ for _ in range(self.gan_k):
40
+ self.optimizer.zero_grad()
41
+ # d: B x 1 tensor
42
+ d_fake = self.dis(fake_detach)
43
+ d_real = self.dis(real)
44
+ retain_graph = False
45
+ if self.gan_type == 'GAN':
46
+ loss_d = self.bce(d_real, d_fake)
47
+ elif self.gan_type.find('WGAN') >= 0:
48
+ loss_d = (d_fake - d_real).mean()
49
+ if self.gan_type.find('GP') >= 0:
50
+ epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
51
+ hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
52
+ hat.requires_grad = True
53
+ d_hat = self.dis(hat)
54
+ gradients = torch.autograd.grad(
55
+ outputs=d_hat.sum(), inputs=hat,
56
+ retain_graph=True, create_graph=True, only_inputs=True
57
+ )[0]
58
+ gradients = gradients.view(gradients.size(0), -1)
59
+ gradient_norm = gradients.norm(2, dim=1)
60
+ gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
61
+ loss_d += gradient_penalty
62
+ # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
63
+ elif self.gan_type == 'RGAN':
64
+ better_real = d_real - d_fake.mean(dim=0, keepdim=True)
65
+ better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
66
+ loss_d = self.bce(better_real, better_fake)
67
+ retain_graph = True
68
+
69
+ # Discriminator update
70
+ self.loss += loss_d.item()
71
+ loss_d.backward(retain_graph=retain_graph)
72
+ self.optimizer.step()
73
+
74
+ if self.gan_type == 'WGAN':
75
+ for p in self.dis.parameters():
76
+ p.data.clamp_(-1, 1)
77
+
78
+ self.loss /= self.gan_k
79
+
80
+ # updating generator...
81
+ d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
82
+ if self.gan_type == 'GAN':
83
+ label_real = torch.ones_like(d_fake_bp)
84
+ loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
85
+ elif self.gan_type.find('WGAN') >= 0:
86
+ loss_g = -d_fake_bp.mean()
87
+ elif self.gan_type == 'RGAN':
88
+ better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
89
+ better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
90
+ loss_g = self.bce(better_fake, better_real)
91
+
92
+ # Generator loss
93
+ return loss_g
94
+
95
+ def state_dict(self, *args, **kwargs):
96
+ state_discriminator = self.dis.state_dict(*args, **kwargs)
97
+ state_optimizer = self.optimizer.state_dict()
98
+
99
+ return dict(**state_discriminator, **state_optimizer)
100
+
101
+ def bce(self, real, fake):
102
+ label_real = torch.ones_like(real)
103
+ label_fake = torch.zeros_like(fake)
104
+ bce_real = F.binary_cross_entropy_with_logits(real, label_real)
105
+ bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
106
+ bce_loss = bce_real + bce_fake
107
+ return bce_loss
108
+
109
+ # Some references
110
+ # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
111
+ # OR
112
+ # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
Demosaic/code/loss/discriminator.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import common
2
+
3
+ import torch.nn as nn
4
+
5
+ class Discriminator(nn.Module):
6
+ '''
7
+ output is not normalized
8
+ '''
9
+ def __init__(self, args):
10
+ super(Discriminator, self).__init__()
11
+
12
+ in_channels = args.n_colors
13
+ out_channels = 64
14
+ depth = 7
15
+
16
+ def _block(_in_channels, _out_channels, stride=1):
17
+ return nn.Sequential(
18
+ nn.Conv2d(
19
+ _in_channels,
20
+ _out_channels,
21
+ 3,
22
+ padding=1,
23
+ stride=stride,
24
+ bias=False
25
+ ),
26
+ nn.BatchNorm2d(_out_channels),
27
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)
28
+ )
29
+
30
+ m_features = [_block(in_channels, out_channels)]
31
+ for i in range(depth):
32
+ in_channels = out_channels
33
+ if i % 2 == 1:
34
+ stride = 1
35
+ out_channels *= 2
36
+ else:
37
+ stride = 2
38
+ m_features.append(_block(in_channels, out_channels, stride=stride))
39
+
40
+ patch_size = args.patch_size // (2**((depth + 1) // 2))
41
+ m_classifier = [
42
+ nn.Linear(out_channels * patch_size**2, 1024),
43
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
44
+ nn.Linear(1024, 1)
45
+ ]
46
+
47
+ self.features = nn.Sequential(*m_features)
48
+ self.classifier = nn.Sequential(*m_classifier)
49
+
50
+ def forward(self, x):
51
+ features = self.features(x)
52
+ output = self.classifier(features.view(features.size(0), -1))
53
+
54
+ return output
55
+
Demosaic/code/loss/vgg.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import common
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.models as models
7
+
8
+ class VGG(nn.Module):
9
+ def __init__(self, conv_index, rgb_range=1):
10
+ super(VGG, self).__init__()
11
+ vgg_features = models.vgg19(pretrained=True).features
12
+ modules = [m for m in vgg_features]
13
+ if conv_index.find('22') >= 0:
14
+ self.vgg = nn.Sequential(*modules[:8])
15
+ elif conv_index.find('54') >= 0:
16
+ self.vgg = nn.Sequential(*modules[:35])
17
+
18
+ vgg_mean = (0.485, 0.456, 0.406)
19
+ vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
20
+ self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
21
+ for p in self.parameters():
22
+ p.requires_grad = False
23
+
24
+ def forward(self, sr, hr):
25
+ def _forward(x):
26
+ x = self.sub_mean(x)
27
+ x = self.vgg(x)
28
+ return x
29
+
30
+ vgg_sr = _forward(sr)
31
+ with torch.no_grad():
32
+ vgg_hr = _forward(hr.detach())
33
+
34
+ loss = F.mse_loss(vgg_sr, vgg_hr)
35
+
36
+ return loss
Demosaic/code/main.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import utility
4
+ import data
5
+ import model
6
+ import loss
7
+ from option import args
8
+ from trainer import Trainer
9
+
10
+ torch.manual_seed(args.seed)
11
+ checkpoint = utility.checkpoint(args)
12
+
13
+ def main():
14
+ global model
15
+ if args.data_test == ['video']:
16
+ from videotester import VideoTester
17
+ model = model.Model(args,checkpoint)
18
+ print('total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
19
+ t = VideoTester(args, model, checkpoint)
20
+ t.test()
21
+ else:
22
+ if checkpoint.ok:
23
+ loader = data.Data(args)
24
+ _model = model.Model(args, checkpoint)
25
+ print('total params:%.5fM' % (sum(p.numel() for p in _model.parameters())/1000000.0))
26
+ _loss = loss.Loss(args, checkpoint) if not args.test_only else None
27
+ t = Trainer(args, loader, _model, _loss, checkpoint)
28
+ while not t.terminate():
29
+ t.train()
30
+ t.test()
31
+
32
+ checkpoint.done()
33
+
34
+ if __name__ == '__main__':
35
+ main()
Demosaic/code/model/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Sanghyun Son
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Demosaic/code/model/__init__.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Variable
7
+
8
+ class Model(nn.Module):
9
+ def __init__(self, args, ckp):
10
+ super(Model, self).__init__()
11
+ print('Making model...')
12
+
13
+ self.scale = args.scale
14
+ self.idx_scale = 0
15
+ self.self_ensemble = args.self_ensemble
16
+ self.chop = args.chop
17
+ self.precision = args.precision
18
+ self.cpu = args.cpu
19
+ self.device = torch.device('cpu' if args.cpu else 'cuda')
20
+ self.n_GPUs = args.n_GPUs
21
+ self.save_models = args.save_models
22
+
23
+ module = import_module('model.' + args.model.lower())
24
+ self.model = module.make_model(args).to(self.device)
25
+ if args.precision == 'half': self.model.half()
26
+
27
+ if not args.cpu and args.n_GPUs > 1:
28
+ self.model = nn.DataParallel(self.model, range(args.n_GPUs))
29
+
30
+ self.load(
31
+ ckp.dir,
32
+ pre_train=args.pre_train,
33
+ resume=args.resume,
34
+ cpu=args.cpu
35
+ )
36
+ print(self.model, file=ckp.log_file)
37
+
38
+ def forward(self, x, idx_scale):
39
+ self.idx_scale = idx_scale
40
+ target = self.get_model()
41
+ if hasattr(target, 'set_scale'):
42
+ target.set_scale(idx_scale)
43
+
44
+ if self.self_ensemble and not self.training:
45
+ if self.chop:
46
+ forward_function = self.forward_chop
47
+ else:
48
+ forward_function = self.model.forward
49
+
50
+ return self.forward_x8(x, forward_function)
51
+ elif self.chop and not self.training:
52
+ return self.forward_chop(x)
53
+ else:
54
+ return self.model(x)
55
+
56
+ def get_model(self):
57
+ if self.n_GPUs == 1:
58
+ return self.model
59
+ else:
60
+ return self.model.module
61
+
62
+ def state_dict(self, **kwargs):
63
+ target = self.get_model()
64
+ return target.state_dict(**kwargs)
65
+
66
+ def save(self, apath, epoch, is_best=False):
67
+ target = self.get_model()
68
+ torch.save(
69
+ target.state_dict(),
70
+ os.path.join(apath, 'model_latest.pt')
71
+ )
72
+ if is_best:
73
+ torch.save(
74
+ target.state_dict(),
75
+ os.path.join(apath, 'model_best.pt')
76
+ )
77
+
78
+ if self.save_models:
79
+ torch.save(
80
+ target.state_dict(),
81
+ os.path.join(apath, 'model_{}.pt'.format(epoch))
82
+ )
83
+
84
+ def load(self, apath, pre_train='.', resume=-1, cpu=False):
85
+ if cpu:
86
+ kwargs = {'map_location': lambda storage, loc: storage}
87
+ else:
88
+ kwargs = {}
89
+
90
+ if resume == -1:
91
+ self.get_model().load_state_dict(
92
+ torch.load(
93
+ os.path.join(apath,'model', 'model_latest.pt'),
94
+ **kwargs
95
+ ),
96
+ strict=False
97
+ )
98
+ elif resume == 0:
99
+ if pre_train != '.':
100
+ print('Loading model from {}'.format(pre_train))
101
+ self.get_model().load_state_dict(
102
+ torch.load(pre_train, **kwargs),
103
+ strict=False
104
+ )
105
+ else:
106
+ self.get_model().load_state_dict(
107
+ torch.load(
108
+ os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),
109
+ **kwargs
110
+ ),
111
+ strict=False
112
+ )
113
+
114
+ def forward_chop(self, x, shave=10, min_size=6400):
115
+ scale = self.scale[self.idx_scale]
116
+ scale = 1
117
+ n_GPUs = min(self.n_GPUs, 4)
118
+ b, c, h, w = x.size()
119
+ h_half, w_half = h // 2, w // 2
120
+ h_size, w_size = h_half + shave, w_half + shave
121
+ lr_list = [
122
+ x[:, :, 0:h_size, 0:w_size],
123
+ x[:, :, 0:h_size, (w - w_size):w],
124
+ x[:, :, (h - h_size):h, 0:w_size],
125
+ x[:, :, (h - h_size):h, (w - w_size):w]]
126
+
127
+ if w_size * h_size < min_size:
128
+ sr_list = []
129
+ for i in range(0, 4, n_GPUs):
130
+ lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
131
+ sr_batch = self.model(lr_batch)
132
+ sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
133
+ else:
134
+ sr_list = [
135
+ self.forward_chop(patch, shave=shave, min_size=min_size) \
136
+ for patch in lr_list
137
+ ]
138
+
139
+ h, w = scale * h, scale * w
140
+ h_half, w_half = scale * h_half, scale * w_half
141
+ h_size, w_size = scale * h_size, scale * w_size
142
+ shave *= scale
143
+
144
+ output = x.new(b, c, h, w)
145
+ output[:, :, 0:h_half, 0:w_half] \
146
+ = sr_list[0][:, :, 0:h_half, 0:w_half]
147
+ output[:, :, 0:h_half, w_half:w] \
148
+ = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
149
+ output[:, :, h_half:h, 0:w_half] \
150
+ = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
151
+ output[:, :, h_half:h, w_half:w] \
152
+ = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
153
+
154
+ return output
155
+
156
+ def forward_x8(self, x, forward_function):
157
+ def _transform(v, op):
158
+ if self.precision != 'single': v = v.float()
159
+
160
+ v2np = v.data.cpu().numpy()
161
+ if op == 'v':
162
+ tfnp = v2np[:, :, :, ::-1].copy()
163
+ elif op == 'h':
164
+ tfnp = v2np[:, :, ::-1, :].copy()
165
+ elif op == 't':
166
+ tfnp = v2np.transpose((0, 1, 3, 2)).copy()
167
+
168
+ ret = torch.Tensor(tfnp).to(self.device)
169
+ if self.precision == 'half': ret = ret.half()
170
+
171
+ return ret
172
+
173
+ lr_list = [x]
174
+ for tf in 'v', 'h', 't':
175
+ lr_list.extend([_transform(t, tf) for t in lr_list])
176
+
177
+ sr_list = [forward_function(aug) for aug in lr_list]
178
+ for i in range(len(sr_list)):
179
+ if i > 3:
180
+ sr_list[i] = _transform(sr_list[i], 't')
181
+ if i % 4 > 1:
182
+ sr_list[i] = _transform(sr_list[i], 'h')
183
+ if (i % 4) % 2 == 1:
184
+ sr_list[i] = _transform(sr_list[i], 'v')
185
+
186
+ output_cat = torch.cat(sr_list, dim=0)
187
+ output = output_cat.mean(dim=0, keepdim=True)
188
+
189
+ return output
190
+
Demosaic/code/model/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (5.42 kB). View file
 
Demosaic/code/model/__pycache__/attention.cpython-37.pyc ADDED
Binary file (3.35 kB). View file
 
Demosaic/code/model/__pycache__/common.cpython-37.pyc ADDED
Binary file (3.37 kB). View file
 
Demosaic/code/model/__pycache__/lambdanet.cpython-37.pyc ADDED
Binary file (2.81 kB). View file
 
Demosaic/code/model/__pycache__/raftnet.cpython-37.pyc ADDED
Binary file (4.8 kB). View file
 
Demosaic/code/model/__pycache__/raftnetlayer.cpython-37.pyc ADDED
Binary file (4.96 kB). View file
 
Demosaic/code/model/__pycache__/raftnets.cpython-37.pyc ADDED
Binary file (4.91 kB). View file
 
Demosaic/code/model/__pycache__/raftnetsingle.cpython-37.pyc ADDED
Binary file (4.88 kB). View file
 
Demosaic/code/model/attention.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from torchvision import utils as vutils
6
+ from model import common
7
+ from utils.tools import extract_image_patches,\
8
+ reduce_mean, reduce_sum, same_padding
9
+
10
+ class PyramidAttention(nn.Module):
11
+ def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True, conv=common.default_conv):
12
+ super(PyramidAttention, self).__init__()
13
+ self.ksize = ksize
14
+ self.stride = stride
15
+ self.res_scale = res_scale
16
+ self.softmax_scale = softmax_scale
17
+ self.scale = [1-i/10 for i in range(level)]
18
+ self.average = average
19
+ escape_NaN = torch.FloatTensor([1e-4])
20
+ self.register_buffer('escape_NaN', escape_NaN)
21
+ self.conv_match_L_base = common.BasicBlock(conv,channel,channel//reduction, 1, bn=False, act=nn.PReLU())
22
+ self.conv_match = common.BasicBlock(conv,channel, channel//reduction, 1, bn=False, act=nn.PReLU())
23
+ self.conv_assembly = common.BasicBlock(conv,channel, channel,1,bn=False, act=nn.PReLU())
24
+
25
+ def forward(self, input):
26
+ res = input
27
+ #theta
28
+ match_base = self.conv_match_L_base(input)
29
+ shape_base = list(res.size())
30
+ input_groups = torch.split(match_base,1,dim=0)
31
+ # patch size for matching
32
+ kernel = self.ksize
33
+ # raw_w is for reconstruction
34
+ raw_w = []
35
+ # w is for matching
36
+ w = []
37
+ #build feature pyramid
38
+ for i in range(len(self.scale)):
39
+ ref = input
40
+ if self.scale[i]!=1:
41
+ ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic')
42
+ #feature transformation function f
43
+ base = self.conv_assembly(ref)
44
+ shape_input = base.shape
45
+ #sampling
46
+ raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],
47
+ strides=[self.stride,self.stride],
48
+ rates=[1, 1],
49
+ padding='same') # [N, C*k*k, L]
50
+ raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)
51
+ raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
52
+ raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
53
+ raw_w.append(raw_w_i_groups)
54
+
55
+ #feature transformation function g
56
+ ref_i = self.conv_match(ref)
57
+ shape_ref = ref_i.shape
58
+ #sampling
59
+ w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],
60
+ strides=[self.stride, self.stride],
61
+ rates=[1, 1],
62
+ padding='same')
63
+ w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
64
+ w_i = w_i.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
65
+ w_i_groups = torch.split(w_i, 1, dim=0)
66
+ w.append(w_i_groups)
67
+
68
+ y = []
69
+ for idx, xi in enumerate(input_groups):
70
+ #group in a filter
71
+ wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))],dim=0) # [L, C, k, k]
72
+ #normalize
73
+ max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
74
+ axis=[1, 2, 3],
75
+ keepdim=True)),
76
+ self.escape_NaN)
77
+ wi_normed = wi/ max_wi
78
+ #matching
79
+ xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
80
+ yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
81
+ yi = yi.view(1,wi.shape[0], shape_base[2], shape_base[3]) # (B=1, C=32*32, H=32, W=32)
82
+ # softmax matching score
83
+ yi = F.softmax(yi*self.softmax_scale, dim=1)
84
+
85
+ if self.average == False:
86
+ yi = (yi == yi.max(dim=1,keepdim=True)[0]).float()
87
+
88
+ # deconv for patch pasting
89
+ raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))],dim=0)
90
+ yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride,padding=1)/4.
91
+ y.append(yi)
92
+
93
+ y = torch.cat(y, dim=0)+res*self.res_scale # back to the mini-batch
94
+ return y
Demosaic/code/model/betalambdanet.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import common
2
+ import torch
3
+ import torch.nn as nn
4
+ from lambda_networks import LambdaLayer
5
+ import torch.cuda.amp as amp
6
+
7
+ def make_model(args, parent=False):
8
+ return BETALAMBDANET(args)
9
+
10
+ class BETALAMBDANET(nn.Module):
11
+ def __init__(self, args, conv=common.default_conv):
12
+ super(BETALAMBDANET, self).__init__()
13
+
14
+ n_resblocks = args.n_resblocks
15
+ n_feats = args.n_feats
16
+ kernel_size = 3
17
+ scale = args.scale[0]
18
+
19
+ rgb_mean = (0.4488, 0.4371, 0.4040)
20
+ rgb_std = (1.0, 1.0, 1.0)
21
+ self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
22
+ layer = LambdaLayer(
23
+ dim = n_feats,
24
+ dim_out = n_feats,
25
+ r = 23, # the receptive field for relative positional encoding (23 x 23)
26
+ dim_k = 16,
27
+ heads = 4,
28
+ dim_u = 4
29
+ )
30
+ # msa = attention.PyramidAttention()
31
+ # define head module
32
+ m_head = [conv(args.n_colors, n_feats, kernel_size)]
33
+
34
+ # define body module
35
+ m_body = [
36
+ common.ResBlock(
37
+ conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale
38
+ ) for _ in range(n_resblocks//2)
39
+ ]
40
+ # m_body.append(msa)
41
+ m_body.append(layer)
42
+ for i in range(n_resblocks//2):
43
+ m_body.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale))
44
+
45
+ m_body.append(conv(n_feats, n_feats, kernel_size))
46
+
47
+ # define tail module
48
+ #m_tail = [
49
+ # common.Upsampler(conv, scale, n_feats, act=False),
50
+ # conv(n_feats, args.n_colors, kernel_size)
51
+ #]
52
+ m_tail = [
53
+ conv(n_feats, args.n_colors, kernel_size)
54
+ ]
55
+
56
+ self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
57
+
58
+ self.head = nn.Sequential(*m_head)
59
+ self.body = nn.Sequential(*m_body)
60
+ self.tail = nn.Sequential(*m_tail)
61
+
62
+ self.recurrence = args.recurrence
63
+ self.detach = args.detach
64
+ # self.step_detach = args.step_detach
65
+ self.amp = args.amp
66
+ self.beta=nn.Parameter(torch.ones(1)*0.5)
67
+
68
+ def forward(self, x):
69
+ with amp.autocast(self.amp):
70
+ out = self.head(x)
71
+ last_output=out
72
+ for i in range(self.recurrence):
73
+ res = self.body(last_output)
74
+ res = self.beta*res + (1-self.beta)*last_output
75
+ last_output=res
76
+ output = self.tail(last_output)
77
+ return [output]
78
+
79
+ def load_state_dict(self, state_dict, strict=True):
80
+ own_state = self.state_dict()
81
+ for name, param in state_dict.items():
82
+ if name in own_state:
83
+ if isinstance(param, nn.Parameter):
84
+ param = param.data
85
+ try:
86
+ own_state[name].copy_(param)
87
+ except Exception:
88
+ if name.find('tail') == -1:
89
+ raise RuntimeError('While copying the parameter named {}, '
90
+ 'whose dimensions in the model are {} and '
91
+ 'whose dimensions in the checkpoint are {}.'
92
+ .format(name, own_state[name].size(), param.size()))
93
+ elif strict:
94
+ if name.find('tail') == -1:
95
+ raise KeyError('unexpected key "{}" in state_dict'
96
+ .format(name))
97
+
Demosaic/code/model/common.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
8
+ return nn.Conv2d(
9
+ in_channels, out_channels, kernel_size,
10
+ padding=(kernel_size//2),stride=stride, bias=bias)
11
+
12
+ def spectral_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
13
+ return nn.utils.spectral_norm(nn.Conv2d(
14
+ in_channels, out_channels, kernel_size,
15
+ padding=(kernel_size//2),stride=stride, bias=bias))
16
+
17
+ class MeanShift(nn.Conv2d):
18
+ def __init__(
19
+ self, rgb_range,
20
+ rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
21
+
22
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
23
+ std = torch.Tensor(rgb_std)
24
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
25
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
26
+ for p in self.parameters():
27
+ p.requires_grad = False
28
+
29
+ class BasicBlock(nn.Sequential):
30
+ def __init__(
31
+ self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
32
+ bn=False, act=nn.PReLU()):
33
+
34
+ m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
35
+ if bn:
36
+ m.append(nn.BatchNorm2d(out_channels))
37
+ if act is not None:
38
+ m.append(act)
39
+
40
+ super(BasicBlock, self).__init__(*m)
41
+
42
+ class ResBlock(nn.Module):
43
+ def __init__(
44
+ self, conv, n_feats, kernel_size,
45
+ bias=True, bn=False, act=nn.PReLU(), res_scale=1):
46
+
47
+ super(ResBlock, self).__init__()
48
+ m = []
49
+ for i in range(2):
50
+ m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
51
+ if bn:
52
+ m.append(nn.BatchNorm2d(n_feats))
53
+ if i == 0:
54
+ m.append(act)
55
+
56
+ self.body = nn.Sequential(*m)
57
+ self.res_scale = res_scale
58
+
59
+ def forward(self, x):
60
+ res = self.body(x).mul(self.res_scale)
61
+ res += x
62
+
63
+ return res
64
+
65
+ class Upsampler(nn.Sequential):
66
+ def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
67
+
68
+ m = []
69
+ if (scale & (scale - 1)) == 0: # Is scale = 2^n?
70
+ for _ in range(int(math.log(scale, 2))):
71
+ m.append(conv(n_feats, 4 * n_feats, 3, bias))
72
+ m.append(nn.PixelShuffle(2))
73
+ if bn:
74
+ m.append(nn.BatchNorm2d(n_feats))
75
+ if act == 'relu':
76
+ m.append(nn.ReLU(True))
77
+ elif act == 'prelu':
78
+ m.append(nn.PReLU(n_feats))
79
+
80
+ elif scale == 3:
81
+ m.append(conv(n_feats, 9 * n_feats, 3, bias))
82
+ m.append(nn.PixelShuffle(3))
83
+ if bn:
84
+ m.append(nn.BatchNorm2d(n_feats))
85
+ if act == 'relu':
86
+ m.append(nn.ReLU(True))
87
+ elif act == 'prelu':
88
+ m.append(nn.PReLU(n_feats))
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ super(Upsampler, self).__init__(*m)
93
+
Demosaic/code/model/ddbpn.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deep Back-Projection Networks For Super-Resolution
2
+ # https://arxiv.org/abs/1803.02735
3
+
4
+ from model import common
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ def make_model(args, parent=False):
11
+ return DDBPN(args)
12
+
13
+ def projection_conv(in_channels, out_channels, scale, up=True):
14
+ kernel_size, stride, padding = {
15
+ 2: (6, 2, 2),
16
+ 4: (8, 4, 2),
17
+ 8: (12, 8, 2)
18
+ }[scale]
19
+ if up:
20
+ conv_f = nn.ConvTranspose2d
21
+ else:
22
+ conv_f = nn.Conv2d
23
+
24
+ return conv_f(
25
+ in_channels, out_channels, kernel_size,
26
+ stride=stride, padding=padding
27
+ )
28
+
29
+ class DenseProjection(nn.Module):
30
+ def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
31
+ super(DenseProjection, self).__init__()
32
+ if bottleneck:
33
+ self.bottleneck = nn.Sequential(*[
34
+ nn.Conv2d(in_channels, nr, 1),
35
+ nn.PReLU(nr)
36
+ ])
37
+ inter_channels = nr
38
+ else:
39
+ self.bottleneck = None
40
+ inter_channels = in_channels
41
+
42
+ self.conv_1 = nn.Sequential(*[
43
+ projection_conv(inter_channels, nr, scale, up),
44
+ nn.PReLU(nr)
45
+ ])
46
+ self.conv_2 = nn.Sequential(*[
47
+ projection_conv(nr, inter_channels, scale, not up),
48
+ nn.PReLU(inter_channels)
49
+ ])
50
+ self.conv_3 = nn.Sequential(*[
51
+ projection_conv(inter_channels, nr, scale, up),
52
+ nn.PReLU(nr)
53
+ ])
54
+
55
+ def forward(self, x):
56
+ if self.bottleneck is not None:
57
+ x = self.bottleneck(x)
58
+
59
+ a_0 = self.conv_1(x)
60
+ b_0 = self.conv_2(a_0)
61
+ e = b_0.sub(x)
62
+ a_1 = self.conv_3(e)
63
+
64
+ out = a_0.add(a_1)
65
+
66
+ return out
67
+
68
+ class DDBPN(nn.Module):
69
+ def __init__(self, args):
70
+ super(DDBPN, self).__init__()
71
+ scale = args.scale[0]
72
+
73
+ n0 = 128
74
+ nr = 32
75
+ self.depth = 6
76
+
77
+ rgb_mean = (0.4488, 0.4371, 0.4040)
78
+ rgb_std = (1.0, 1.0, 1.0)
79
+ self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
80
+ initial = [
81
+ nn.Conv2d(args.n_colors, n0, 3, padding=1),
82
+ nn.PReLU(n0),
83
+ nn.Conv2d(n0, nr, 1),
84
+ nn.PReLU(nr)
85
+ ]
86
+ self.initial = nn.Sequential(*initial)
87
+
88
+ self.upmodules = nn.ModuleList()
89
+ self.downmodules = nn.ModuleList()
90
+ channels = nr
91
+ for i in range(self.depth):
92
+ self.upmodules.append(
93
+ DenseProjection(channels, nr, scale, True, i > 1)
94
+ )
95
+ if i != 0:
96
+ channels += nr
97
+
98
+ channels = nr
99
+ for i in range(self.depth - 1):
100
+ self.downmodules.append(
101
+ DenseProjection(channels, nr, scale, False, i != 0)
102
+ )
103
+ channels += nr
104
+
105
+ reconstruction = [
106
+ nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1)
107
+ ]
108
+ self.reconstruction = nn.Sequential(*reconstruction)
109
+
110
+ self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
111
+
112
+ def forward(self, x):
113
+ x = self.sub_mean(x)
114
+ x = self.initial(x)
115
+
116
+ h_list = []
117
+ l_list = []
118
+ for i in range(self.depth - 1):
119
+ if i == 0:
120
+ l = x
121
+ else:
122
+ l = torch.cat(l_list, dim=1)
123
+ h_list.append(self.upmodules[i](l))
124
+ l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))
125
+
126
+ h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))
127
+ out = self.reconstruction(torch.cat(h_list, dim=1))
128
+ out = self.add_mean(out)
129
+
130
+ return out
131
+