Gokuleshwaran
commited on
Commit
•
6221b96
1
Parent(s):
2f12290
First model version
Browse files- app.py +218 -0
- checkpoints/edsr_best.pth +3 -0
- checkpoints/srcnn_best.pth +3 -0
- checkpoints/vdsr_best.pth +3 -0
- inference.py +86 -0
- metrics.py +123 -0
- models/__pycache__/edsr.cpython-312.pyc +0 -0
- models/__pycache__/srcnn.cpython-312.pyc +0 -0
- models/__pycache__/vdsr.cpython-312.pyc +0 -0
- models/edsr.py +51 -0
- models/srcnn.py +16 -0
- models/vdsr.py +39 -0
- requirements.txt +6 -0
- results/edsr_output.png +0 -0
- results/metrics_comparison.png +0 -0
- results/srcnn_output.png +0 -0
- results/vdsr_output.png +0 -0
- train.py +137 -0
- utils/dataset.py +53 -0
app.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from PIL import Image, ImageOps, ImageEnhance, ImageFilter
|
6 |
+
import numpy as np
|
7 |
+
import time
|
8 |
+
import io
|
9 |
+
|
10 |
+
|
11 |
+
# Metrics imports
|
12 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr
|
13 |
+
from skimage.metrics import structural_similarity as ssim
|
14 |
+
|
15 |
+
# Model imports
|
16 |
+
from models.srcnn import SRCNN
|
17 |
+
from models.vdsr import VDSR
|
18 |
+
from models.edsr import EDSR
|
19 |
+
|
20 |
+
# Cache for loaded models
|
21 |
+
model_cache = {}
|
22 |
+
|
23 |
+
def load_model(model_name):
|
24 |
+
"""
|
25 |
+
Load super-resolution model with optional scale factor
|
26 |
+
|
27 |
+
Args:
|
28 |
+
model_name (str): Name of the model (SRCNN, VDSR, EDSR)
|
29 |
+
scale_factor (int): Upscaling factor (2, 3, or 4)
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
torch.nn.Module: Loaded model
|
33 |
+
"""
|
34 |
+
try:
|
35 |
+
# Check if model is already in the cache
|
36 |
+
if model_name in model_cache:
|
37 |
+
return model_cache[model_name]
|
38 |
+
|
39 |
+
if model_name == 'SRCNN':
|
40 |
+
model = SRCNN()
|
41 |
+
elif model_name == 'VDSR':
|
42 |
+
model = VDSR()
|
43 |
+
else:
|
44 |
+
model = EDSR()
|
45 |
+
|
46 |
+
# Load pre-trained weights if available
|
47 |
+
weight_path = f'checkpoints/{model_name.lower()}_best.pth'
|
48 |
+
if os.path.exists(weight_path):
|
49 |
+
model.load_state_dict(torch.load(weight_path, map_location=torch.device('cpu'), weights_only=True))
|
50 |
+
else:
|
51 |
+
st.warning(f"No pre-trained weights found for the {model_name} model. Using randomly initialized weights.")
|
52 |
+
|
53 |
+
model.eval()
|
54 |
+
|
55 |
+
# Cache the loaded model
|
56 |
+
model_cache[model_name] = model
|
57 |
+
return model
|
58 |
+
except Exception as e:
|
59 |
+
st.error(f"Error loading {model_name} model: {e}")
|
60 |
+
return None
|
61 |
+
|
62 |
+
def process_image(image, model):
|
63 |
+
# Convert to YCbCr and extract Y channel
|
64 |
+
ycbcr = image.convert('YCbCr')
|
65 |
+
y, cb, cr = ycbcr.split()
|
66 |
+
|
67 |
+
# Transform Y channel
|
68 |
+
transform = transforms.Compose([
|
69 |
+
transforms.ToTensor()
|
70 |
+
])
|
71 |
+
|
72 |
+
input_tensor = transform(y).unsqueeze(0)
|
73 |
+
|
74 |
+
# Process through model
|
75 |
+
with torch.no_grad():
|
76 |
+
output = model(input_tensor)
|
77 |
+
|
78 |
+
# Post-process output
|
79 |
+
output = output.squeeze().clamp(0, 1).numpy()
|
80 |
+
output_y = Image.fromarray((output * 255).astype(np.uint8))
|
81 |
+
|
82 |
+
# Merge channels back
|
83 |
+
output_ycbcr = Image.merge('YCbCr', [output_y, cb, cr])
|
84 |
+
output_rgb = output_ycbcr.convert('RGB')
|
85 |
+
|
86 |
+
return output_rgb
|
87 |
+
|
88 |
+
def calculate_image_metrics(original, enhanced):
|
89 |
+
"""
|
90 |
+
Calculate image quality metrics
|
91 |
+
|
92 |
+
Args:
|
93 |
+
original (np.ndarray): Original image
|
94 |
+
enhanced (np.ndarray): Enhanced image
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
dict: Quality metrics
|
98 |
+
"""
|
99 |
+
try:
|
100 |
+
# Ensure images are the same size
|
101 |
+
min_height = min(original.shape[0], enhanced.shape[0])
|
102 |
+
min_width = min(original.shape[1], enhanced.shape[1])
|
103 |
+
|
104 |
+
# Resize images to the smallest common size
|
105 |
+
original = original[:min_height, :min_width]
|
106 |
+
enhanced = enhanced[:min_height, :min_width]
|
107 |
+
|
108 |
+
# Calculate SSIM with an explicit window size
|
109 |
+
win_size = min(7, min(min_height, min_width))
|
110 |
+
if win_size % 2 == 0:
|
111 |
+
win_size -= 1 # Ensure odd window size
|
112 |
+
|
113 |
+
return {
|
114 |
+
'PSNR': psnr(original, enhanced),
|
115 |
+
'SSIM': ssim(original, enhanced, multichannel=True, win_size=win_size, channel_axis=-1)
|
116 |
+
}
|
117 |
+
except Exception as e:
|
118 |
+
st.error(f"Error calculating metrics: {e}")
|
119 |
+
return {'PSNR': 0, 'SSIM': 0}
|
120 |
+
|
121 |
+
def main():
|
122 |
+
st.set_page_config(
|
123 |
+
page_title="Super Resolution Comparison",
|
124 |
+
page_icon="🖼️",
|
125 |
+
layout="wide"
|
126 |
+
)
|
127 |
+
|
128 |
+
st.title("🚀 Super Resolution Model Comparison")
|
129 |
+
st.write("Upload a low-resolution image and compare different super-resolution models.")
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
# File Upload
|
135 |
+
uploaded_file = st.file_uploader(
|
136 |
+
"Choose an image",
|
137 |
+
type=['png', 'jpg', 'jpeg'],
|
138 |
+
help="Upload a low-resolution image for enhancement"
|
139 |
+
)
|
140 |
+
|
141 |
+
if uploaded_file is not None:
|
142 |
+
# Load input image
|
143 |
+
input_image = Image.open(uploaded_file)
|
144 |
+
input_array = np.array(input_image)
|
145 |
+
|
146 |
+
st.subheader("📸 Original Image")
|
147 |
+
st.image(input_image, caption="Low-Resolution Input", use_column_width=True)
|
148 |
+
|
149 |
+
# Model Names
|
150 |
+
model_names = ['SRCNN', 'VDSR', 'EDSR']
|
151 |
+
|
152 |
+
# Performance and Quality Storage
|
153 |
+
processing_times = {}
|
154 |
+
quality_metrics = {}
|
155 |
+
enhanced_images = {}
|
156 |
+
|
157 |
+
# Process images
|
158 |
+
columns = st.columns(len(model_names))
|
159 |
+
for i, model_name in enumerate(model_names):
|
160 |
+
with columns[i]:
|
161 |
+
st.subheader(f"{model_name} Model")
|
162 |
+
|
163 |
+
# Load model
|
164 |
+
model = load_model(model_name)
|
165 |
+
|
166 |
+
if model:
|
167 |
+
# Time the processing
|
168 |
+
start_time = time.time()
|
169 |
+
enhanced_image = process_image(input_image, model)
|
170 |
+
processing_time = time.time() - start_time
|
171 |
+
|
172 |
+
if enhanced_image:
|
173 |
+
# Display enhanced image
|
174 |
+
st.image(enhanced_image, caption=f"{model_name} Output", use_column_width=True)
|
175 |
+
|
176 |
+
# Calculate metrics
|
177 |
+
enhanced_array = np.array(enhanced_image)
|
178 |
+
metrics = calculate_image_metrics(input_array, enhanced_array)
|
179 |
+
|
180 |
+
# Store results
|
181 |
+
processing_times[model_name] = processing_time
|
182 |
+
quality_metrics[model_name] = metrics
|
183 |
+
enhanced_images[model_name] = enhanced_image
|
184 |
+
|
185 |
+
# Performance Metrics Section
|
186 |
+
st.subheader("📊 Performance Metrics")
|
187 |
+
metric_cols = st.columns(len(model_names))
|
188 |
+
|
189 |
+
for i, (model, time_val) in enumerate(processing_times.items()):
|
190 |
+
with metric_cols[i]:
|
191 |
+
st.metric(f"{model} Processing Time", f"{time_val:.4f} seconds")
|
192 |
+
|
193 |
+
# Quality Metrics Section
|
194 |
+
st.subheader("🔍 Image Quality Assessment")
|
195 |
+
quality_cols = st.columns(len(model_names))
|
196 |
+
|
197 |
+
for i, (model, metrics) in enumerate(quality_metrics.items()):
|
198 |
+
with quality_cols[i]:
|
199 |
+
st.metric(f"{model} PSNR", f"{metrics['PSNR']:.2f} dB")
|
200 |
+
st.metric(f"{model} SSIM", f"{metrics['SSIM']:.4f}")
|
201 |
+
|
202 |
+
# Download Section
|
203 |
+
st.subheader("💾 Download Enhanced Images")
|
204 |
+
download_cols = st.columns(len(model_names))
|
205 |
+
|
206 |
+
for i, (model, image) in enumerate(enhanced_images.items()):
|
207 |
+
with download_cols[i]:
|
208 |
+
buffered = io.BytesIO()
|
209 |
+
image.save(buffered, format="PNG")
|
210 |
+
st.download_button(
|
211 |
+
label=f"Download {model} Image",
|
212 |
+
data=buffered.getvalue(),
|
213 |
+
file_name=f"{model}_enhanced.png",
|
214 |
+
mime="image/png"
|
215 |
+
)
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
main()
|
checkpoints/edsr_best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3f1a4bd666d2537fdd371c2eac1a33c41fa890a577401f8a05ee9c1c73c6ea9
|
3 |
+
size 4904536
|
checkpoints/srcnn_best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b15f698a761b03d40decbf66ee0ac3a2496b61f396427b8e181c236c132c356d
|
3 |
+
size 35238
|
checkpoints/vdsr_best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec5a3872a48aea983e0d5f0349d81d3b44bd28b5bd0ba29cfa169478c663c325
|
3 |
+
size 2667978
|
inference.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# inference.py
|
2 |
+
import os
|
3 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
4 |
+
import streamlit as st
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
from models.srcnn import SRCNN
|
10 |
+
from models.vdsr import VDSR
|
11 |
+
from models.edsr import EDSR
|
12 |
+
|
13 |
+
def load_model(model_name):
|
14 |
+
if model_name == 'SRCNN':
|
15 |
+
model = SRCNN()
|
16 |
+
elif model_name == 'VDSR':
|
17 |
+
model = VDSR()
|
18 |
+
else:
|
19 |
+
model = EDSR()
|
20 |
+
|
21 |
+
model.load_state_dict(torch.load(f'checkpoints/{model_name.lower()}_best.pth', map_location=torch.device('cpu')))
|
22 |
+
model.eval()
|
23 |
+
return model
|
24 |
+
|
25 |
+
def process_image(image, model):
|
26 |
+
# Convert to YCbCr and extract Y channel
|
27 |
+
ycbcr = image.convert('YCbCr')
|
28 |
+
y, cb, cr = ycbcr.split()
|
29 |
+
|
30 |
+
# Transform Y channel
|
31 |
+
transform = transforms.Compose([
|
32 |
+
transforms.ToTensor()
|
33 |
+
])
|
34 |
+
|
35 |
+
input_tensor = transform(y).unsqueeze(0)
|
36 |
+
|
37 |
+
# Process through model
|
38 |
+
with torch.no_grad():
|
39 |
+
output = model(input_tensor)
|
40 |
+
|
41 |
+
# Post-process output
|
42 |
+
output = output.squeeze().clamp(0, 1).numpy()
|
43 |
+
output_y = Image.fromarray((output * 255).astype(np.uint8))
|
44 |
+
|
45 |
+
# Merge channels back
|
46 |
+
output_ycbcr = Image.merge('YCbCr', [output_y, cb, cr])
|
47 |
+
output_rgb = output_ycbcr.convert('RGB')
|
48 |
+
|
49 |
+
return output_rgb
|
50 |
+
|
51 |
+
def main():
|
52 |
+
st.title("Super Resolution Model Comparison")
|
53 |
+
st.write("Upload a low-resolution image to compare SRCNN, VDSR, and EDSR models")
|
54 |
+
|
55 |
+
# File uploader
|
56 |
+
uploaded_file = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'])
|
57 |
+
|
58 |
+
if uploaded_file is not None:
|
59 |
+
# Load and display input image
|
60 |
+
input_image = Image.open(uploaded_file)
|
61 |
+
st.subheader("Input Image")
|
62 |
+
st.image(input_image, caption="Original Image")
|
63 |
+
|
64 |
+
# Process with each model
|
65 |
+
col1, col2, col3 = st.columns(3)
|
66 |
+
|
67 |
+
with col1:
|
68 |
+
st.subheader("SRCNN")
|
69 |
+
model = load_model('SRCNN')
|
70 |
+
srcnn_output = process_image(input_image, model)
|
71 |
+
st.image(srcnn_output, caption="SRCNN Output")
|
72 |
+
|
73 |
+
with col2:
|
74 |
+
st.subheader("VDSR")
|
75 |
+
model = load_model('VDSR')
|
76 |
+
vdsr_output = process_image(input_image, model)
|
77 |
+
st.image(vdsr_output, caption="VDSR Output")
|
78 |
+
|
79 |
+
with col3:
|
80 |
+
st.subheader("EDSR")
|
81 |
+
model = load_model('EDSR')
|
82 |
+
edsr_output = process_image(input_image, model)
|
83 |
+
st.image(edsr_output, caption="EDSR Output")
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
main()
|
metrics.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# metrics.py
|
2 |
+
import os
|
3 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from models.srcnn import SRCNN
|
10 |
+
from models.vdsr import VDSR
|
11 |
+
from models.edsr import EDSR
|
12 |
+
import math
|
13 |
+
from skimage.metrics import structural_similarity as ssim
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
|
16 |
+
|
17 |
+
def calculate_psnr(img1, img2):
|
18 |
+
mse = torch.mean((img1 - img2) ** 2)
|
19 |
+
if mse == 0:
|
20 |
+
return float('inf')
|
21 |
+
return 20 * math.log10(1.0 / math.sqrt(mse.item()))
|
22 |
+
|
23 |
+
def process_image(model, lr_img):
|
24 |
+
with torch.no_grad():
|
25 |
+
# Convert to YCbCr and extract Y channel
|
26 |
+
ycbcr = lr_img.convert('YCbCr')
|
27 |
+
y, cb, cr = ycbcr.split()
|
28 |
+
|
29 |
+
# Transform Y channel
|
30 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
31 |
+
input_tensor = transform(y).unsqueeze(0)
|
32 |
+
|
33 |
+
# Process through model
|
34 |
+
output = model(input_tensor)
|
35 |
+
|
36 |
+
# Post-process output
|
37 |
+
output = output.squeeze().clamp(0, 1).numpy()
|
38 |
+
output_y = Image.fromarray((output * 255).astype(np.uint8))
|
39 |
+
|
40 |
+
# Merge channels back
|
41 |
+
output_ycbcr = Image.merge('YCbCr', [output_y, cb, cr])
|
42 |
+
return output_ycbcr.convert('RGB')
|
43 |
+
|
44 |
+
def calculate_ssim(img1, img2):
|
45 |
+
# Move channel axis to the end for SSIM calculation
|
46 |
+
img1_np = img1.cpu().numpy().transpose(1, 2, 0)
|
47 |
+
img2_np = img2.cpu().numpy().transpose(1, 2, 0)
|
48 |
+
return ssim(img1_np, img2_np, data_range=1.0, channel_axis=2, win_size=7)
|
49 |
+
|
50 |
+
def evaluate_models(test_image_path):
|
51 |
+
# Load models
|
52 |
+
models = {
|
53 |
+
'SRCNN': SRCNN(),
|
54 |
+
'VDSR': VDSR(),
|
55 |
+
'EDSR': EDSR()
|
56 |
+
}
|
57 |
+
|
58 |
+
# Load weights
|
59 |
+
for name, model in models.items():
|
60 |
+
model.load_state_dict(torch.load(f'checkpoints/{name.lower()}_best.pth', weights_only=True))
|
61 |
+
model.eval()
|
62 |
+
|
63 |
+
# Load test image
|
64 |
+
lr_img = Image.open(test_image_path)
|
65 |
+
hr_img = Image.open(test_image_path) # Using same image as reference
|
66 |
+
|
67 |
+
# Results dictionary
|
68 |
+
results = {model_name: {} for model_name in models.keys()}
|
69 |
+
|
70 |
+
# Process image with each model and calculate metrics
|
71 |
+
for name, model in models.items():
|
72 |
+
# Generate SR image
|
73 |
+
sr_img = process_image(model, lr_img)
|
74 |
+
|
75 |
+
# Convert images to tensors for metric calculation
|
76 |
+
transform = transforms.Compose([
|
77 |
+
transforms.Resize((256, 256)), # Ensure minimum size for SSIM
|
78 |
+
transforms.ToTensor()
|
79 |
+
])
|
80 |
+
|
81 |
+
sr_tensor = transform(sr_img)
|
82 |
+
hr_tensor = transform(hr_img)
|
83 |
+
|
84 |
+
# Calculate metrics
|
85 |
+
results[name]['PSNR'] = calculate_psnr(sr_tensor, hr_tensor)
|
86 |
+
results[name]['SSIM'] = calculate_ssim(sr_tensor, hr_tensor)
|
87 |
+
|
88 |
+
# Save output images
|
89 |
+
sr_img.save(f'results/{name.lower()}_output.png')
|
90 |
+
|
91 |
+
# Display results
|
92 |
+
print("\nModel Performance Metrics:")
|
93 |
+
print("-" * 50)
|
94 |
+
print(f"{'Model':<10} {'PSNR (dB)':<15} {'SSIM':<15}")
|
95 |
+
print("-" * 50)
|
96 |
+
|
97 |
+
for model_name, metrics in results.items():
|
98 |
+
print(f"{model_name:<10} {metrics['PSNR']:<15.2f} {metrics['SSIM']:<15.4f}")
|
99 |
+
|
100 |
+
# Plot results
|
101 |
+
plt.figure(figsize=(12, 6))
|
102 |
+
|
103 |
+
# PSNR comparison
|
104 |
+
plt.subplot(1, 2, 1)
|
105 |
+
plt.bar(results.keys(), [m['PSNR'] for m in results.values()])
|
106 |
+
plt.title('PSNR Comparison')
|
107 |
+
plt.ylabel('PSNR (dB)')
|
108 |
+
|
109 |
+
# SSIM comparison
|
110 |
+
plt.subplot(1, 2, 2)
|
111 |
+
plt.bar(results.keys(), [m['SSIM'] for m in results.values()])
|
112 |
+
plt.title('SSIM Comparison')
|
113 |
+
plt.ylabel('SSIM')
|
114 |
+
|
115 |
+
plt.tight_layout()
|
116 |
+
plt.savefig('results/metrics_comparison.png')
|
117 |
+
plt.close()
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
import os
|
121 |
+
os.makedirs('results', exist_ok=True)
|
122 |
+
test_image_path = r"data\DIV2K_train_LR_bicubic_X4\DIV2K_train_LR_bicubic\X4\0001x4.png" # Replace with your test image path
|
123 |
+
evaluate_models(test_image_path)
|
models/__pycache__/edsr.cpython-312.pyc
ADDED
Binary file (2.9 kB). View file
|
|
models/__pycache__/srcnn.cpython-312.pyc
ADDED
Binary file (1.45 kB). View file
|
|
models/__pycache__/vdsr.cpython-312.pyc
ADDED
Binary file (3.12 kB). View file
|
|
models/edsr.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/edsr.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class ResBlock(nn.Module):
|
6 |
+
def __init__(self, n_feats, kernel_size, bias=True, res_scale=1):
|
7 |
+
super(ResBlock, self).__init__()
|
8 |
+
m = []
|
9 |
+
for i in range(2):
|
10 |
+
m.append(nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2), bias=bias))
|
11 |
+
if i == 0:
|
12 |
+
m.append(nn.ReLU(True))
|
13 |
+
self.body = nn.Sequential(*m)
|
14 |
+
self.res_scale = res_scale
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
res = self.body(x).mul(self.res_scale)
|
18 |
+
res += x
|
19 |
+
return res
|
20 |
+
|
21 |
+
class EDSR(nn.Module):
|
22 |
+
def __init__(self, n_resblocks=16, n_feats=64, scale=4):
|
23 |
+
super(EDSR, self).__init__()
|
24 |
+
kernel_size = 3
|
25 |
+
self.scale = scale
|
26 |
+
|
27 |
+
# Define head module
|
28 |
+
m_head = [nn.Conv2d(1, n_feats, kernel_size, padding=(kernel_size//2))]
|
29 |
+
|
30 |
+
# Define body module
|
31 |
+
m_body = [
|
32 |
+
ResBlock(n_feats, kernel_size) \
|
33 |
+
for _ in range(n_resblocks)
|
34 |
+
]
|
35 |
+
m_body.append(nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2)))
|
36 |
+
|
37 |
+
# Define tail module
|
38 |
+
m_tail = [
|
39 |
+
nn.Conv2d(n_feats, 1, kernel_size, padding=(kernel_size//2))
|
40 |
+
]
|
41 |
+
|
42 |
+
self.head = nn.Sequential(*m_head)
|
43 |
+
self.body = nn.Sequential(*m_body)
|
44 |
+
self.tail = nn.Sequential(*m_tail)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.head(x)
|
48 |
+
res = self.body(x)
|
49 |
+
res += x
|
50 |
+
x = self.tail(res)
|
51 |
+
return x
|
models/srcnn.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/srcnn.py
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class SRCNN(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super(SRCNN, self).__init__()
|
7 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
|
8 |
+
self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
|
9 |
+
self.conv3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
|
10 |
+
self.relu = nn.ReLU(inplace=True)
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
x = self.relu(self.conv1(x))
|
14 |
+
x = self.relu(self.conv2(x))
|
15 |
+
x = self.conv3(x)
|
16 |
+
return x
|
models/vdsr.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/vdsr.py
|
2 |
+
import torch.nn as nn
|
3 |
+
from math import sqrt
|
4 |
+
|
5 |
+
class Conv_ReLU_Block(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(Conv_ReLU_Block, self).__init__()
|
8 |
+
self.conv = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)
|
9 |
+
self.relu = nn.ReLU(inplace=True)
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
return self.relu(self.conv(x))
|
13 |
+
|
14 |
+
class VDSR(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(VDSR, self).__init__()
|
17 |
+
self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
|
18 |
+
self.input = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
19 |
+
self.output = nn.Conv2d(64, 1, kernel_size=3, padding=1, bias=False)
|
20 |
+
self.relu = nn.ReLU(inplace=True)
|
21 |
+
|
22 |
+
# Initialize weights
|
23 |
+
for m in self.modules():
|
24 |
+
if isinstance(m, nn.Conv2d):
|
25 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
26 |
+
m.weight.data.normal_(0, sqrt(2. / n))
|
27 |
+
|
28 |
+
def make_layer(self, block, num_of_layer):
|
29 |
+
layers = []
|
30 |
+
for _ in range(num_of_layer):
|
31 |
+
layers.append(block())
|
32 |
+
return nn.Sequential(*layers)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
residual = x
|
36 |
+
out = self.relu(self.input(x))
|
37 |
+
out = self.residual_layer(out)
|
38 |
+
out = self.output(out)
|
39 |
+
return out + residual
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
numpy
|
4 |
+
torchvision
|
5 |
+
|
6 |
+
scikit-image
|
results/edsr_output.png
ADDED
results/metrics_comparison.png
ADDED
results/srcnn_output.png
ADDED
results/vdsr_output.png
ADDED
train.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train.py
|
2 |
+
import os
|
3 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from utils.dataset import DIV2KDataset
|
9 |
+
from models.srcnn import SRCNN
|
10 |
+
from models.vdsr import VDSR
|
11 |
+
from models.edsr import EDSR
|
12 |
+
import math
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
class EarlyStopping:
|
16 |
+
def __init__(self, patience=7, min_delta=0.01, min_psnr_improvement=0.1):
|
17 |
+
self.patience = patience
|
18 |
+
self.min_delta = min_delta
|
19 |
+
self.min_psnr_improvement = min_psnr_improvement
|
20 |
+
self.counter = 0
|
21 |
+
self.best_loss = None
|
22 |
+
self.best_psnr = None
|
23 |
+
self.early_stop = False
|
24 |
+
|
25 |
+
def __call__(self, loss, psnr):
|
26 |
+
if self.best_loss is None:
|
27 |
+
self.best_loss = loss
|
28 |
+
self.best_psnr = psnr
|
29 |
+
elif (loss > self.best_loss - self.min_delta) and (psnr < self.best_psnr + self.min_psnr_improvement):
|
30 |
+
self.counter += 1
|
31 |
+
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
32 |
+
if self.counter >= self.patience:
|
33 |
+
self.early_stop = True
|
34 |
+
else:
|
35 |
+
self.best_loss = min(loss, self.best_loss)
|
36 |
+
self.best_psnr = max(psnr, self.best_psnr)
|
37 |
+
self.counter = 0
|
38 |
+
|
39 |
+
def calculate_psnr(img1, img2):
|
40 |
+
mse = torch.mean((img1 - img2) ** 2)
|
41 |
+
if mse == 0:
|
42 |
+
return float('inf')
|
43 |
+
return 20 * math.log10(1.0 / math.sqrt(mse.item()))
|
44 |
+
|
45 |
+
def train_model(model_name, train_loader, val_loader, device, num_epochs=100):
|
46 |
+
# Initialize model
|
47 |
+
if model_name == 'srcnn':
|
48 |
+
model = SRCNN()
|
49 |
+
elif model_name == 'vdsr':
|
50 |
+
model = VDSR()
|
51 |
+
else:
|
52 |
+
model = EDSR()
|
53 |
+
|
54 |
+
model = model.to(device)
|
55 |
+
criterion = nn.MSELoss()
|
56 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
57 |
+
|
58 |
+
# Initialize early stopping
|
59 |
+
early_stopping = EarlyStopping(patience=10, min_delta=0.00001, min_psnr_improvement=0.1)
|
60 |
+
best_psnr = 0
|
61 |
+
|
62 |
+
for epoch in range(num_epochs):
|
63 |
+
# Training
|
64 |
+
model.train()
|
65 |
+
train_loss = 0
|
66 |
+
num_batches = 0
|
67 |
+
|
68 |
+
for batch_idx, (lr_img, hr_img) in enumerate(train_loader):
|
69 |
+
lr_img, hr_img = lr_img.to(device), hr_img.to(device)
|
70 |
+
|
71 |
+
optimizer.zero_grad()
|
72 |
+
output = model(lr_img)
|
73 |
+
loss = criterion(output, hr_img)
|
74 |
+
loss.backward()
|
75 |
+
optimizer.step()
|
76 |
+
|
77 |
+
train_loss += loss.item()
|
78 |
+
num_batches += 1
|
79 |
+
|
80 |
+
if batch_idx % 100 == 0:
|
81 |
+
print(f'Train Epoch: {epoch} [{batch_idx}/{len(train_loader)}]\tLoss: {loss.item():.6f}')
|
82 |
+
|
83 |
+
avg_train_loss = train_loss / num_batches
|
84 |
+
|
85 |
+
# Validation
|
86 |
+
model.eval()
|
87 |
+
val_psnr = 0
|
88 |
+
with torch.no_grad():
|
89 |
+
for lr_img, hr_img in val_loader:
|
90 |
+
lr_img, hr_img = lr_img.to(device), hr_img.to(device)
|
91 |
+
output = model(lr_img)
|
92 |
+
val_psnr += calculate_psnr(output, hr_img)
|
93 |
+
|
94 |
+
val_psnr /= len(val_loader)
|
95 |
+
print(f'Epoch: {epoch}, Average Loss: {avg_train_loss:.6f}, Average PSNR: {val_psnr:.2f}dB')
|
96 |
+
|
97 |
+
# Early stopping check
|
98 |
+
early_stopping(avg_train_loss, val_psnr)
|
99 |
+
if early_stopping.early_stop:
|
100 |
+
print(f"Early stopping triggered at epoch {epoch}")
|
101 |
+
break
|
102 |
+
|
103 |
+
# Save best model
|
104 |
+
if val_psnr > best_psnr:
|
105 |
+
best_psnr = val_psnr
|
106 |
+
torch.save(model.state_dict(), f'checkpoints/{model_name}_best.pth')
|
107 |
+
print(f'Saved new best model with PSNR: {best_psnr:.2f}dB')
|
108 |
+
|
109 |
+
def main():
|
110 |
+
# Setup
|
111 |
+
device = torch.device('cpu')
|
112 |
+
|
113 |
+
# Data paths
|
114 |
+
train_hr_dir = 'data/DIV2K_train_HR/DIV2K_train_HR/'
|
115 |
+
train_lr_dir = 'data/DIV2K_train_LR_bicubic_X4/DIV2K_train_LR_bicubic/X4'
|
116 |
+
val_hr_dir = 'data/DIV2K_valid_HR/DIV2K_valid_HR'
|
117 |
+
val_lr_dir = 'data/DIV2K_valid_LR_bicubic_X4/DIV2K_valid_LR_bicubic/X4'
|
118 |
+
|
119 |
+
# Create datasets
|
120 |
+
train_dataset = DIV2KDataset(train_hr_dir, train_lr_dir, patch_size=48)
|
121 |
+
val_dataset = DIV2KDataset(val_hr_dir, val_lr_dir, patch_size=48)
|
122 |
+
|
123 |
+
# Create dataloaders
|
124 |
+
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
|
125 |
+
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
|
126 |
+
|
127 |
+
# Create checkpoints directory
|
128 |
+
os.makedirs('checkpoints', exist_ok=True)
|
129 |
+
|
130 |
+
# Train models
|
131 |
+
models = ['edsr']
|
132 |
+
for model_name in models:
|
133 |
+
print(f'Training {model_name.upper()}...')
|
134 |
+
train_model(model_name, train_loader, val_loader, device)
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
main()
|
utils/dataset.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/dataset.py
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from PIL import Image
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import os
|
6 |
+
|
7 |
+
class DIV2KDataset(Dataset):
|
8 |
+
def __init__(self, hr_dir, lr_dir, patch_size=96, upscale_factor=4):
|
9 |
+
self.hr_files = sorted(os.listdir(hr_dir))
|
10 |
+
self.lr_files = sorted(os.listdir(lr_dir))
|
11 |
+
self.hr_dir = hr_dir
|
12 |
+
self.lr_dir = lr_dir
|
13 |
+
self.patch_size = patch_size
|
14 |
+
self.upscale_factor = upscale_factor
|
15 |
+
|
16 |
+
# LR transform
|
17 |
+
self.lr_transform = transforms.Compose([
|
18 |
+
transforms.Resize((patch_size//upscale_factor, patch_size//upscale_factor),
|
19 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
20 |
+
transforms.ToTensor()
|
21 |
+
])
|
22 |
+
|
23 |
+
# HR transform
|
24 |
+
self.hr_transform = transforms.Compose([
|
25 |
+
transforms.Resize((patch_size, patch_size),
|
26 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
27 |
+
transforms.ToTensor()
|
28 |
+
])
|
29 |
+
|
30 |
+
# Upscale LR images to match HR size for SRCNN
|
31 |
+
self.lr_upscale = transforms.Compose([
|
32 |
+
transforms.Resize((patch_size, patch_size),
|
33 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
34 |
+
transforms.ToTensor()
|
35 |
+
])
|
36 |
+
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
# Load images and convert to YCbCr
|
39 |
+
hr_img = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('YCbCr')
|
40 |
+
lr_img = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('YCbCr')
|
41 |
+
|
42 |
+
# Extract Y channel
|
43 |
+
hr_y, _, _ = hr_img.split()
|
44 |
+
lr_y, _, _ = lr_img.split()
|
45 |
+
|
46 |
+
# For SRCNN, we need to upscale LR images first
|
47 |
+
lr_y_upscaled = self.lr_upscale(lr_y)
|
48 |
+
hr_y_tensor = self.hr_transform(hr_y)
|
49 |
+
|
50 |
+
return lr_y_upscaled, hr_y_tensor
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.hr_files)
|