File size: 7,986 Bytes
314a753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import gradio as gr
import os.path
import numpy as np
from collections import OrderedDict
import torch
import cv2
from PIL import Image, ImageOps
import utils_image as util
from network_fbcnn import FBCNN as net
import requests
import datetime
from gradio_imageslider import ImageSlider

current_output = None
for model_path in ['fbcnn_gray.pth','fbcnn_color.pth']:
    if os.path.exists(model_path):
        print(f'{model_path} exists.')
    else:
        print("downloading model")
        url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
        r = requests.get(url, allow_redirects=True)
        open(model_path, 'wb').write(r.content)    

def inference(input_img, is_gray, input_quality, zoom, x_shift, y_shift):
    
    print("datetime:", datetime.datetime.utcnow())
    input_img_width, input_img_height = Image.fromarray(input_img).size
    print("img size:", (input_img_width, input_img_height))
    
    if (input_img_width > 1080) or (input_img_height > 1080):
        resize_ratio = min(1080/input_img_width, 1080/input_img_height)
        resized_input = Image.fromarray(input_img).resize(
            (int(input_img_width*resize_ratio) + (input_img_width*resize_ratio < 1),
             int(input_img_height*resize_ratio) + (input_img_height*resize_ratio < 1)),
            resample=Image.BICUBIC)
        input_img = np.array(resized_input)
        print("input image resized to:", resized_input.size)

    if is_gray:
        n_channels = 1
        model_name = 'fbcnn_gray.pth'
    else:
        n_channels = 3
        model_name = 'fbcnn_color.pth'
    nc = [64,128,256,512]
    nb = 4

    input_quality = 100 - input_quality

    model_path = model_name

    if os.path.exists(model_path):
        print(f'{model_path} already exists.')
    else:
        print("downloading model")
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
        r = requests.get(url, allow_redirects=True)
        open(model_path, 'wb').write(r.content)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("device:", device)

    print(f'loading model from {model_path}')
    
    model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
    print("#model.load_state_dict(torch.load(model_path), strict=True)")
    model.load_state_dict(torch.load(model_path), strict=True)
    print("#model.eval()")
    model.eval()
    print("#for k, v in model.named_parameters()")
    for k, v in model.named_parameters():
        v.requires_grad = False
    print("#model.to(device)")
    model = model.to(device)
    print("Model loaded.")

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnrb'] = []

    print("#if n_channels")
    if n_channels == 1:
        open_cv_image = Image.fromarray(input_img)
        open_cv_image = ImageOps.grayscale(open_cv_image)
        open_cv_image = np.array(open_cv_image)
        img = np.expand_dims(open_cv_image, axis=2)
    elif n_channels == 3:
        open_cv_image = np.array(input_img)
        if open_cv_image.ndim == 2:
            open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB)
        else:
            open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)

    print("#util.uint2tensor4(open_cv_image)")
    img_L = util.uint2tensor4(open_cv_image)
    
    print("#img_L.to(device)")
    img_L = img_L.to(device)

    print("#model(img_L)")
    img_E, QF = model(img_L)
    print("#util.tensor2single(img_E)")
    img_E = util.tensor2single(img_E)
    print("#util.single2uint(img_E)")
    img_E = util.single2uint(img_E)
    
    print("#torch.tensor([[1-input_quality/100]]).cuda() || torch.tensor([[1-input_quality/100]])")
    qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
    print("#util.single2uint(img_E)")
    img_E, QF = model(img_L, qf_input)  

    print("#util.tensor2single(img_E)")
    img_E = util.tensor2single(img_E)
    print("#util.single2uint(img_E)")
    img_E = util.single2uint(img_E)

    if img_E.ndim == 3:
        img_E = img_E[:, :, [2, 1, 0]]

    global current_output
    current_output = img_E.copy()
    print("--inference finished")

    (in_img, out_img) = zoom_image(zoom, x_shift, y_shift, input_img, img_E)
    print("--generating preview finished")
    
    return img_E, (in_img, out_img)

def zoom_image(zoom, x_shift, y_shift, input_img, output_img = None):   
    global current_output
    if output_img is None:
        if current_output is None:
            return None
        output_img = current_output
    
    img = Image.fromarray(input_img)
    out_img = Image.fromarray(output_img)
    
    img_w, img_h = img.size
    zoom_factor = (100 - zoom) / 100
    x_shift /= 100
    y_shift /= 100
    
    zoom_w, zoom_h = int(img_w * zoom_factor), int(img_h * zoom_factor)
    x_offset = int((img_w - zoom_w) * x_shift)
    y_offset = int((img_h - zoom_h) * y_shift)
    
    crop_box = (x_offset, y_offset, x_offset + zoom_w, y_offset + zoom_h)
    img = img.crop(crop_box).resize((img_w, img_h), Image.BILINEAR)
    out_img = out_img.crop(crop_box).resize((img_w, img_h), Image.BILINEAR)

    return (img, out_img)
    
with gr.Blocks() as demo:
    gr.Markdown("# JPEG Artifacts Removal [FBCNN]")

    with gr.Row():
        input_img = gr.Image(label="Input Image")
        output_img = gr.Image(label="Result")
    
    is_gray = gr.Checkbox(label="Grayscale (Check this if your image is grayscale)")
    input_quality = gr.Slider(1, 100, step=1, label="Intensity (Higher = stronger JPEG artifact removal)")
    zoom = gr.Slider(10, 100, step=1, value=50, label="Zoom Percentage (0 = original size)")
    x_shift = gr.Slider(0, 100, step=1, label="Horizontal shift Percentage (Before/After)")
    y_shift = gr.Slider(0, 100, step=1, label="Vertical shift Percentage (Before/After)")
    
    run = gr.Button("Run")

    with gr.Row():
        before_after = ImageSlider(label="Before/After", type="pil", value=None)

    run.click(
        inference, 
        inputs=[input_img, is_gray, input_quality, zoom, x_shift, y_shift], 
        outputs=[output_img, before_after]
    )
    
    gr.Examples([
        ["doraemon.jpg", False, 60, 58, 50, 50],
        ["tomandjerry.jpg", False, 60, 60, 57, 44],
        ["somepanda.jpg", True, 100, 70, 8, 24],
        ["cemetry.jpg", False, 70, 80, 76, 62],
        ["michelangelo_david.jpg", True, 30, 88, 53, 27],
        ["elon_musk.jpg", False, 45, 75, 33, 30],
        ["text.jpg", True, 70, 50, 11, 29]
    ], inputs=[input_img, is_gray, input_quality, zoom, x_shift, y_shift])
    
    zoom.release(zoom_image, inputs=[zoom, x_shift, y_shift, input_img], outputs=[before_after])
    x_shift.release(zoom_image, inputs=[zoom, x_shift, y_shift, input_img], outputs=[before_after])
    y_shift.release(zoom_image, inputs=[zoom, x_shift, y_shift, input_img], outputs=[before_after])
    
    gr.Markdown("""

    JPEG Artifacts are noticeable distortions of images caused by JPEG lossy compression.

    Note that this is not an AI Upscaler, but just a JPEG Compression Artifact Remover.



    [Original Demo](https://huggingface.co/spaces/danielsapit/JPEG_Artifacts_Removal)

    [FBCNN GitHub Repo](https://github.com/jiaxi-jiang/FBCNN)  

    [Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021)](https://arxiv.org/abs/2109.14573)  

    [Jiaxi Jiang](https://jiaxi-jiang.github.io/),  

    [Kai Zhang](https://cszn.github.io/),  

    [Radu Timofte](http://people.ee.ethz.ch/~timofter/)

    """)

demo.launch()