deepclarity / app.py
snehilchatterjee's picture
Update app.py
29b1777 verified
import torch
import warnings
import gradio as gr
import cv2
import torchvision
from torch import nn
from torchvision.models import mobilenet_v3_small
import numpy as np
from PIL import Image
from torchvision import transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
warnings.filterwarnings("ignore")
def flip_text(x):
return x[::-1]
def method2_prep(image):
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.CenterCrop((224, 224))
])
t_lower = 50
t_upper = 150
height, width = image.shape[:2]
x = (width - 1920) // 2
y = (height - 1080) // 2
image = image[y:y+1080, x:x+1920]
img = torch.from_numpy(cv2.Canny(image, t_lower, t_upper)[np.newaxis, ...])
img = torch.vstack((img, img, img))
return transforms(img.type(torch.float32))
def model2_inf(x):
print("Method 2")
image = method2_prep(x).unsqueeze(dim=0)
model = mobilenet_v3_small(weights='DEFAULT')
model.classifier[3] = nn.Linear(in_features=1024, out_features=2, bias=True)
image_np = image[0].permute(1, 2, 0).cpu().numpy()
image_np = (image_np * 255).astype(np.uint8) # Ensure the image is of type uint8
model.load_state_dict(torch.load('./weights/method2(0.960).pt', map_location=torch.device('cpu')))
#print("\nModel weights loaded successfully")
model.eval() # Set the model to evaluation mode
with torch.inference_mode():
model = model.to(device)
image = image.to(device)
output = torch.softmax(model(image), dim=1).detach().cpu()
prediction = torch.argmax(output, dim=1).item()
del model
torch.cuda.empty_cache()
if prediction == 0:
return "The image is not pixelated", None
else:
return "The image is pixelated", translate_image(Image.fromarray(x), False, 'TinySRGAN', 'False')
class _conv(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias):
super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels,
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)
self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02)
self.bias.data = torch.zeros((out_channels))
for p in self.parameters():
p.requires_grad = True
class conv(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True):
super(conv, self).__init__()
m = []
m.append(_conv(in_channels = in_channel, out_channels = out_channel,
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True))
if BN:
m.append(nn.BatchNorm2d(num_features = out_channel))
if act is not None:
m.append(act)
self.body = nn.Sequential(*m)
def forward(self, x):
out = self.body(x)
return out
class ResBlock(nn.Module):
def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True):
super(ResBlock, self).__init__()
m = []
m.append(conv(channels, channels, kernel_size, BN = True, act = act))
m.append(conv(channels, channels, kernel_size, BN = True, act = None))
self.body = nn.Sequential(*m)
def forward(self, x):
res = self.body(x)
res += x
return res
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)):
super(BasicBlock, self).__init__()
m = []
self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act)
for i in range(num_res_block):
m.append(ResBlock(out_channels, kernel_size, act))
m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None))
self.body = nn.Sequential(*m)
def forward(self, x):
res = self.conv(x)
out = self.body(res)
out += res
return out
class Upsampler(nn.Module):
def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)):
super(Upsampler, self).__init__()
m = []
m.append(conv(channel, channel * scale * scale, kernel_size))
m.append(nn.PixelShuffle(scale))
if act is not None:
m.append(act)
self.body = nn.Sequential(*m)
def forward(self, x):
out = self.body(x)
return out
class discrim_block(nn.Module):
def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)):
super(discrim_block, self).__init__()
m = []
m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act))
m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2))
self.body = nn.Sequential(*m)
def forward(self, x):
out = self.body(x)
return out
class TinySRGAN(nn.Module):
def __init__(self, img_feat = 3, n_feats = 32, kernel_size = 3, num_block = 6, act = nn.PReLU(), scale=4):
super(TinySRGAN, self).__init__()
self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act)
resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)]
self.body = nn.Sequential(*resblocks)
self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None)
if(scale == 4):
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)]
else:
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)]
self.tail = nn.Sequential(*upsample_blocks)
self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh())
def forward(self, x):
x = self.conv01(x)
_skip_connection = x
x = self.body(x)
x = self.conv02(x)
feat = x + _skip_connection
x = self.tail(feat)
x = self.last_conv(x)
return x, feat
def build_generator():
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, expansion=6, stride=1, alpha=1.0):
super(ResidualBlock, self).__init__()
self.expansion = expansion
self.stride = stride
self.in_channels = in_channels
self.out_channels = int(out_channels * alpha)
self.pointwise_conv_filters = self._make_divisible(self.out_channels, 8)
self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=1, stride=1, padding=0, bias=True)
self.bn1 = nn.BatchNorm2d(in_channels * expansion)
self.conv2 = nn.Conv2d(in_channels * expansion, in_channels * expansion, kernel_size=3, stride=stride, padding=1, groups=in_channels * expansion, bias=True)
self.bn2 = nn.BatchNorm2d(in_channels * expansion)
self.conv3 = nn.Conv2d(in_channels * expansion, self.pointwise_conv_filters, kernel_size=1, stride=1, padding=0, bias=True)
self.bn3 = nn.BatchNorm2d(self.pointwise_conv_filters)
self.relu = nn.ReLU(inplace=True)
self.skip_add = (stride == 1 and in_channels == self.pointwise_conv_filters)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.skip_add:
out = out + identity
return out
@staticmethod
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class Generator(nn.Module):
def __init__(self, in_channels, num_residual_blocks, gf):
super(Generator, self).__init__()
self.num_residual_blocks = num_residual_blocks
self.gf = gf
self.conv1 = nn.Conv2d(in_channels, gf, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(gf)
self.prelu1 = nn.PReLU()
self.residual_blocks = self.make_layer(ResidualBlock, gf, num_residual_blocks)
self.conv2 = nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(gf)
self.upsample1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1),
nn.PReLU()
)
self.upsample2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1),
nn.PReLU()
)
self.conv3 = nn.Conv2d(gf, 3, kernel_size=3, stride=1, padding=1)
self.tanh = nn.Tanh()
def make_layer(self, block, out_channels, blocks):
layers = []
for _ in range(blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
out1 = self.prelu1(self.bn1(self.conv1(x)))
out = self.residual_blocks(out1)
out = self.bn2(self.conv2(out))
out = out + out1
out = self.upsample1(out)
out = self.upsample2(out)
out = self.tanh(self.conv3(out))
return out
return Generator(3, 6, 32)
def numpify(imgs):
all_images = []
for img in imgs:
img = img.permute(1,2,0).to('cpu') ### MIGHT CRASH HERE
all_images.append(img)
return np.stack(all_images, axis=0)
transform = transforms.Compose([
transforms.ToTensor()
])
# Function to translate the image
def translate_image(image, sharpen, model_name, save):
print('Translating!')
desired_width = 480
original_width, original_height = image.size
desired_height = int((original_height / original_width) * desired_width)
resized_image = image.resize((desired_width, desired_height))
if(model_name=='MobileSR'):
model=build_generator().to(device)
model.load_state_dict(torch.load('./weights/mobile_sr.pt', map_location=torch.device('cpu')))
low_res = transform(resized_image)
low_res = low_res.unsqueeze(dim=0).to(device)
model.eval()
with torch.no_grad():
sr = model(low_res)
fake_imgs = numpify(sr)
sr_img = Image.fromarray((((fake_imgs[0] + 1) / 2) * 255).astype(np.uint8))
elif(model_name=='MiniSRGAN'):
model = MiniSRGAN().to(device)
model.load_state_dict(torch.load('./weights/miniSRGAN.pt', map_location=torch.device('cpu')))
model.eval()
inputs = np.array(resized_image)
inputs = (inputs / 127.5) - 1.0
inputs = torch.tensor(inputs.transpose(2, 0, 1).astype(np.float32)).to(device)
with torch.no_grad():
output, _ = model(torch.unsqueeze(inputs,dim=0))
output = output[0].cpu().numpy()
output = np.clip(output, -1.0, 1.0)
output = (output + 1.0) / 2.0
output = output.transpose(1, 2, 0)
sr_img = Image.fromarray((output * 255.0).astype(np.uint8))
elif(model_name=='TinySRGAN'):
model = TinySRGAN().to(device)
model.load_state_dict(torch.load('./weights/tinySRGAN.pt', map_location=torch.device('cpu')))
inputs = np.array(resized_image)
inputs = (inputs / 127.5) - 1.0
inputs = torch.tensor(inputs.transpose(2, 0, 1).astype(np.float32)).to(device)
model.eval()
with torch.no_grad():
output, _ = model(torch.unsqueeze(inputs,dim=0))
output = output[0].cpu().numpy()
output = (output + 1.0) / 2.0
output = output.transpose(1, 2, 0)
sr_img = Image.fromarray((output * 255.0).astype(np.uint8))
if sharpen:
sr_img_cv = np.array(sr_img)
sr_img_cv = cv2.cvtColor(sr_img_cv, cv2.COLOR_RGB2BGR)
kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
sharpened_sr_img_cv = cv2.filter2D(sr_img_cv, -1, kernel)
sharpened_sr_img = Image.fromarray(cv2.cvtColor(sharpened_sr_img_cv, cv2.COLOR_BGR2RGB))
if(save=="True"):
sharpened_sr_img.save('super_resolved_image.png')
return sharpened_sr_img
else:
if(save=="True"):
sr_img.save('super_resolved_image.png')
return sr_img
# Gradio interface
interface = gr.Interface(
fn=model2_inf,
inputs=gr.Image(type="numpy"),
outputs=[gr.Textbox(label="Result"), gr.Image(label="Processed Image")],
title="DeepClarity",
description="Upload an image to check if it is pixelated. If the image is pixelated, the processed image will be displayed.",
allow_flagging='never'
)
interface.launch()