# Standard libraries import numpy as np # PyTorch import torch import torch.nn as nn import math y_table = np.array( [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], dtype=np.float32).T y_table = nn.Parameter(torch.from_numpy(y_table)) # c_table = np.empty((8, 8), dtype=np.float32) c_table.fill(99) c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T c_table = nn.Parameter(torch.from_numpy(c_table)) def diff_round_back(x): """ Differentiable rounding function Input: x(tensor) Output: x(tensor) """ return torch.round(x) + (x - torch.round(x))**3 def diff_round(input_tensor): test = 0 for n in range(1, 10): test += math.pow(-1, n+1) / n * torch.sin(2 * math.pi * n * input_tensor) final_tensor = input_tensor - 1 / math.pi * test return final_tensor class Quant(torch.autograd.Function): @staticmethod def forward(ctx, input): input = torch.clamp(input, 0, 1) output = (input * 255.).round() / 255. return output @staticmethod def backward(ctx, grad_output): return grad_output class Quantization(nn.Module): def __init__(self): super(Quantization, self).__init__() def forward(self, input): return Quant.apply(input) def quality_to_factor(quality): """ Calculate factor corresponding to quality Input: quality(float): Quality for jpeg compression Output: factor(float): Compression factor """ if quality < 50: quality = 5000. / quality else: quality = 200. - quality*2 return quality / 100.