Spaces:
Sleeping
Sleeping
import torch | |
import warnings | |
import gradio as gr | |
import cv2 | |
import torchvision | |
from torch import nn | |
from torchvision.models import mobilenet_v3_small | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
warnings.filterwarnings("ignore") | |
def flip_text(x): | |
return x[::-1] | |
def method2_prep(image): | |
transforms = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize((256, 256)), | |
torchvision.transforms.CenterCrop((224, 224)) | |
]) | |
t_lower = 50 | |
t_upper = 150 | |
height, width = image.shape[:2] | |
x = (width - 1920) // 2 | |
y = (height - 1080) // 2 | |
image = image[y:y+1080, x:x+1920] | |
img = torch.from_numpy(cv2.Canny(image, t_lower, t_upper)[np.newaxis, ...]) | |
img = torch.vstack((img, img, img)) | |
return transforms(img.type(torch.float32)) | |
def model2_inf(x): | |
print("Method 2") | |
image = method2_prep(x).unsqueeze(dim=0) | |
model = mobilenet_v3_small(weights='DEFAULT') | |
model.classifier[3] = nn.Linear(in_features=1024, out_features=2, bias=True) | |
image_np = image[0].permute(1, 2, 0).cpu().numpy() | |
image_np = (image_np * 255).astype(np.uint8) # Ensure the image is of type uint8 | |
model.load_state_dict(torch.load('./weights/method2(0.960).pt', map_location=torch.device('cpu'))) | |
#print("\nModel weights loaded successfully") | |
model.eval() # Set the model to evaluation mode | |
with torch.inference_mode(): | |
model = model.to(device) | |
image = image.to(device) | |
output = torch.softmax(model(image), dim=1).detach().cpu() | |
prediction = torch.argmax(output, dim=1).item() | |
del model | |
torch.cuda.empty_cache() | |
if prediction == 0: | |
return "The image is not pixelated", None | |
else: | |
return "The image is pixelated", translate_image(Image.fromarray(x), False, 'TinySRGAN', 'False') | |
class _conv(nn.Conv2d): | |
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias): | |
super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels, | |
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True) | |
self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02) | |
self.bias.data = torch.zeros((out_channels)) | |
for p in self.parameters(): | |
p.requires_grad = True | |
class conv(nn.Module): | |
def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True): | |
super(conv, self).__init__() | |
m = [] | |
m.append(_conv(in_channels = in_channel, out_channels = out_channel, | |
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)) | |
if BN: | |
m.append(nn.BatchNorm2d(num_features = out_channel)) | |
if act is not None: | |
m.append(act) | |
self.body = nn.Sequential(*m) | |
def forward(self, x): | |
out = self.body(x) | |
return out | |
class ResBlock(nn.Module): | |
def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True): | |
super(ResBlock, self).__init__() | |
m = [] | |
m.append(conv(channels, channels, kernel_size, BN = True, act = act)) | |
m.append(conv(channels, channels, kernel_size, BN = True, act = None)) | |
self.body = nn.Sequential(*m) | |
def forward(self, x): | |
res = self.body(x) | |
res += x | |
return res | |
class BasicBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)): | |
super(BasicBlock, self).__init__() | |
m = [] | |
self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act) | |
for i in range(num_res_block): | |
m.append(ResBlock(out_channels, kernel_size, act)) | |
m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None)) | |
self.body = nn.Sequential(*m) | |
def forward(self, x): | |
res = self.conv(x) | |
out = self.body(res) | |
out += res | |
return out | |
class Upsampler(nn.Module): | |
def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)): | |
super(Upsampler, self).__init__() | |
m = [] | |
m.append(conv(channel, channel * scale * scale, kernel_size)) | |
m.append(nn.PixelShuffle(scale)) | |
if act is not None: | |
m.append(act) | |
self.body = nn.Sequential(*m) | |
def forward(self, x): | |
out = self.body(x) | |
return out | |
class discrim_block(nn.Module): | |
def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)): | |
super(discrim_block, self).__init__() | |
m = [] | |
m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act)) | |
m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2)) | |
self.body = nn.Sequential(*m) | |
def forward(self, x): | |
out = self.body(x) | |
return out | |
class TinySRGAN(nn.Module): | |
def __init__(self, img_feat = 3, n_feats = 32, kernel_size = 3, num_block = 6, act = nn.PReLU(), scale=4): | |
super(TinySRGAN, self).__init__() | |
self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act) | |
resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)] | |
self.body = nn.Sequential(*resblocks) | |
self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None) | |
if(scale == 4): | |
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)] | |
else: | |
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)] | |
self.tail = nn.Sequential(*upsample_blocks) | |
self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh()) | |
def forward(self, x): | |
x = self.conv01(x) | |
_skip_connection = x | |
x = self.body(x) | |
x = self.conv02(x) | |
feat = x + _skip_connection | |
x = self.tail(feat) | |
x = self.last_conv(x) | |
return x, feat | |
def build_generator(): | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, expansion=6, stride=1, alpha=1.0): | |
super(ResidualBlock, self).__init__() | |
self.expansion = expansion | |
self.stride = stride | |
self.in_channels = in_channels | |
self.out_channels = int(out_channels * alpha) | |
self.pointwise_conv_filters = self._make_divisible(self.out_channels, 8) | |
self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=1, stride=1, padding=0, bias=True) | |
self.bn1 = nn.BatchNorm2d(in_channels * expansion) | |
self.conv2 = nn.Conv2d(in_channels * expansion, in_channels * expansion, kernel_size=3, stride=stride, padding=1, groups=in_channels * expansion, bias=True) | |
self.bn2 = nn.BatchNorm2d(in_channels * expansion) | |
self.conv3 = nn.Conv2d(in_channels * expansion, self.pointwise_conv_filters, kernel_size=1, stride=1, padding=0, bias=True) | |
self.bn3 = nn.BatchNorm2d(self.pointwise_conv_filters) | |
self.relu = nn.ReLU(inplace=True) | |
self.skip_add = (stride == 1 and in_channels == self.pointwise_conv_filters) | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out = self.bn3(out) | |
if self.skip_add: | |
out = out + identity | |
return out | |
def _make_divisible(v, divisor, min_value=None): | |
if min_value is None: | |
min_value = divisor | |
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | |
if new_v < 0.9 * v: | |
new_v += divisor | |
return new_v | |
class Generator(nn.Module): | |
def __init__(self, in_channels, num_residual_blocks, gf): | |
super(Generator, self).__init__() | |
self.num_residual_blocks = num_residual_blocks | |
self.gf = gf | |
self.conv1 = nn.Conv2d(in_channels, gf, kernel_size=3, stride=1, padding=1) | |
self.bn1 = nn.BatchNorm2d(gf) | |
self.prelu1 = nn.PReLU() | |
self.residual_blocks = self.make_layer(ResidualBlock, gf, num_residual_blocks) | |
self.conv2 = nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1) | |
self.bn2 = nn.BatchNorm2d(gf) | |
self.upsample1 = nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1), | |
nn.PReLU() | |
) | |
self.upsample2 = nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1), | |
nn.PReLU() | |
) | |
self.conv3 = nn.Conv2d(gf, 3, kernel_size=3, stride=1, padding=1) | |
self.tanh = nn.Tanh() | |
def make_layer(self, block, out_channels, blocks): | |
layers = [] | |
for _ in range(blocks): | |
layers.append(block(out_channels, out_channels)) | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
out1 = self.prelu1(self.bn1(self.conv1(x))) | |
out = self.residual_blocks(out1) | |
out = self.bn2(self.conv2(out)) | |
out = out + out1 | |
out = self.upsample1(out) | |
out = self.upsample2(out) | |
out = self.tanh(self.conv3(out)) | |
return out | |
return Generator(3, 6, 32) | |
def numpify(imgs): | |
all_images = [] | |
for img in imgs: | |
img = img.permute(1,2,0).to('cpu') ### MIGHT CRASH HERE | |
all_images.append(img) | |
return np.stack(all_images, axis=0) | |
transform = transforms.Compose([ | |
transforms.ToTensor() | |
]) | |
# Function to translate the image | |
def translate_image(image, sharpen, model_name, save): | |
print('Translating!') | |
desired_width = 480 | |
original_width, original_height = image.size | |
desired_height = int((original_height / original_width) * desired_width) | |
resized_image = image.resize((desired_width, desired_height)) | |
if(model_name=='MobileSR'): | |
model=build_generator().to(device) | |
model.load_state_dict(torch.load('./weights/mobile_sr.pt', map_location=torch.device('cpu'))) | |
low_res = transform(resized_image) | |
low_res = low_res.unsqueeze(dim=0).to(device) | |
model.eval() | |
with torch.no_grad(): | |
sr = model(low_res) | |
fake_imgs = numpify(sr) | |
sr_img = Image.fromarray((((fake_imgs[0] + 1) / 2) * 255).astype(np.uint8)) | |
elif(model_name=='MiniSRGAN'): | |
model = MiniSRGAN().to(device) | |
model.load_state_dict(torch.load('./weights/miniSRGAN.pt', map_location=torch.device('cpu'))) | |
model.eval() | |
inputs = np.array(resized_image) | |
inputs = (inputs / 127.5) - 1.0 | |
inputs = torch.tensor(inputs.transpose(2, 0, 1).astype(np.float32)).to(device) | |
with torch.no_grad(): | |
output, _ = model(torch.unsqueeze(inputs,dim=0)) | |
output = output[0].cpu().numpy() | |
output = np.clip(output, -1.0, 1.0) | |
output = (output + 1.0) / 2.0 | |
output = output.transpose(1, 2, 0) | |
sr_img = Image.fromarray((output * 255.0).astype(np.uint8)) | |
elif(model_name=='TinySRGAN'): | |
model = TinySRGAN().to(device) | |
model.load_state_dict(torch.load('./weights/tinySRGAN.pt', map_location=torch.device('cpu'))) | |
inputs = np.array(resized_image) | |
inputs = (inputs / 127.5) - 1.0 | |
inputs = torch.tensor(inputs.transpose(2, 0, 1).astype(np.float32)).to(device) | |
model.eval() | |
with torch.no_grad(): | |
output, _ = model(torch.unsqueeze(inputs,dim=0)) | |
output = output[0].cpu().numpy() | |
output = (output + 1.0) / 2.0 | |
output = output.transpose(1, 2, 0) | |
sr_img = Image.fromarray((output * 255.0).astype(np.uint8)) | |
if sharpen: | |
sr_img_cv = np.array(sr_img) | |
sr_img_cv = cv2.cvtColor(sr_img_cv, cv2.COLOR_RGB2BGR) | |
kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) | |
sharpened_sr_img_cv = cv2.filter2D(sr_img_cv, -1, kernel) | |
sharpened_sr_img = Image.fromarray(cv2.cvtColor(sharpened_sr_img_cv, cv2.COLOR_BGR2RGB)) | |
if(save=="True"): | |
sharpened_sr_img.save('super_resolved_image.png') | |
return sharpened_sr_img | |
else: | |
if(save=="True"): | |
sr_img.save('super_resolved_image.png') | |
return sr_img | |
# Gradio interface | |
interface = gr.Interface( | |
fn=model2_inf, | |
inputs=gr.Image(type="numpy"), | |
outputs=[gr.Textbox(label="Result"), gr.Image(label="Processed Image")], | |
title="DeepClarity", | |
description="Upload an image to check if it is pixelated. If the image is pixelated, the processed image will be displayed.", | |
allow_flagging='never' | |
) | |
interface.launch() |