amd
/

ONNX
PyTorch
English
RyzenAI
super resolution
SISR
zhengrongzhang commited on
Commit
2071132
1 Parent(s): d8d07c6

init model

Browse files
README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - eugenesiow/Div2k
5
+ - eugenesiow/Set5
6
+ language:
7
+ - en
8
+ tags:
9
+ - RyzenAI
10
+ - super resolution
11
+ - SISR
12
+ - pytorch
13
+ ---
14
+ ## Model description
15
+ SESR is based on linear overparameterization of CNNs and creates an efficient model architecture for SISR. It was introduced in the paper [Collapsible Linear Blocks for Super-Efficient Super Resolution](https://arxiv.org/abs/2103.09404).
16
+ The official code for this work is available at this
17
+ https://github.com/ARM-software/sesr
18
+
19
+ We develop a modified version that could be supported by [AMD Ryzen AI](https://onnxruntime.ai/docs/execution-providers/Vitis-AI-ExecutionProvider.html).
20
+
21
+ ## Intended uses & limitations
22
+
23
+ You can use the raw model for super resolution. See the [model hub](https://huggingface.co/models?search=amd/sesr) to look for all available models.
24
+
25
+
26
+ ## How to use
27
+
28
+ ### Installation
29
+
30
+ Follow [Ryzen AI Installation](https://ryzenai.docs.amd.com/en/latest/inst.html) to prepare the environment for Ryzen AI.
31
+ Run the following script to install pre-requisites for this model.
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+
37
+ ### Data Preparation (optional: for accuracy evaluation)
38
+
39
+ 1. Download the benchmark(https://cv.snu.ac.kr/research/EDSR/benchmark.tar) dataset.
40
+ 2. Organize the dataset directory as follows:
41
+ ```Plain
42
+ └── dataset
43
+ └── benchmark
44
+ ├── Set5
45
+ ├── HR
46
+ | ├── baby.png
47
+ | ├── ...
48
+ └── LR_bicubic
49
+ └──X2
50
+ ├──babyx2.png
51
+ ├── ...
52
+ ├── Set14
53
+ ├── ...
54
+ ```
55
+
56
+ ### Test & Evaluation
57
+
58
+ - Code snippet from [`one_image_inference.py`](one_image_inference.py) on how to use
59
+ ```python
60
+ parser = argparse.ArgumentParser(description='EDSR and MDSR')
61
+ parser.add_argument('--onnx_path', type=str, default='SESR_int8.onnx',
62
+ help='onnx path')
63
+ parser.add_argument('--image_path', default='test_data/test.png',
64
+ help='path of your image')
65
+ parser.add_argument('--output_path', default='test_data/sr.png',
66
+ help='path of your image')
67
+ parser.add_argument('--ipu', action='store_true',
68
+ help='use ipu')
69
+ parser.add_argument('--provider_config', type=str, default=None,
70
+ help='provider config path')
71
+ args = parser.parse_args()
72
+ if args.ipu:
73
+ providers = ["VitisAIExecutionProvider"]
74
+ provider_options = [{"config_file": args.provider_config}]
75
+ else:
76
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
77
+ provider_options = None
78
+
79
+ onnx_file_name = args.onnx_path
80
+ image_path = args.image_path
81
+ output_path = args.output_path
82
+
83
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
84
+ lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
85
+ sr = tiling_inference(ort_session, lr, 8, (56, 56))
86
+ sr = np.clip(sr, 0, 255)
87
+ sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
88
+ sr = cv2.imwrite(output_path, sr)
89
+ ```
90
+
91
+ - Run inference for a single image
92
+ ```python
93
+ python one_image_inference.py --onnx_path SESR_int8.onnx --image_path /Path/To/Your/Image --ipu --provider_config Path/To/vaip_config.json
94
+ ```
95
+ Note: **vaip_config.json** is located at the setup package of Ryzen AI (refer to [Installation](https://huggingface.co/amd/yolox-s#installation))
96
+
97
+ - Test accuracy of the quantized model
98
+ ```python
99
+ python test.py --onnx_path SESR_int8.onnx --data_test Set5 --ipu --provider_config Path/To/vaip_config.json
100
+ ```
101
+
102
+
103
+
104
+ ### Performance
105
+ | Method | Scale | Flops | Set5 |
106
+ |------------|-------|-------|--------------|
107
+ |SESR-S (float) |X2 |10.22G |37.21|
108
+ |SESR-S (INT8) |X2 |10.22G |36.81|
109
+ - Note: the Flops is calculated with the input resolution is 256x256
110
+
111
+
112
+ ```bibtex
113
+ @misc{bhardwaj2022collapsible,
114
+ title={Collapsible Linear Blocks for Super-Efficient Super Resolution},
115
+ author={Kartikeya Bhardwaj and Milos Milosavljevic and Liam O'Neil and Dibakar Gope and Ramon Matas and Alex Chalfin and Naveen Suda and Lingchuan Meng and Danny Loh},
116
+ year={2022},
117
+ eprint={2103.09404},
118
+ archivePrefix={arXiv},
119
+ primaryClass={eess.IV}
120
+ }
121
+ ```
SESR_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2b4dc9547653f01bec1ba53e42f6722ffa0cc74ee1c787c7e93174d17260de7
3
+ size 110994
data/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ #from dataloader import MSDataLoader
3
+ from torch.utils.data import dataloader
4
+ from torch.utils.data import ConcatDataset
5
+ import torch
6
+ import random
7
+ # This is a simple wrapper function for ConcatDataset
8
+ class MyConcatDataset(ConcatDataset):
9
+ def __init__(self, datasets):
10
+ super(MyConcatDataset, self).__init__(datasets)
11
+
12
+
13
+ def set_scale(self, idx_scale):
14
+ for d in self.datasets:
15
+ if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
16
+
17
+ class Data:
18
+ def __init__(self, args):
19
+ self.loader_train = None
20
+ self.loader_test = []
21
+ for d in args.data_test:
22
+ if d in ['Set5', 'Set14', 'B100', 'Urban100']:
23
+ m = import_module('data.benchmark')
24
+ testset = getattr(m, 'Benchmark')(args, name=d)
25
+ else:
26
+ assert NotImplementedError
27
+
28
+ self.loader_test.append(
29
+ dataloader.DataLoader(
30
+ testset,
31
+ batch_size=1,
32
+ shuffle=False,
33
+ pin_memory=False,
34
+ num_workers=args.n_threads,
35
+ )
36
+ )
data/benchmark.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ #from data import common
4
+ from data import srdata
5
+ import numpy as np
6
+ import torch
7
+ import torch.utils.data as data
8
+
9
+ class Benchmark(srdata.SRData):
10
+ def __init__(self, args, name='', benchmark=True):
11
+ super(Benchmark, self).__init__(
12
+ args, name=name, benchmark=True
13
+ )
14
+
15
+ def _set_filesystem(self, dir_data):
16
+ self.apath = os.path.join(dir_data, 'benchmark', self.name)
17
+ self.dir_hr = os.path.join(self.apath, 'HR')
18
+ if self.input_large:
19
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
20
+ else:
21
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
22
+ self.ext = ('', '.png')
23
+
data/common.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import skimage.color as sc
5
+
6
+ import torch
7
+
8
+ def set_channel(*args, n_channels=3):
9
+ def _set_channel(img):
10
+ if img.ndim == 2:
11
+ img = np.expand_dims(img, axis=2)
12
+
13
+ c = img.shape[2]
14
+ if n_channels == 1 and c == 3:
15
+ img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
16
+ elif n_channels == 3 and c == 1:
17
+ img = np.concatenate([img] * n_channels, 2)
18
+
19
+ return img
20
+
21
+ return [_set_channel(a) for a in args]
22
+
23
+ def np2Tensor(*args, rgb_range=255):
24
+ def _np2Tensor(img):
25
+ np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
26
+ tensor = torch.from_numpy(np_transpose).float()
27
+ tensor.mul_(rgb_range / 255)
28
+
29
+ return tensor
30
+
31
+ return [_np2Tensor(a) for a in args]
32
+
33
+
34
+
data/data_tiling.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+
5
+ def tiling_inference(session, lr, overlapping, patch_size):
6
+ """
7
+ Parameters:
8
+ - session: an ONNX Runtime session object that contains the super-resolution model
9
+ - lr: the low-resolution image
10
+ - overlapping: the number of pixels to overlap between adjacent patches
11
+ - patch_size: a tuple of (height, width) that specifies the size of each patch
12
+ Returns: - a numpy array that represents the enhanced image
13
+ """
14
+ _, _, h, w = lr.shape
15
+ sr = np.zeros((1, 3, 2*h, 2*w))
16
+ n_h = math.ceil(h / float(patch_size[0] - overlapping))
17
+ n_w = math.ceil(w / float(patch_size[1] - overlapping))
18
+ #every tilling input has same size of patch_size
19
+ for ih in range(n_h):
20
+ h_idx = ih * (patch_size[0] - overlapping)
21
+ h_idx = h_idx if h_idx + patch_size[0] <= h else h - patch_size[0]
22
+ for iw in range(n_w):
23
+ w_idx = iw * (patch_size[1] - overlapping)
24
+ w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
25
+
26
+ tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
27
+ sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr})[0]
28
+
29
+ left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
30
+ left += overlapping//2
31
+ right -= overlapping//2
32
+ top += overlapping//2
33
+ bottom -= overlapping//2
34
+ #processing edge pixels
35
+ if w_idx == 0:
36
+ left -= overlapping//2
37
+ if h_idx == 0:
38
+ top -= overlapping//2
39
+ if h_idx+patch_size[0]>=h:
40
+ bottom += overlapping//2
41
+ if w_idx+patch_size[1]>=w:
42
+ right += overlapping//2
43
+
44
+ #get preditions
45
+ sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right)] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right]
46
+ return sr
data/srdata.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+ import pickle
5
+ from data import common
6
+ import imageio
7
+ import torch.utils.data as data
8
+
9
+ class SRData(data.Dataset):
10
+ def __init__(self, args, name='', benchmark=False):
11
+ self.args = args
12
+ self.name = name
13
+ self.split = 'test'
14
+ self.do_eval = True
15
+ self.benchmark = benchmark
16
+ self.input_large = False
17
+ self.scale = args.scale
18
+ self.idx_scale = 0
19
+ self._set_filesystem(args.dir_data)
20
+ list_hr, list_lr = self._scan()
21
+ self.images_hr, self.images_lr = list_hr, list_lr
22
+
23
+ # Below functions as used to prepare images
24
+ def _scan(self):
25
+ names_hr = sorted(
26
+ glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
27
+ )
28
+ names_lr = [[] for _ in self.scale]
29
+ for f in names_hr:
30
+ filename, _ = os.path.splitext(os.path.basename(f))
31
+ for si, s in enumerate(self.scale):
32
+ names_lr[si].append(os.path.join(
33
+ self.dir_lr, 'X{}/{}x{}{}'.format(
34
+ s, filename, s, self.ext[1]
35
+ )
36
+ ))
37
+
38
+ return names_hr, names_lr
39
+
40
+ def _set_filesystem(self, dir_data):
41
+ self.apath = os.path.join(dir_data, self.name)
42
+ self.dir_hr = os.path.join(self.apath, 'HR')
43
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
44
+ self.ext = ('.png', '.png')
45
+
46
+ def __getitem__(self, idx):
47
+ lr, hr, filename = self._load_file(idx)
48
+ pair = self.get_patch(lr, hr)
49
+ pair = common.set_channel(*pair, n_channels=3)
50
+ pair_t = common.np2Tensor(*pair, rgb_range=255)
51
+
52
+ return pair_t[0], pair_t[1], filename
53
+
54
+ def __len__(self):
55
+ return len(self.images_hr)
56
+
57
+ def _get_index(self, idx):
58
+ return idx
59
+
60
+ def _load_file(self, idx):
61
+ idx = self._get_index(idx)
62
+ f_hr = self.images_hr[idx]
63
+ f_lr = self.images_lr[self.idx_scale][idx]
64
+
65
+ filename, _ = os.path.splitext(os.path.basename(f_hr))
66
+ hr = imageio.imread(f_hr)
67
+ lr = imageio.imread(f_lr)
68
+ return lr, hr, filename
69
+
70
+ def get_patch(self, lr, hr):
71
+ scale = self.scale[self.idx_scale]
72
+ ih, iw = lr.shape[:2]
73
+ hr = hr[0:ih * scale, 0:iw * scale]
74
+ return lr, hr
75
+
76
+ def set_scale(self, idx_scale):
77
+ if not self.input_large:
78
+ self.idx_scale = idx_scale
79
+ else:
80
+ self.idx_scale = random.randint(0, len(self.scale) - 1)
81
+
metric.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from scipy import signal
4
+
5
+ def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
6
+ if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
7
+ print("the dimention of sr image is not equal to hr's! ")
8
+ sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
9
+ diff = (sr - hr).data.div(rgb_range)
10
+
11
+ if benchmark:
12
+ shave = scale
13
+ if diff.size(1) > 1:
14
+ convert = diff.new(1, 3, 1, 1)
15
+ convert[0, 0, 0, 0] = 65.738
16
+ convert[0, 1, 0, 0] = 129.057
17
+ convert[0, 2, 0, 0] = 25.064
18
+ diff.mul_(convert).div_(256)
19
+ diff = diff.sum(dim=1, keepdim=True)
20
+ else:
21
+ shave = scale + 6
22
+ valid = diff[:, :, shave:-shave, shave:-shave]
23
+ mse = valid.pow(2).mean()
24
+
25
+ return -10 * math.log10(mse)
26
+
27
+ def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
28
+ """
29
+ 2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma])
30
+ Acknowledgement : https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python (Author@ali_m)
31
+ """
32
+ m,n = [(ss-1.)/2. for ss in shape]
33
+ y,x = np.ogrid[-m:m+1,-n:n+1]
34
+ h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
35
+ h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
36
+ sumh = h.sum()
37
+ if sumh != 0:
38
+ h /= sumh
39
+ return h
40
+
41
+ def calc_ssim(X, Y, scale, rgb_range, dataset=None, sigma=1.5, K1=0.01, K2=0.03, R=255):
42
+ '''
43
+ X : y channel (i.e., luminance) of transformed YCbCr space of X
44
+ Y : y channel (i.e., luminance) of transformed YCbCr space of Y
45
+ Please follow the setting of psnr_ssim.m in EDSR (Enhanced Deep Residual Networks for Single Image Super-Resolution CVPRW2017).
46
+ Official Link : https://github.com/LimBee/NTIRE2017/tree/db34606c2844e89317aac8728a2de562ef1f8aba
47
+ The authors of EDSR use MATLAB's ssim as the evaluation tool,
48
+ thus this function is the same as ssim.m in MATLAB with C(3) == C(2)/2.
49
+ '''
50
+ gaussian_filter = matlab_style_gauss2D((11, 11), sigma)
51
+ shave = scale
52
+ if X.size(1) > 1:
53
+ gray_coeffs = [65.738, 129.057, 25.064]
54
+ convert = X.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
55
+ X = X.mul(convert).sum(dim=1)
56
+ Y = Y.mul(convert).sum(dim=1)
57
+
58
+
59
+ X = X[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
60
+ Y = Y[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
61
+
62
+ window = gaussian_filter
63
+
64
+ ux = signal.convolve2d(X, window, mode='same', boundary='symm')
65
+ uy = signal.convolve2d(Y, window, mode='same', boundary='symm')
66
+
67
+ uxx = signal.convolve2d(X*X, window, mode='same', boundary='symm')
68
+ uyy = signal.convolve2d(Y*Y, window, mode='same', boundary='symm')
69
+ uxy = signal.convolve2d(X*Y, window, mode='same', boundary='symm')
70
+
71
+ vx = uxx - ux * ux
72
+ vy = uyy - uy * uy
73
+ vxy = uxy - ux * uy
74
+
75
+ C1 = (K1 * R) ** 2
76
+ C2 = (K2 * R) ** 2
77
+
78
+ A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2))
79
+ D = B1 * B2
80
+ S = (A1 * A2) / D
81
+ mssim = S.mean()
82
+
83
+ return mssim
one_image_inference.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import cv2
3
+ import numpy as np
4
+ import sys
5
+ import pathlib
6
+ CURRENT_DIR = pathlib.Path(__file__).parent
7
+ sys.path.append(str(CURRENT_DIR))
8
+ from data.data_tiling import tiling_inference
9
+ import argparse
10
+
11
+
12
+ def main(args):
13
+ if args.ipu:
14
+ providers = ["VitisAIExecutionProvider"]
15
+ provider_options = [{"config_file": args.provider_config}]
16
+ else:
17
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
18
+ provider_options = None
19
+ onnx_file_name = args.onnx_path
20
+ image_path = args.image_path
21
+ output_path = args.output_path
22
+
23
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
24
+ lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
25
+ sr = tiling_inference(ort_session, lr, 8, (56, 56))
26
+ sr = np.clip(sr, 0, 255)
27
+ sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
28
+ cv2.imwrite(output_path, sr)
29
+
30
+
31
+ if __name__ == '__main__':
32
+ parser = argparse.ArgumentParser(description='EDSR and MDSR')
33
+ parser.add_argument('--onnx_path', type=str, default='SESR_int8.onnx',
34
+ help='onnx path')
35
+ parser.add_argument('--image_path', default='test_data/test.png',
36
+ help='path of your image')
37
+ parser.add_argument('--output_path', default='test_data/sr.png',
38
+ help='path of your image')
39
+ parser.add_argument('--ipu', action='store_true',
40
+ help='use ipu')
41
+ parser.add_argument('--provider_config', type=str, default=None,
42
+ help='provider config path')
43
+ args = parser.parse_args()
44
+ main(args)
option.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ parser = argparse.ArgumentParser(description='SESR')
3
+
4
+ # ipu test or cpu, you need to provide onnx path
5
+ parser.add_argument('--ipu', action='store_true',
6
+ help='use ipu')
7
+ parser.add_argument('--onnx_path', type=str, default='SESR_int8.onnx',
8
+ help='onnx path')
9
+ parser.add_argument('--provider_config', type=str, default=None,
10
+ help='provider config path')
11
+ # Data specifications, you can use default
12
+ parser.add_argument('--dir_data', type=str, default='dataset/',
13
+ help='dataset directory')
14
+ parser.add_argument('--data_test', type=str, default='Set5',
15
+ help='test dataset name')
16
+ parser.add_argument('--n_threads', type=int, default=6,
17
+ help='number of threads for data loading')
18
+ parser.add_argument('--scale', type=str, default='2',
19
+ help='super resolution scale')
20
+ args = parser.parse_args()
21
+ args.scale = list(map(lambda x: int(x), args.scale.split('+')))
22
+ args.data_test = args.data_test.split('+')
23
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ numpy>=1.23.5
3
+ scipy>=1.9
4
+ opencv-python
5
+ pandas
6
+ pillow
7
+ scikit-image
8
+ tqdm
test.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import pathlib
4
+ CURRENT_DIR = pathlib.Path(__file__).parent
5
+ sys.path.append(str(CURRENT_DIR))
6
+ from tqdm import tqdm
7
+ import utility
8
+ import data
9
+ from option import args
10
+ import metric
11
+ import onnxruntime
12
+ import cv2
13
+ from data.data_tiling import tiling_inference
14
+
15
+
16
+ def prepare(a, b, device):
17
+ def _prepare(tensor):
18
+ return tensor.to(device)
19
+
20
+ return _prepare(a), _prepare(b)
21
+
22
+
23
+ def test_model(session, loader, device):
24
+ torch.set_grad_enabled(False)
25
+ self_scale = [2]
26
+ for idx_data, d in enumerate(loader.loader_test):
27
+ eval_ssim = 0
28
+ eval_psnr = 0
29
+ for idx_scale, scale in enumerate(self_scale):
30
+ d.dataset.set_scale(idx_scale)
31
+ for lr, hr, filename in tqdm(d, ncols=80):
32
+ lr, hr = prepare(lr, hr, device)
33
+ sr = tiling_inference(session, lr.cpu().numpy(), 8, (56, 56))
34
+ sr = torch.from_numpy(sr).to(device)
35
+ sr = utility.quantize(sr, 255)
36
+ eval_psnr += metric.calc_psnr(
37
+ sr, hr, scale, 255, benchmark=d)
38
+ eval_ssim += metric.calc_ssim(
39
+ sr, hr, scale, 255, dataset=d)
40
+ mean_ssim = eval_ssim / len(d)
41
+ mean_psnr = eval_psnr / len(d)
42
+ print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
43
+ return mean_psnr, mean_ssim
44
+
45
+ def main():
46
+ loader = data.Data(args)
47
+ if args.ipu:
48
+ providers = ["VitisAIExecutionProvider"]
49
+ provider_options = [{"config_file": args.provider_config}]
50
+ else:
51
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
52
+ provider_options = None
53
+ onnx_file_name = args.onnx_path
54
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
55
+ test_model(ort_session, loader, device="cpu")
56
+
57
+
58
+ if __name__ == '__main__':
59
+ main()
test_data/test.png ADDED
utility.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import numpy as np
5
+
6
+ class timer():
7
+ def __init__(self):
8
+ self.acc = 0
9
+ self.tic()
10
+
11
+ def tic(self):
12
+ self.t0 = time.time()
13
+
14
+ def toc(self, restart=False):
15
+ diff = time.time() - self.t0
16
+ if restart: self.t0 = time.time()
17
+ return diff
18
+
19
+ def hold(self):
20
+ self.acc += self.toc()
21
+
22
+ def release(self):
23
+ ret = self.acc
24
+ self.acc = 0
25
+
26
+ return ret
27
+
28
+ def reset(self):
29
+ self.acc = 0
30
+
31
+ def quantize(img, rgb_range):
32
+ pixel_range = 255 / rgb_range
33
+ return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)