import torch import torch.nn as nn class Quant(torch.autograd.Function): @staticmethod def forward(ctx, input): input = torch.clamp(input, 0, 1) output = (input * 255.).round() / 255. return output @staticmethod def backward(ctx, grad_output): return grad_output class Quantization(nn.Module): def __init__(self): super(Quantization, self).__init__() def forward(self, input): return Quant.apply(input)