Gokuleshwaran commited on
Commit
6221b96
1 Parent(s): 2f12290

First model version

Browse files
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image, ImageOps, ImageEnhance, ImageFilter
6
+ import numpy as np
7
+ import time
8
+ import io
9
+
10
+
11
+ # Metrics imports
12
+ from skimage.metrics import peak_signal_noise_ratio as psnr
13
+ from skimage.metrics import structural_similarity as ssim
14
+
15
+ # Model imports
16
+ from models.srcnn import SRCNN
17
+ from models.vdsr import VDSR
18
+ from models.edsr import EDSR
19
+
20
+ # Cache for loaded models
21
+ model_cache = {}
22
+
23
+ def load_model(model_name):
24
+ """
25
+ Load super-resolution model with optional scale factor
26
+
27
+ Args:
28
+ model_name (str): Name of the model (SRCNN, VDSR, EDSR)
29
+ scale_factor (int): Upscaling factor (2, 3, or 4)
30
+
31
+ Returns:
32
+ torch.nn.Module: Loaded model
33
+ """
34
+ try:
35
+ # Check if model is already in the cache
36
+ if model_name in model_cache:
37
+ return model_cache[model_name]
38
+
39
+ if model_name == 'SRCNN':
40
+ model = SRCNN()
41
+ elif model_name == 'VDSR':
42
+ model = VDSR()
43
+ else:
44
+ model = EDSR()
45
+
46
+ # Load pre-trained weights if available
47
+ weight_path = f'checkpoints/{model_name.lower()}_best.pth'
48
+ if os.path.exists(weight_path):
49
+ model.load_state_dict(torch.load(weight_path, map_location=torch.device('cpu'), weights_only=True))
50
+ else:
51
+ st.warning(f"No pre-trained weights found for the {model_name} model. Using randomly initialized weights.")
52
+
53
+ model.eval()
54
+
55
+ # Cache the loaded model
56
+ model_cache[model_name] = model
57
+ return model
58
+ except Exception as e:
59
+ st.error(f"Error loading {model_name} model: {e}")
60
+ return None
61
+
62
+ def process_image(image, model):
63
+ # Convert to YCbCr and extract Y channel
64
+ ycbcr = image.convert('YCbCr')
65
+ y, cb, cr = ycbcr.split()
66
+
67
+ # Transform Y channel
68
+ transform = transforms.Compose([
69
+ transforms.ToTensor()
70
+ ])
71
+
72
+ input_tensor = transform(y).unsqueeze(0)
73
+
74
+ # Process through model
75
+ with torch.no_grad():
76
+ output = model(input_tensor)
77
+
78
+ # Post-process output
79
+ output = output.squeeze().clamp(0, 1).numpy()
80
+ output_y = Image.fromarray((output * 255).astype(np.uint8))
81
+
82
+ # Merge channels back
83
+ output_ycbcr = Image.merge('YCbCr', [output_y, cb, cr])
84
+ output_rgb = output_ycbcr.convert('RGB')
85
+
86
+ return output_rgb
87
+
88
+ def calculate_image_metrics(original, enhanced):
89
+ """
90
+ Calculate image quality metrics
91
+
92
+ Args:
93
+ original (np.ndarray): Original image
94
+ enhanced (np.ndarray): Enhanced image
95
+
96
+ Returns:
97
+ dict: Quality metrics
98
+ """
99
+ try:
100
+ # Ensure images are the same size
101
+ min_height = min(original.shape[0], enhanced.shape[0])
102
+ min_width = min(original.shape[1], enhanced.shape[1])
103
+
104
+ # Resize images to the smallest common size
105
+ original = original[:min_height, :min_width]
106
+ enhanced = enhanced[:min_height, :min_width]
107
+
108
+ # Calculate SSIM with an explicit window size
109
+ win_size = min(7, min(min_height, min_width))
110
+ if win_size % 2 == 0:
111
+ win_size -= 1 # Ensure odd window size
112
+
113
+ return {
114
+ 'PSNR': psnr(original, enhanced),
115
+ 'SSIM': ssim(original, enhanced, multichannel=True, win_size=win_size, channel_axis=-1)
116
+ }
117
+ except Exception as e:
118
+ st.error(f"Error calculating metrics: {e}")
119
+ return {'PSNR': 0, 'SSIM': 0}
120
+
121
+ def main():
122
+ st.set_page_config(
123
+ page_title="Super Resolution Comparison",
124
+ page_icon="🖼️",
125
+ layout="wide"
126
+ )
127
+
128
+ st.title("🚀 Super Resolution Model Comparison")
129
+ st.write("Upload a low-resolution image and compare different super-resolution models.")
130
+
131
+
132
+
133
+
134
+ # File Upload
135
+ uploaded_file = st.file_uploader(
136
+ "Choose an image",
137
+ type=['png', 'jpg', 'jpeg'],
138
+ help="Upload a low-resolution image for enhancement"
139
+ )
140
+
141
+ if uploaded_file is not None:
142
+ # Load input image
143
+ input_image = Image.open(uploaded_file)
144
+ input_array = np.array(input_image)
145
+
146
+ st.subheader("📸 Original Image")
147
+ st.image(input_image, caption="Low-Resolution Input", use_column_width=True)
148
+
149
+ # Model Names
150
+ model_names = ['SRCNN', 'VDSR', 'EDSR']
151
+
152
+ # Performance and Quality Storage
153
+ processing_times = {}
154
+ quality_metrics = {}
155
+ enhanced_images = {}
156
+
157
+ # Process images
158
+ columns = st.columns(len(model_names))
159
+ for i, model_name in enumerate(model_names):
160
+ with columns[i]:
161
+ st.subheader(f"{model_name} Model")
162
+
163
+ # Load model
164
+ model = load_model(model_name)
165
+
166
+ if model:
167
+ # Time the processing
168
+ start_time = time.time()
169
+ enhanced_image = process_image(input_image, model)
170
+ processing_time = time.time() - start_time
171
+
172
+ if enhanced_image:
173
+ # Display enhanced image
174
+ st.image(enhanced_image, caption=f"{model_name} Output", use_column_width=True)
175
+
176
+ # Calculate metrics
177
+ enhanced_array = np.array(enhanced_image)
178
+ metrics = calculate_image_metrics(input_array, enhanced_array)
179
+
180
+ # Store results
181
+ processing_times[model_name] = processing_time
182
+ quality_metrics[model_name] = metrics
183
+ enhanced_images[model_name] = enhanced_image
184
+
185
+ # Performance Metrics Section
186
+ st.subheader("📊 Performance Metrics")
187
+ metric_cols = st.columns(len(model_names))
188
+
189
+ for i, (model, time_val) in enumerate(processing_times.items()):
190
+ with metric_cols[i]:
191
+ st.metric(f"{model} Processing Time", f"{time_val:.4f} seconds")
192
+
193
+ # Quality Metrics Section
194
+ st.subheader("🔍 Image Quality Assessment")
195
+ quality_cols = st.columns(len(model_names))
196
+
197
+ for i, (model, metrics) in enumerate(quality_metrics.items()):
198
+ with quality_cols[i]:
199
+ st.metric(f"{model} PSNR", f"{metrics['PSNR']:.2f} dB")
200
+ st.metric(f"{model} SSIM", f"{metrics['SSIM']:.4f}")
201
+
202
+ # Download Section
203
+ st.subheader("💾 Download Enhanced Images")
204
+ download_cols = st.columns(len(model_names))
205
+
206
+ for i, (model, image) in enumerate(enhanced_images.items()):
207
+ with download_cols[i]:
208
+ buffered = io.BytesIO()
209
+ image.save(buffered, format="PNG")
210
+ st.download_button(
211
+ label=f"Download {model} Image",
212
+ data=buffered.getvalue(),
213
+ file_name=f"{model}_enhanced.png",
214
+ mime="image/png"
215
+ )
216
+
217
+ if __name__ == "__main__":
218
+ main()
checkpoints/edsr_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3f1a4bd666d2537fdd371c2eac1a33c41fa890a577401f8a05ee9c1c73c6ea9
3
+ size 4904536
checkpoints/srcnn_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b15f698a761b03d40decbf66ee0ac3a2496b61f396427b8e181c236c132c356d
3
+ size 35238
checkpoints/vdsr_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec5a3872a48aea983e0d5f0349d81d3b44bd28b5bd0ba29cfa169478c663c325
3
+ size 2667978
inference.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import os
3
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
4
+ import streamlit as st
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ import numpy as np
9
+ from models.srcnn import SRCNN
10
+ from models.vdsr import VDSR
11
+ from models.edsr import EDSR
12
+
13
+ def load_model(model_name):
14
+ if model_name == 'SRCNN':
15
+ model = SRCNN()
16
+ elif model_name == 'VDSR':
17
+ model = VDSR()
18
+ else:
19
+ model = EDSR()
20
+
21
+ model.load_state_dict(torch.load(f'checkpoints/{model_name.lower()}_best.pth', map_location=torch.device('cpu')))
22
+ model.eval()
23
+ return model
24
+
25
+ def process_image(image, model):
26
+ # Convert to YCbCr and extract Y channel
27
+ ycbcr = image.convert('YCbCr')
28
+ y, cb, cr = ycbcr.split()
29
+
30
+ # Transform Y channel
31
+ transform = transforms.Compose([
32
+ transforms.ToTensor()
33
+ ])
34
+
35
+ input_tensor = transform(y).unsqueeze(0)
36
+
37
+ # Process through model
38
+ with torch.no_grad():
39
+ output = model(input_tensor)
40
+
41
+ # Post-process output
42
+ output = output.squeeze().clamp(0, 1).numpy()
43
+ output_y = Image.fromarray((output * 255).astype(np.uint8))
44
+
45
+ # Merge channels back
46
+ output_ycbcr = Image.merge('YCbCr', [output_y, cb, cr])
47
+ output_rgb = output_ycbcr.convert('RGB')
48
+
49
+ return output_rgb
50
+
51
+ def main():
52
+ st.title("Super Resolution Model Comparison")
53
+ st.write("Upload a low-resolution image to compare SRCNN, VDSR, and EDSR models")
54
+
55
+ # File uploader
56
+ uploaded_file = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'])
57
+
58
+ if uploaded_file is not None:
59
+ # Load and display input image
60
+ input_image = Image.open(uploaded_file)
61
+ st.subheader("Input Image")
62
+ st.image(input_image, caption="Original Image")
63
+
64
+ # Process with each model
65
+ col1, col2, col3 = st.columns(3)
66
+
67
+ with col1:
68
+ st.subheader("SRCNN")
69
+ model = load_model('SRCNN')
70
+ srcnn_output = process_image(input_image, model)
71
+ st.image(srcnn_output, caption="SRCNN Output")
72
+
73
+ with col2:
74
+ st.subheader("VDSR")
75
+ model = load_model('VDSR')
76
+ vdsr_output = process_image(input_image, model)
77
+ st.image(vdsr_output, caption="VDSR Output")
78
+
79
+ with col3:
80
+ st.subheader("EDSR")
81
+ model = load_model('EDSR')
82
+ edsr_output = process_image(input_image, model)
83
+ st.image(edsr_output, caption="EDSR Output")
84
+
85
+ if __name__ == "__main__":
86
+ main()
metrics.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # metrics.py
2
+ import os
3
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+ from models.srcnn import SRCNN
10
+ from models.vdsr import VDSR
11
+ from models.edsr import EDSR
12
+ import math
13
+ from skimage.metrics import structural_similarity as ssim
14
+ import matplotlib.pyplot as plt
15
+
16
+
17
+ def calculate_psnr(img1, img2):
18
+ mse = torch.mean((img1 - img2) ** 2)
19
+ if mse == 0:
20
+ return float('inf')
21
+ return 20 * math.log10(1.0 / math.sqrt(mse.item()))
22
+
23
+ def process_image(model, lr_img):
24
+ with torch.no_grad():
25
+ # Convert to YCbCr and extract Y channel
26
+ ycbcr = lr_img.convert('YCbCr')
27
+ y, cb, cr = ycbcr.split()
28
+
29
+ # Transform Y channel
30
+ transform = transforms.Compose([transforms.ToTensor()])
31
+ input_tensor = transform(y).unsqueeze(0)
32
+
33
+ # Process through model
34
+ output = model(input_tensor)
35
+
36
+ # Post-process output
37
+ output = output.squeeze().clamp(0, 1).numpy()
38
+ output_y = Image.fromarray((output * 255).astype(np.uint8))
39
+
40
+ # Merge channels back
41
+ output_ycbcr = Image.merge('YCbCr', [output_y, cb, cr])
42
+ return output_ycbcr.convert('RGB')
43
+
44
+ def calculate_ssim(img1, img2):
45
+ # Move channel axis to the end for SSIM calculation
46
+ img1_np = img1.cpu().numpy().transpose(1, 2, 0)
47
+ img2_np = img2.cpu().numpy().transpose(1, 2, 0)
48
+ return ssim(img1_np, img2_np, data_range=1.0, channel_axis=2, win_size=7)
49
+
50
+ def evaluate_models(test_image_path):
51
+ # Load models
52
+ models = {
53
+ 'SRCNN': SRCNN(),
54
+ 'VDSR': VDSR(),
55
+ 'EDSR': EDSR()
56
+ }
57
+
58
+ # Load weights
59
+ for name, model in models.items():
60
+ model.load_state_dict(torch.load(f'checkpoints/{name.lower()}_best.pth', weights_only=True))
61
+ model.eval()
62
+
63
+ # Load test image
64
+ lr_img = Image.open(test_image_path)
65
+ hr_img = Image.open(test_image_path) # Using same image as reference
66
+
67
+ # Results dictionary
68
+ results = {model_name: {} for model_name in models.keys()}
69
+
70
+ # Process image with each model and calculate metrics
71
+ for name, model in models.items():
72
+ # Generate SR image
73
+ sr_img = process_image(model, lr_img)
74
+
75
+ # Convert images to tensors for metric calculation
76
+ transform = transforms.Compose([
77
+ transforms.Resize((256, 256)), # Ensure minimum size for SSIM
78
+ transforms.ToTensor()
79
+ ])
80
+
81
+ sr_tensor = transform(sr_img)
82
+ hr_tensor = transform(hr_img)
83
+
84
+ # Calculate metrics
85
+ results[name]['PSNR'] = calculate_psnr(sr_tensor, hr_tensor)
86
+ results[name]['SSIM'] = calculate_ssim(sr_tensor, hr_tensor)
87
+
88
+ # Save output images
89
+ sr_img.save(f'results/{name.lower()}_output.png')
90
+
91
+ # Display results
92
+ print("\nModel Performance Metrics:")
93
+ print("-" * 50)
94
+ print(f"{'Model':<10} {'PSNR (dB)':<15} {'SSIM':<15}")
95
+ print("-" * 50)
96
+
97
+ for model_name, metrics in results.items():
98
+ print(f"{model_name:<10} {metrics['PSNR']:<15.2f} {metrics['SSIM']:<15.4f}")
99
+
100
+ # Plot results
101
+ plt.figure(figsize=(12, 6))
102
+
103
+ # PSNR comparison
104
+ plt.subplot(1, 2, 1)
105
+ plt.bar(results.keys(), [m['PSNR'] for m in results.values()])
106
+ plt.title('PSNR Comparison')
107
+ plt.ylabel('PSNR (dB)')
108
+
109
+ # SSIM comparison
110
+ plt.subplot(1, 2, 2)
111
+ plt.bar(results.keys(), [m['SSIM'] for m in results.values()])
112
+ plt.title('SSIM Comparison')
113
+ plt.ylabel('SSIM')
114
+
115
+ plt.tight_layout()
116
+ plt.savefig('results/metrics_comparison.png')
117
+ plt.close()
118
+
119
+ if __name__ == "__main__":
120
+ import os
121
+ os.makedirs('results', exist_ok=True)
122
+ test_image_path = r"data\DIV2K_train_LR_bicubic_X4\DIV2K_train_LR_bicubic\X4\0001x4.png" # Replace with your test image path
123
+ evaluate_models(test_image_path)
models/__pycache__/edsr.cpython-312.pyc ADDED
Binary file (2.9 kB). View file
 
models/__pycache__/srcnn.cpython-312.pyc ADDED
Binary file (1.45 kB). View file
 
models/__pycache__/vdsr.cpython-312.pyc ADDED
Binary file (3.12 kB). View file
 
models/edsr.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/edsr.py
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class ResBlock(nn.Module):
6
+ def __init__(self, n_feats, kernel_size, bias=True, res_scale=1):
7
+ super(ResBlock, self).__init__()
8
+ m = []
9
+ for i in range(2):
10
+ m.append(nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2), bias=bias))
11
+ if i == 0:
12
+ m.append(nn.ReLU(True))
13
+ self.body = nn.Sequential(*m)
14
+ self.res_scale = res_scale
15
+
16
+ def forward(self, x):
17
+ res = self.body(x).mul(self.res_scale)
18
+ res += x
19
+ return res
20
+
21
+ class EDSR(nn.Module):
22
+ def __init__(self, n_resblocks=16, n_feats=64, scale=4):
23
+ super(EDSR, self).__init__()
24
+ kernel_size = 3
25
+ self.scale = scale
26
+
27
+ # Define head module
28
+ m_head = [nn.Conv2d(1, n_feats, kernel_size, padding=(kernel_size//2))]
29
+
30
+ # Define body module
31
+ m_body = [
32
+ ResBlock(n_feats, kernel_size) \
33
+ for _ in range(n_resblocks)
34
+ ]
35
+ m_body.append(nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2)))
36
+
37
+ # Define tail module
38
+ m_tail = [
39
+ nn.Conv2d(n_feats, 1, kernel_size, padding=(kernel_size//2))
40
+ ]
41
+
42
+ self.head = nn.Sequential(*m_head)
43
+ self.body = nn.Sequential(*m_body)
44
+ self.tail = nn.Sequential(*m_tail)
45
+
46
+ def forward(self, x):
47
+ x = self.head(x)
48
+ res = self.body(x)
49
+ res += x
50
+ x = self.tail(res)
51
+ return x
models/srcnn.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/srcnn.py
2
+ import torch.nn as nn
3
+
4
+ class SRCNN(nn.Module):
5
+ def __init__(self):
6
+ super(SRCNN, self).__init__()
7
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
8
+ self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
9
+ self.conv3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
10
+ self.relu = nn.ReLU(inplace=True)
11
+
12
+ def forward(self, x):
13
+ x = self.relu(self.conv1(x))
14
+ x = self.relu(self.conv2(x))
15
+ x = self.conv3(x)
16
+ return x
models/vdsr.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/vdsr.py
2
+ import torch.nn as nn
3
+ from math import sqrt
4
+
5
+ class Conv_ReLU_Block(nn.Module):
6
+ def __init__(self):
7
+ super(Conv_ReLU_Block, self).__init__()
8
+ self.conv = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)
9
+ self.relu = nn.ReLU(inplace=True)
10
+
11
+ def forward(self, x):
12
+ return self.relu(self.conv(x))
13
+
14
+ class VDSR(nn.Module):
15
+ def __init__(self):
16
+ super(VDSR, self).__init__()
17
+ self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
18
+ self.input = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
19
+ self.output = nn.Conv2d(64, 1, kernel_size=3, padding=1, bias=False)
20
+ self.relu = nn.ReLU(inplace=True)
21
+
22
+ # Initialize weights
23
+ for m in self.modules():
24
+ if isinstance(m, nn.Conv2d):
25
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
26
+ m.weight.data.normal_(0, sqrt(2. / n))
27
+
28
+ def make_layer(self, block, num_of_layer):
29
+ layers = []
30
+ for _ in range(num_of_layer):
31
+ layers.append(block())
32
+ return nn.Sequential(*layers)
33
+
34
+ def forward(self, x):
35
+ residual = x
36
+ out = self.relu(self.input(x))
37
+ out = self.residual_layer(out)
38
+ out = self.output(out)
39
+ return out + residual
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ numpy
4
+ torchvision
5
+
6
+ scikit-image
results/edsr_output.png ADDED
results/metrics_comparison.png ADDED
results/srcnn_output.png ADDED
results/vdsr_output.png ADDED
train.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ import os
3
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import DataLoader
8
+ from utils.dataset import DIV2KDataset
9
+ from models.srcnn import SRCNN
10
+ from models.vdsr import VDSR
11
+ from models.edsr import EDSR
12
+ import math
13
+ import numpy as np
14
+
15
+ class EarlyStopping:
16
+ def __init__(self, patience=7, min_delta=0.01, min_psnr_improvement=0.1):
17
+ self.patience = patience
18
+ self.min_delta = min_delta
19
+ self.min_psnr_improvement = min_psnr_improvement
20
+ self.counter = 0
21
+ self.best_loss = None
22
+ self.best_psnr = None
23
+ self.early_stop = False
24
+
25
+ def __call__(self, loss, psnr):
26
+ if self.best_loss is None:
27
+ self.best_loss = loss
28
+ self.best_psnr = psnr
29
+ elif (loss > self.best_loss - self.min_delta) and (psnr < self.best_psnr + self.min_psnr_improvement):
30
+ self.counter += 1
31
+ print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
32
+ if self.counter >= self.patience:
33
+ self.early_stop = True
34
+ else:
35
+ self.best_loss = min(loss, self.best_loss)
36
+ self.best_psnr = max(psnr, self.best_psnr)
37
+ self.counter = 0
38
+
39
+ def calculate_psnr(img1, img2):
40
+ mse = torch.mean((img1 - img2) ** 2)
41
+ if mse == 0:
42
+ return float('inf')
43
+ return 20 * math.log10(1.0 / math.sqrt(mse.item()))
44
+
45
+ def train_model(model_name, train_loader, val_loader, device, num_epochs=100):
46
+ # Initialize model
47
+ if model_name == 'srcnn':
48
+ model = SRCNN()
49
+ elif model_name == 'vdsr':
50
+ model = VDSR()
51
+ else:
52
+ model = EDSR()
53
+
54
+ model = model.to(device)
55
+ criterion = nn.MSELoss()
56
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
57
+
58
+ # Initialize early stopping
59
+ early_stopping = EarlyStopping(patience=10, min_delta=0.00001, min_psnr_improvement=0.1)
60
+ best_psnr = 0
61
+
62
+ for epoch in range(num_epochs):
63
+ # Training
64
+ model.train()
65
+ train_loss = 0
66
+ num_batches = 0
67
+
68
+ for batch_idx, (lr_img, hr_img) in enumerate(train_loader):
69
+ lr_img, hr_img = lr_img.to(device), hr_img.to(device)
70
+
71
+ optimizer.zero_grad()
72
+ output = model(lr_img)
73
+ loss = criterion(output, hr_img)
74
+ loss.backward()
75
+ optimizer.step()
76
+
77
+ train_loss += loss.item()
78
+ num_batches += 1
79
+
80
+ if batch_idx % 100 == 0:
81
+ print(f'Train Epoch: {epoch} [{batch_idx}/{len(train_loader)}]\tLoss: {loss.item():.6f}')
82
+
83
+ avg_train_loss = train_loss / num_batches
84
+
85
+ # Validation
86
+ model.eval()
87
+ val_psnr = 0
88
+ with torch.no_grad():
89
+ for lr_img, hr_img in val_loader:
90
+ lr_img, hr_img = lr_img.to(device), hr_img.to(device)
91
+ output = model(lr_img)
92
+ val_psnr += calculate_psnr(output, hr_img)
93
+
94
+ val_psnr /= len(val_loader)
95
+ print(f'Epoch: {epoch}, Average Loss: {avg_train_loss:.6f}, Average PSNR: {val_psnr:.2f}dB')
96
+
97
+ # Early stopping check
98
+ early_stopping(avg_train_loss, val_psnr)
99
+ if early_stopping.early_stop:
100
+ print(f"Early stopping triggered at epoch {epoch}")
101
+ break
102
+
103
+ # Save best model
104
+ if val_psnr > best_psnr:
105
+ best_psnr = val_psnr
106
+ torch.save(model.state_dict(), f'checkpoints/{model_name}_best.pth')
107
+ print(f'Saved new best model with PSNR: {best_psnr:.2f}dB')
108
+
109
+ def main():
110
+ # Setup
111
+ device = torch.device('cpu')
112
+
113
+ # Data paths
114
+ train_hr_dir = 'data/DIV2K_train_HR/DIV2K_train_HR/'
115
+ train_lr_dir = 'data/DIV2K_train_LR_bicubic_X4/DIV2K_train_LR_bicubic/X4'
116
+ val_hr_dir = 'data/DIV2K_valid_HR/DIV2K_valid_HR'
117
+ val_lr_dir = 'data/DIV2K_valid_LR_bicubic_X4/DIV2K_valid_LR_bicubic/X4'
118
+
119
+ # Create datasets
120
+ train_dataset = DIV2KDataset(train_hr_dir, train_lr_dir, patch_size=48)
121
+ val_dataset = DIV2KDataset(val_hr_dir, val_lr_dir, patch_size=48)
122
+
123
+ # Create dataloaders
124
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
125
+ val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
126
+
127
+ # Create checkpoints directory
128
+ os.makedirs('checkpoints', exist_ok=True)
129
+
130
+ # Train models
131
+ models = ['edsr']
132
+ for model_name in models:
133
+ print(f'Training {model_name.upper()}...')
134
+ train_model(model_name, train_loader, val_loader, device)
135
+
136
+ if __name__ == '__main__':
137
+ main()
utils/dataset.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/dataset.py
2
+ from torch.utils.data import Dataset
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import os
6
+
7
+ class DIV2KDataset(Dataset):
8
+ def __init__(self, hr_dir, lr_dir, patch_size=96, upscale_factor=4):
9
+ self.hr_files = sorted(os.listdir(hr_dir))
10
+ self.lr_files = sorted(os.listdir(lr_dir))
11
+ self.hr_dir = hr_dir
12
+ self.lr_dir = lr_dir
13
+ self.patch_size = patch_size
14
+ self.upscale_factor = upscale_factor
15
+
16
+ # LR transform
17
+ self.lr_transform = transforms.Compose([
18
+ transforms.Resize((patch_size//upscale_factor, patch_size//upscale_factor),
19
+ interpolation=transforms.InterpolationMode.BICUBIC),
20
+ transforms.ToTensor()
21
+ ])
22
+
23
+ # HR transform
24
+ self.hr_transform = transforms.Compose([
25
+ transforms.Resize((patch_size, patch_size),
26
+ interpolation=transforms.InterpolationMode.BICUBIC),
27
+ transforms.ToTensor()
28
+ ])
29
+
30
+ # Upscale LR images to match HR size for SRCNN
31
+ self.lr_upscale = transforms.Compose([
32
+ transforms.Resize((patch_size, patch_size),
33
+ interpolation=transforms.InterpolationMode.BICUBIC),
34
+ transforms.ToTensor()
35
+ ])
36
+
37
+ def __getitem__(self, idx):
38
+ # Load images and convert to YCbCr
39
+ hr_img = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('YCbCr')
40
+ lr_img = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('YCbCr')
41
+
42
+ # Extract Y channel
43
+ hr_y, _, _ = hr_img.split()
44
+ lr_y, _, _ = lr_img.split()
45
+
46
+ # For SRCNN, we need to upscale LR images first
47
+ lr_y_upscaled = self.lr_upscale(lr_y)
48
+ hr_y_tensor = self.hr_transform(hr_y)
49
+
50
+ return lr_y_upscaled, hr_y_tensor
51
+
52
+ def __len__(self):
53
+ return len(self.hr_files)