|
|
|
import itertools |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from . import JPEG_utils as utils |
|
|
|
|
|
class y_dequantize(nn.Module): |
|
""" Dequantize Y channel |
|
Inputs: |
|
image(tensor): batch x height x width |
|
factor(float): compression factor |
|
Outputs: |
|
image(tensor): batch x height x width |
|
""" |
|
def __init__(self, factor=1): |
|
super(y_dequantize, self).__init__() |
|
self.y_table = utils.y_table |
|
self.factor = factor |
|
|
|
def forward(self, image): |
|
return image * (self.y_table * self.factor) |
|
|
|
|
|
class c_dequantize(nn.Module): |
|
""" Dequantize CbCr channel |
|
Inputs: |
|
image(tensor): batch x height x width |
|
factor(float): compression factor |
|
Outputs: |
|
image(tensor): batch x height x width |
|
""" |
|
def __init__(self, factor=1): |
|
super(c_dequantize, self).__init__() |
|
self.factor = factor |
|
self.c_table = utils.c_table |
|
|
|
def forward(self, image): |
|
return image * (self.c_table * self.factor) |
|
|
|
|
|
class idct_8x8(nn.Module): |
|
""" Inverse discrete Cosine Transformation |
|
Input: |
|
dcp(tensor): batch x height x width |
|
Output: |
|
image(tensor): batch x height x width |
|
""" |
|
def __init__(self): |
|
super(idct_8x8, self).__init__() |
|
alpha = np.array([1. / np.sqrt(2)] + [1] * 7) |
|
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) |
|
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) |
|
for x, y, u, v in itertools.product(range(8), repeat=4): |
|
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos( |
|
(2 * v + 1) * y * np.pi / 16) |
|
self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) |
|
|
|
def forward(self, image): |
|
|
|
image = image * self.alpha |
|
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 |
|
result.view(image.shape) |
|
return result |
|
|
|
|
|
class block_merging(nn.Module): |
|
""" Merge pathces into image |
|
Inputs: |
|
patches(tensor) batch x height*width/64, height x width |
|
height(int) |
|
width(int) |
|
Output: |
|
image(tensor): batch x height x width |
|
""" |
|
def __init__(self): |
|
super(block_merging, self).__init__() |
|
|
|
def forward(self, patches, height, width): |
|
k = 8 |
|
batch_size = patches.shape[0] |
|
|
|
image_reshaped = patches.view(batch_size, height//k, width//k, k, k) |
|
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) |
|
return image_transposed.contiguous().view(batch_size, height, width) |
|
|
|
|
|
class chroma_upsampling(nn.Module): |
|
""" Upsample chroma layers |
|
Input: |
|
y(tensor): y channel image |
|
cb(tensor): cb channel |
|
cr(tensor): cr channel |
|
Ouput: |
|
image(tensor): batch x height x width x 3 |
|
""" |
|
def __init__(self): |
|
super(chroma_upsampling, self).__init__() |
|
|
|
def forward(self, y, cb, cr): |
|
def repeat(x, k=2): |
|
height, width = x.shape[1:3] |
|
x = x.unsqueeze(-1) |
|
x = x.repeat(1, 1, k, k) |
|
x = x.view(-1, height * k, width * k) |
|
return x |
|
|
|
cb = repeat(cb) |
|
cr = repeat(cr) |
|
|
|
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) |
|
|
|
|
|
class ycbcr_to_rgb_jpeg(nn.Module): |
|
""" Converts YCbCr image to RGB JPEG |
|
Input: |
|
image(tensor): batch x height x width x 3 |
|
Outpput: |
|
result(tensor): batch x 3 x height x width |
|
""" |
|
def __init__(self): |
|
super(ycbcr_to_rgb_jpeg, self).__init__() |
|
|
|
matrix = np.array( |
|
[[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], |
|
dtype=np.float32).T |
|
self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) |
|
self.matrix = nn.Parameter(torch.from_numpy(matrix)) |
|
|
|
def forward(self, image): |
|
result = torch.tensordot(image + self.shift, self.matrix, dims=1) |
|
|
|
result.view(image.shape) |
|
return result.permute(0, 3, 1, 2) |
|
|
|
|
|
class decompress_jpeg(nn.Module): |
|
""" Full JPEG decompression algortihm |
|
Input: |
|
compressed(dict(tensor)): batch x h*w/64 x 8 x 8 |
|
rounding(function): rounding function to use |
|
factor(float): Compression factor |
|
Ouput: |
|
image(tensor): batch x 3 x height x width |
|
""" |
|
|
|
def __init__(self, rounding=torch.round, factor=1): |
|
super(decompress_jpeg, self).__init__() |
|
self.c_dequantize = c_dequantize(factor=factor) |
|
self.y_dequantize = y_dequantize(factor=factor) |
|
self.idct = idct_8x8() |
|
self.merging = block_merging() |
|
|
|
self.chroma = chroma_upsampling() |
|
self.colors = ycbcr_to_rgb_jpeg() |
|
|
|
|
|
|
|
def forward(self, y, cb, cr, height, width): |
|
components = {'y': y, 'cb': cb, 'cr': cr} |
|
|
|
|
|
self.height = height |
|
self.width = width |
|
for k in components.keys(): |
|
if k in ('cb', 'cr'): |
|
comp = self.c_dequantize(components[k]) |
|
|
|
height, width = int(self.height/2), int(self.width/2) |
|
|
|
|
|
else: |
|
comp = self.y_dequantize(components[k]) |
|
|
|
height, width = self.height, self.width |
|
comp = self.idct(comp) |
|
components[k] = self.merging(comp, height, width) |
|
|
|
|
|
image = self.chroma(components['y'], components['cb'], components['cr']) |
|
|
|
image = self.colors(image) |
|
|
|
image = torch.min(255*torch.ones_like(image), |
|
torch.max(torch.zeros_like(image), image)) |
|
return image/255 |
|
|
|
|