diff --git "a/hit_sng_arch.py" "b/hit_sng_arch.py" --- "a/hit_sng_arch.py" +++ "b/hit_sng_arch.py" @@ -1,1132 +1,1130 @@ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_, _assert -from torchvision.transforms import functional as TF -from timm.models.fx_features import register_notrace_function - -import numpy as np -from einops import rearrange -from huggingface_hub import PyTorchModelHubMixin -from utils import FileClient, imfrombytes, img2tensor, tensor2img - -class DFE(nn.Module): - """ Dual Feature Extraction - Args: - in_features (int): Number of input channels. - out_features (int): Number of output channels. - """ - def __init__(self, in_features, out_features): - super().__init__() - - self.out_features = out_features - - self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(in_features // 5, out_features, 1, 1, 0)) - - self.linear = nn.Conv2d(in_features, out_features,1,1,0) - - def forward(self, x, x_size): - - B, L, C = x.shape - H, W = x_size - x = x.permute(0, 2, 1).contiguous().view(B, C, H, W) - x = self.conv(x) * self.linear(x) - x = x.view(B, -1, H*W).permute(0,2,1).contiguous() - - return x - -class Mlp(nn.Module): - """ MLP-based Feed-Forward Network - Args: - in_features (int): Number of input channels. - hidden_features (int | None): Number of hidden channels. Default: None - out_features (int | None): Number of output channels. Default: None - act_layer (nn.Module): Activation layer. Default: nn.GELU - drop (float): Dropout rate. Default: 0.0 - """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - wh, ww = H//window_size, W//window_size - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows, (wh, ww) - -@register_notrace_function # reason: int argument is a Proxy -def window_unpartition(windows, num_windows): - """ - Args: - windows: [B*wh*ww, WH, WW, D] - num_windows (tuple[int]): The height and width of the window. - Returns: - x: [B, ph, pw, D] - """ - x = rearrange(windows, '(p h w) wh ww c -> p (h wh) (w ww) c', h=num_windows[0], w=num_windows[1]) - return x.contiguous() - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (tuple): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] * (window_size[0] * window_size[1]) / (H * W)) - x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - -class DynamicPosBias(nn.Module): - # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py - """ Dynamic Relative Position Bias. - Args: - dim (int): Number of input channels. - num_heads (int): Number of heads for spatial self-correlation. - residual (bool): If True, use residual strage to connect conv. - """ - def __init__(self, dim, num_heads, residual): - super().__init__() - self.residual = residual - self.num_heads = num_heads - self.pos_dim = dim // 4 - self.pos_proj = nn.Linear(2, self.pos_dim) - self.pos1 = nn.Sequential( - nn.LayerNorm(self.pos_dim), - nn.ReLU(inplace=True), - nn.Linear(self.pos_dim, self.pos_dim), - ) - self.pos2 = nn.Sequential( - nn.LayerNorm(self.pos_dim), - nn.ReLU(inplace=True), - nn.Linear(self.pos_dim, self.pos_dim) - ) - self.pos3 = nn.Sequential( - nn.LayerNorm(self.pos_dim), - nn.ReLU(inplace=True), - nn.Linear(self.pos_dim, self.num_heads) - ) - def forward(self, biases): - if self.residual: - pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads - pos = pos + self.pos1(pos) - pos = pos + self.pos2(pos) - pos = self.pos3(pos) - else: - pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) - return pos - -class SCC(nn.Module): - """ Spatial-Channel Correlation. - Args: - dim (int): Number of input channels. - base_win_size (tuple[int]): The height and width of the base window. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of heads for spatial self-correlation. - value_drop (float, optional): Dropout ratio of value. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, base_win_size, window_size, num_heads, value_drop=0., proj_drop=0.): - - super().__init__() - # parameters - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - - # feature projection - head_dim = dim // (2*num_heads) - if dim % (2*num_heads) > 0: - head_dim = head_dim + 1 - self.attn_dim = head_dim * 2 * num_heads - self.qv = DFE(dim, self.attn_dim) - self.proj = nn.Linear(self.attn_dim, dim) - - # dropout - self.value_drop = nn.Dropout(value_drop) - self.proj_drop = nn.Dropout(proj_drop) - - # base window size - min_h = min(self.window_size[0], base_win_size[0]) - min_w = min(self.window_size[1], base_win_size[1]) - self.base_win_size = (min_h, min_w) - - # normalization factor and spatial linear layer for S-SC - self.scale = head_dim - self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1) - - # NGram window partition without shifting - self.ngram_window_partition = NGramWindowPartition(dim, window_size, 2, num_heads, shift_size=0) - - # define a parameter table of relative position bias - self.H_sp, self.W_sp = self.window_size - self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False) - - def spatial_linear_projection(self, x): - B, num_h, L, C = x.shape - H, W = self.window_size - map_H, map_W = self.base_win_size - - x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1) - x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C) - return x - - def spatial_self_correlation(self, q, v): - - B, num_head, L, C = q.shape - - # spatial projection - v = self.spatial_linear_projection(v) - - # compute correlation map - corr_map = (q @ v.transpose(-2,-1)) / self.scale - - # add relative position bias - position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device) - position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device) - biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) - rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float() - pos = self.pos(rpe_biases) - - # select position bias - coords_h = torch.arange(self.H_sp, device=v.device) - coords_w = torch.arange(self.W_sp, device=v.device) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.H_sp - 1 - relative_coords[:, :, 1] += self.W_sp - 1 - relative_coords[:, :, 0] *= 2 * self.W_sp - 1 - relative_position_index = relative_coords.sum(-1) - relative_position_bias = pos[relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view( - self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - corr_map = corr_map + relative_position_bias.unsqueeze(0) - - # transformation - v_drop = self.value_drop(v) - x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1) - - return x - - def channel_self_correlation(self, q, v): - - B, num_head, L, C = q.shape - - # apply single head strategy - q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C) - v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C) - - # compute correlation map - corr_map = (q.transpose(-2,-1) @ v) / L - - # transformation - v_drop = self.value_drop(v) - x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1) - - return x - - def forward(self, x): - """ - Args: - x: input features with shape of (B, H, W, C) - """ - xB,xH,xW,xC = x.shape - qv = self.qv(x.view(xB,-1,xC), (xH,xW)).view(xB, xH, xW, xC) - - # window partition - qv = self.ngram_window_partition(qv) - qv = qv.view(-1, self.window_size[0]*self.window_size[1], xC) - - # qv splitting - B, L, C = qv.shape - qv = qv.view(B, L, 2, self.num_heads, C // (2*self.num_heads)).permute(2,0,3,1,4).contiguous() - q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads - - # spatial self-correlation (S-SC) - x_spatial = self.spatial_self_correlation(q, v) - x_spatial = x_spatial.view(-1, self.window_size[0], self.window_size[1], C//2) - x_spatial = window_reverse(x_spatial, (self.window_size[0],self.window_size[1]), xH, xW) # xB xH xW xC - - # channel self-correlation (C-SC) - x_channel = self.channel_self_correlation(q, v) - x_channel = x_channel.view(-1, self.window_size[0], self.window_size[1], C//2) - x_channel = window_reverse(x_channel, (self.window_size[0], self.window_size[1]), xH, xW) # xB xH xW xC - - # spatial-channel information fusion - x = torch.cat([x_spatial, x_channel], -1) - x = self.proj_drop(self.proj(x)) - - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - -class NGramWindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias for NGram attention. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - -class NGramContext(nn.Module): - ''' - Args: - dim (int): Number of input channels. - window_size (int or tuple[int]): The height and width of the window. - ngram (int): How much windows(or patches) to see. - ngram_num_heads (int): - padding_mode (str, optional): How to pad. Default: seq_refl_win_pad - Options: ['seq_refl_win_pad', 'zero_pad'] - ''' - def __init__(self, dim, window_size, ngram, ngram_num_heads, padding_mode='seq_refl_win_pad'): - super(NGramContext, self).__init__() - _assert(padding_mode in ['seq_refl_win_pad', 'zero_pad'], "padding mode should be 'seq_refl_win_pad' or 'zero_pad'!!") - - self.dim = dim - self.window_size = to_2tuple(window_size) - self.ngram = ngram - self.padding_mode = padding_mode - - # to alleviate parameter expansion with window sizes - self.unigram_embed = nn.Conv2d(2, 1, - kernel_size=(self.window_size[0], self.window_size[1]), - stride=self.window_size, padding=0, groups=1) - - self.ngram_attn = NGramWindowAttention(dim=dim//2, num_heads=ngram_num_heads, window_size=(ngram, ngram)) - self.avg_pool = nn.AvgPool2d(ngram) - self.merge = nn.Conv2d(dim, dim, 1, 1, 0) - - def seq_refl_win_pad(self, x, back=False): - if self.ngram == 1: return x - x = TF.pad(x, (0,0,self.ngram-1,self.ngram-1)) if not back else TF.pad(x, (self.ngram-1,self.ngram-1,0,0)) - if self.padding_mode == 'zero_pad': - return x - if not back: - (start_h, start_w), (end_h, end_w) = to_2tuple(-2*self.ngram+1), to_2tuple(-self.ngram) - # pad lower - x[:,:,-(self.ngram-1):,:] = x[:,:,start_h:end_h,:] - # pad right - x[:,:,:,-(self.ngram-1):] = x[:,:,:,start_w:end_w] - else: - (start_h, start_w), (end_h, end_w) = to_2tuple(self.ngram), to_2tuple(2*self.ngram-1) - # pad upper - x[:,:,:self.ngram-1,:] = x[:,:,start_h:end_h,:] - # pad left - x[:,:,:,:self.ngram-1] = x[:,:,:,start_w:end_w] - - return x - - def sliding_window_attention(self, unigram): - slide = unigram.unfold(3, self.ngram, 1).unfold(2, self.ngram, 1) - slide = rearrange(slide, 'b c h w ww hh -> b (h hh) (w ww) c') # [B, 2(wh+ngram-2), 2(ww+ngram-2), D/2] - slide, num_windows = window_partition(slide, self.ngram) # [B*wh*ww, ngram, ngram, D/2], (wh, ww) - slide = slide.view(-1, self.ngram*self.ngram, self.dim//2) # [B*wh*ww, ngram*ngram, D/2] - - context = self.ngram_attn(slide).view(-1, self.ngram, self.ngram, self.dim//2) # [B*wh*ww, ngram, ngram, D/2] - - context = window_unpartition(context, num_windows) # [B, wh*ngram, ww*ngram, D/2] - context = rearrange(context, 'b h w d -> b d h w') # [B, D/2, wh*ngram, ww*ngram] - context = self.avg_pool(context) # [B, D/2, wh, ww] - return context - - def forward(self, x): - B, ph, pw, D = x.size() - x = rearrange(x, 'b ph pw d -> b d ph pw') # [B, D, ph, pw] - x = x.contiguous().view(B*(D//2),2,ph,pw) - unigram = self.unigram_embed(x).view(B, D//2, ph//self.window_size[0], pw//self.window_size[1]) - - unigram_forward_pad = self.seq_refl_win_pad(unigram, False) # [B, D/2, wh+ngram-1, ww+ngram-1] - unigram_backward_pad = self.seq_refl_win_pad(unigram, True) # [B, D/2, wh+ngram-1, ww+ngram-1] - - context_forward = self.sliding_window_attention(unigram_forward_pad) # [B, D/2, wh, ww] - context_backward = self.sliding_window_attention(unigram_backward_pad) # [B, D/2, wh, ww] - - context_bidirect = torch.cat([context_forward, context_backward], dim=1) # [B, D, wh, ww] - context_bidirect = self.merge(context_bidirect) # [B, D, wh, ww] - context_bidirect = rearrange(context_bidirect, 'b d h w -> b h w d') # [B, wh, ww, D] - - return context_bidirect.unsqueeze(-2).unsqueeze(-2).contiguous() # [B, wh, ww, 1, 1, D] - -class NGramWindowPartition(nn.Module): - """ - Args: - dim (int): Number of input channels. - window_size (int): The height and width of the window. - ngram (int): How much windows to see as context. - ngram_num_heads (int): - shift_size (int, optional): Shift size for SW-MSA. Default: 0 - """ - def __init__(self, dim, window_size, ngram, ngram_num_heads, shift_size=0): - super(NGramWindowPartition, self).__init__() - self.window_size = window_size[0] - self.ngram = ngram - self.shift_size = shift_size - - self.ngram_context = NGramContext(dim, window_size, ngram, ngram_num_heads, padding_mode='seq_refl_win_pad') - - def forward(self, x): - B, ph, pw, D = x.size() - wh, ww = ph//self.window_size, pw//self.window_size # number of windows (height, width) - _assert(0 not in [wh, ww], "feature map size should be larger than window size!") - - context = self.ngram_context(x) # [B, wh, ww, 1, 1, D] - - windows = rearrange(x, 'b (h wh) (w ww) c -> b h w wh ww c', - wh=self.window_size, ww=self.window_size).contiguous() # [B, wh, ww, WH, WW, D]. semi window partitioning - windows+=context # [B, wh, ww, WH, WW, D]. inject context - - # Cyclic Shift - if self.shift_size>0: - x = rearrange(windows, 'b h w wh ww c -> b (h wh) (w ww) c').contiguous() # [B, ph, pw, D]. re-patchfying - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # [B, ph, pw, D]. cyclic shift - windows = rearrange(shifted_x, 'b (h wh) (w ww) c -> b h w wh ww c', - wh=self.window_size, ww=self.window_size).contiguous() # [B, wh, ww, WH, WW, D]. re-semi window partitioning - windows = rearrange(windows, 'b h w wh ww c -> (b h w) wh ww c').contiguous() # [B*wh*ww, WH, WW, D]. window partitioning - - return windows - - -class HierarchicalTransformerBlock(nn.Module): - """ Hierarchical Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of heads for spatial self-correlation. - base_win_size (tuple[int]): The height and width of the base window. - window_size (tuple[int]): The height and width of the window. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - drop (float, optional): Dropout rate. Default: 0.0 - value_drop (float, optional): Dropout ratio of value. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, base_win_size, window_size, - mlp_ratio=4., drop=0., value_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.mlp_ratio = mlp_ratio - - # check window size - if (window_size[0] > base_win_size[0]) and (window_size[1] > base_win_size[1]): - assert window_size[0] % base_win_size[0] == 0, "please ensure the window size is smaller than or divisible by the base window size" - assert window_size[1] % base_win_size[1] == 0, "please ensure the window size is smaller than or divisible by the base window size" - - - self.norm1 = norm_layer(dim) - self.correlation = SCC( - dim, base_win_size=base_win_size, window_size=self.window_size, num_heads=num_heads, - value_drop=value_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def check_image_size(self, x, win_size): - x = x.permute(0,3,1,2).contiguous() - _, _, h, w = x.size() - mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0] - mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1] - - if mod_pad_h >= h or mod_pad_w >= w: - pad_h, pad_w = h-1, w-1 - x = F.pad(x, (0, pad_w, 0, pad_h), 'reflect') - else: - pad_h, pad_w = 0, 0 - - mod_pad_h = mod_pad_h - pad_h - mod_pad_w = mod_pad_w - pad_w - - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - x = x.permute(0,2,3,1).contiguous() - return x - - def forward(self, x, x_size, win_size): - H, W = x_size - B, L, C = x.shape - - shortcut = x - x = x.view(B, H, W, C) - - # padding - x = self.check_image_size(x, (win_size[0]*2, win_size[1]*2)) - _, H_pad, W_pad, _ = x.shape # shape after padding - - x = self.correlation(x) - - # unpad - x = x[:, :H, :W, :].contiguous() - - # norm - x = x.view(B, H * W, C) - x = self.norm1(x) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.norm2(self.mlp(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" - - -class PatchMerging(nn.Module): - """ Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - -class BasicLayer(nn.Module): - """ A basic Hierarchical Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of heads for spatial self-correlation. - base_win_size (tuple[int]): The height and width of the base window. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - drop (float, optional): Dropout rate. Default: 0.0 - value_drop (float, optional): Dropout ratio of value. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8]. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, base_win_size, - mlp_ratio=4., drop=0., value_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, - downsample=None, use_checkpoint=False, hier_win_ratios=[0.5,1,2,4,6,8]): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - self.win_hs = [int(base_win_size[0] * ratio) for ratio in hier_win_ratios] - self.win_ws = [int(base_win_size[1] * ratio) for ratio in hier_win_ratios] - - # build blocks - self.blocks = nn.ModuleList([ - HierarchicalTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, - base_win_size=base_win_size, - window_size=(self.win_hs[i], self.win_ws[i]), - mlp_ratio=mlp_ratio, - drop=drop, value_drop=value_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - - i = 0 - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size, (self.win_hs[i], self.win_ws[i])) - else: - x = blk(x, x_size, (self.win_hs[i], self.win_ws[i])) - i = i + 1 - - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - -class RHTB(nn.Module): - """Residual Hierarchical Transformer Block (RHTB). - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of heads for spatial self-correlation. - base_win_size (tuple[int]): The height and width of the base window. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - drop (float, optional): Dropout rate. Default: 0.0 - value_drop (float, optional): Dropout ratio of value. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8]. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, base_win_size, - mlp_ratio=4., drop=0., value_drop=0., drop_path=0., norm_layer=nn.LayerNorm, - downsample=None, use_checkpoint=False, img_size=224, patch_size=4, - resi_connection='1conv', hier_win_ratios=[0.5,1,2,4,6,8]): - super(RHTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - base_win_size=base_win_size, - mlp_ratio=mlp_ratio, - drop=drop, value_drop=value_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint, - hier_win_ratios=hier_win_ratios) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - -class HiT_SNG(nn.Module, PyTorchModelHubMixin): - """ HiT-SNG network. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Transformer block. - num_heads (tuple(int)): Number of heads for spatial self-correlation in different layers. - base_win_size (tuple[int]): The height and width of the base window. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - drop_rate (float): Dropout rate. Default: 0 - value_drop_rate (float): Dropout ratio of value. Default: 0.0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range (float): Image range. 1. or 255. - upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv' - hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8]. - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=60, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], - base_win_size=[8,8], mlp_ratio=2., - drop_rate=0., value_drop_rate=0., drop_path_rate=0., - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffledirect', resi_connection='1conv', - hier_win_ratios=[0.5,1,2,4,6,8], - **kwargs): - super(HiT_SNG, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.base_win_size = base_win_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Hierarchical Transformer blocks (RHTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RHTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - base_win_size=base_win_size, - mlp_ratio=self.mlp_ratio, - drop=drop_rate, value_drop=value_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection, - hier_win_ratios=hier_win_ratios - ) - self.layers.append(layer) - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - assert self.upscale == 4, 'only support x4 now.' - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def infer_image(self, image_path, cuda=True): - - io_backend_opt = {'type':'disk'} - self.file_client = FileClient(io_backend_opt.pop('type'), **io_backend_opt) - - # load lq image - lq_path = image_path - img_bytes = self.file_client.get(lq_path, 'lq') - img_lq = imfrombytes(img_bytes, float32=True) - - # BGR to RGB, HWC to CHW, numpy to tensor - x = img2tensor(img_lq, bgr2rgb=True, float32=True)[None,...] - - if cuda: - x= x.cuda() - - out = self(x) - - if cuda: - out = out.cpu() - - out = tensor2img(out) - - return out - - def forward(self, x): - H, W = x.shape[2:] - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - - return x[:, :, :H*self.upscale, :W*self.upscale] - - -if __name__ == '__main__': - upscale = 4 - base_win_size = [8, 8] - height = (1024 // upscale // base_win_size[0] + 1) * base_win_size[0] - width = (720 // upscale // base_win_size[1] + 1) * base_win_size[1] - - ## HiT-SIR - model = HiT_SNG(upscale=4, img_size=(height, width), - base_win_size=base_win_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - - params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) - print("params: ", params_num) - +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_, _assert +from torchvision.transforms import functional as TF +from timm.models.fx_features import register_notrace_function + +import numpy as np +from einops import rearrange +from huggingface_hub import PyTorchModelHubMixin +from utils import FileClient, imfrombytes, img2tensor, tensor2img + +class DFE(nn.Module): + """ Dual Feature Extraction + Args: + in_features (int): Number of input channels. + out_features (int): Number of output channels. + """ + def __init__(self, in_features, out_features): + super().__init__() + + self.out_features = out_features + + self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(in_features // 5, out_features, 1, 1, 0)) + + self.linear = nn.Conv2d(in_features, out_features,1,1,0) + + def forward(self, x, x_size): + + B, L, C = x.shape + H, W = x_size + x = x.permute(0, 2, 1).contiguous().view(B, C, H, W) + x = self.conv(x) * self.linear(x) + x = x.view(B, -1, H*W).permute(0,2,1).contiguous() + + return x + +class Mlp(nn.Module): + """ MLP-based Feed-Forward Network + Args: + in_features (int): Number of input channels. + hidden_features (int | None): Number of hidden channels. Default: None + out_features (int | None): Number of output channels. Default: None + act_layer (nn.Module): Activation layer. Default: nn.GELU + drop (float): Dropout rate. Default: 0.0 + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + wh, ww = H//window_size, W//window_size + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (wh, ww) + +@register_notrace_function # reason: int argument is a Proxy +def window_unpartition(windows, num_windows): + """ + Args: + windows: [B*wh*ww, WH, WW, D] + num_windows (tuple[int]): The height and width of the window. + Returns: + x: [B, ph, pw, D] + """ + x = rearrange(windows, '(p h w) wh ww c -> p (h wh) (w ww) c', h=num_windows[0], w=num_windows[1]) + return x.contiguous() + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (tuple): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] * (window_size[0] * window_size[1]) / (H * W)) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class DynamicPosBias(nn.Module): + # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py + """ Dynamic Relative Position Bias. + Args: + dim (int): Number of input channels. + num_heads (int): Number of heads for spatial self-correlation. + residual (bool): If True, use residual strage to connect conv. + """ + def __init__(self, dim, num_heads, residual): + super().__init__() + self.residual = residual + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + def forward(self, biases): + if self.residual: + pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads + pos = pos + self.pos1(pos) + pos = pos + self.pos2(pos) + pos = self.pos3(pos) + else: + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + +class SCC(nn.Module): + """ Spatial-Channel Correlation. + Args: + dim (int): Number of input channels. + base_win_size (tuple[int]): The height and width of the base window. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of heads for spatial self-correlation. + value_drop (float, optional): Dropout ratio of value. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, base_win_size, window_size, num_heads, value_drop=0., proj_drop=0.): + + super().__init__() + # parameters + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + + # feature projection + head_dim = dim // (2*num_heads) + if dim % (2*num_heads) > 0: + head_dim = head_dim + 1 + self.attn_dim = head_dim * 2 * num_heads + self.qv = DFE(dim, self.attn_dim) + self.proj = nn.Linear(self.attn_dim, dim) + + # dropout + self.value_drop = nn.Dropout(value_drop) + self.proj_drop = nn.Dropout(proj_drop) + + # base window size + min_h = min(self.window_size[0], base_win_size[0]) + min_w = min(self.window_size[1], base_win_size[1]) + self.base_win_size = (min_h, min_w) + + # normalization factor and spatial linear layer for S-SC + self.scale = head_dim + self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1) + + # NGram window partition without shifting + self.ngram_window_partition = NGramWindowPartition(dim, window_size, 2, num_heads, shift_size=0) + + # define a parameter table of relative position bias + self.H_sp, self.W_sp = self.window_size + self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False) + + def spatial_linear_projection(self, x): + B, num_h, L, C = x.shape + H, W = self.window_size + map_H, map_W = self.base_win_size + + x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1) + x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C) + return x + + def spatial_self_correlation(self, q, v): + + B, num_head, L, C = q.shape + + # spatial projection + v = self.spatial_linear_projection(v) + + # compute correlation map + corr_map = (q @ v.transpose(-2,-1)) / self.scale + + # add relative position bias + position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device) + position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) + rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float() + pos = self.pos(rpe_biases) + + # select position bias + coords_h = torch.arange(self.H_sp, device=v.device) + coords_w = torch.arange(self.W_sp, device=v.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.H_sp - 1 + relative_coords[:, :, 1] += self.W_sp - 1 + relative_coords[:, :, 0] *= 2 * self.W_sp - 1 + relative_position_index = relative_coords.sum(-1) + relative_position_bias = pos[relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view( + self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + corr_map = corr_map + relative_position_bias.unsqueeze(0) + + # transformation + v_drop = self.value_drop(v) + x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1) + + return x + + def channel_self_correlation(self, q, v): + + B, num_head, L, C = q.shape + + # apply single head strategy + q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C) + v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C) + + # compute correlation map + corr_map = (q.transpose(-2,-1) @ v) / L + + # transformation + v_drop = self.value_drop(v) + x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1) + + return x + + def forward(self, x): + """ + Args: + x: input features with shape of (B, H, W, C) + """ + xB,xH,xW,xC = x.shape + qv = self.qv(x.view(xB,-1,xC), (xH,xW)).view(xB, xH, xW, xC) + + # window partition + qv = self.ngram_window_partition(qv) + qv = qv.view(-1, self.window_size[0]*self.window_size[1], xC) + + # qv splitting + B, L, C = qv.shape + qv = qv.view(B, L, 2, self.num_heads, C // (2*self.num_heads)).permute(2,0,3,1,4).contiguous() + q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads + + # spatial self-correlation (S-SC) + x_spatial = self.spatial_self_correlation(q, v) + x_spatial = x_spatial.view(-1, self.window_size[0], self.window_size[1], C//2) + x_spatial = window_reverse(x_spatial, (self.window_size[0],self.window_size[1]), xH, xW) # xB xH xW xC + + # channel self-correlation (C-SC) + x_channel = self.channel_self_correlation(q, v) + x_channel = x_channel.view(-1, self.window_size[0], self.window_size[1], C//2) + x_channel = window_reverse(x_channel, (self.window_size[0], self.window_size[1]), xH, xW) # xB xH xW xC + + # spatial-channel information fusion + x = torch.cat([x_spatial, x_channel], -1) + x = self.proj_drop(self.proj(x)) + + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class NGramWindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias for NGram attention. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class NGramContext(nn.Module): + ''' + Args: + dim (int): Number of input channels. + window_size (int or tuple[int]): The height and width of the window. + ngram (int): How much windows(or patches) to see. + ngram_num_heads (int): + padding_mode (str, optional): How to pad. Default: seq_refl_win_pad + Options: ['seq_refl_win_pad', 'zero_pad'] + ''' + def __init__(self, dim, window_size, ngram, ngram_num_heads, padding_mode='seq_refl_win_pad'): + super(NGramContext, self).__init__() + _assert(padding_mode in ['seq_refl_win_pad', 'zero_pad'], "padding mode should be 'seq_refl_win_pad' or 'zero_pad'!!") + + self.dim = dim + self.window_size = to_2tuple(window_size) + self.ngram = ngram + self.padding_mode = padding_mode + + # to alleviate parameter expansion with window sizes + self.unigram_embed = nn.Conv2d(2, 1, + kernel_size=(self.window_size[0], self.window_size[1]), + stride=self.window_size, padding=0, groups=1) + + self.ngram_attn = NGramWindowAttention(dim=dim//2, num_heads=ngram_num_heads, window_size=(ngram, ngram)) + self.avg_pool = nn.AvgPool2d(ngram) + self.merge = nn.Conv2d(dim, dim, 1, 1, 0) + + def seq_refl_win_pad(self, x, back=False): + if self.ngram == 1: return x + x = TF.pad(x, (0,0,self.ngram-1,self.ngram-1)) if not back else TF.pad(x, (self.ngram-1,self.ngram-1,0,0)) + if self.padding_mode == 'zero_pad': + return x + if not back: + (start_h, start_w), (end_h, end_w) = to_2tuple(-2*self.ngram+1), to_2tuple(-self.ngram) + # pad lower + x[:,:,-(self.ngram-1):,:] = x[:,:,start_h:end_h,:] + # pad right + x[:,:,:,-(self.ngram-1):] = x[:,:,:,start_w:end_w] + else: + (start_h, start_w), (end_h, end_w) = to_2tuple(self.ngram), to_2tuple(2*self.ngram-1) + # pad upper + x[:,:,:self.ngram-1,:] = x[:,:,start_h:end_h,:] + # pad left + x[:,:,:,:self.ngram-1] = x[:,:,:,start_w:end_w] + + return x + + def sliding_window_attention(self, unigram): + slide = unigram.unfold(3, self.ngram, 1).unfold(2, self.ngram, 1) + slide = rearrange(slide, 'b c h w ww hh -> b (h hh) (w ww) c') # [B, 2(wh+ngram-2), 2(ww+ngram-2), D/2] + slide, num_windows = window_partition(slide, self.ngram) # [B*wh*ww, ngram, ngram, D/2], (wh, ww) + slide = slide.view(-1, self.ngram*self.ngram, self.dim//2) # [B*wh*ww, ngram*ngram, D/2] + + context = self.ngram_attn(slide).view(-1, self.ngram, self.ngram, self.dim//2) # [B*wh*ww, ngram, ngram, D/2] + + context = window_unpartition(context, num_windows) # [B, wh*ngram, ww*ngram, D/2] + context = rearrange(context, 'b h w d -> b d h w') # [B, D/2, wh*ngram, ww*ngram] + context = self.avg_pool(context) # [B, D/2, wh, ww] + return context + + def forward(self, x): + B, ph, pw, D = x.size() + x = rearrange(x, 'b ph pw d -> b d ph pw') # [B, D, ph, pw] + x = x.contiguous().view(B*(D//2),2,ph,pw) + unigram = self.unigram_embed(x).view(B, D//2, ph//self.window_size[0], pw//self.window_size[1]) + + unigram_forward_pad = self.seq_refl_win_pad(unigram, False) # [B, D/2, wh+ngram-1, ww+ngram-1] + unigram_backward_pad = self.seq_refl_win_pad(unigram, True) # [B, D/2, wh+ngram-1, ww+ngram-1] + + context_forward = self.sliding_window_attention(unigram_forward_pad) # [B, D/2, wh, ww] + context_backward = self.sliding_window_attention(unigram_backward_pad) # [B, D/2, wh, ww] + + context_bidirect = torch.cat([context_forward, context_backward], dim=1) # [B, D, wh, ww] + context_bidirect = self.merge(context_bidirect) # [B, D, wh, ww] + context_bidirect = rearrange(context_bidirect, 'b d h w -> b h w d') # [B, wh, ww, D] + + return context_bidirect.unsqueeze(-2).unsqueeze(-2).contiguous() # [B, wh, ww, 1, 1, D] + +class NGramWindowPartition(nn.Module): + """ + Args: + dim (int): Number of input channels. + window_size (int): The height and width of the window. + ngram (int): How much windows to see as context. + ngram_num_heads (int): + shift_size (int, optional): Shift size for SW-MSA. Default: 0 + """ + def __init__(self, dim, window_size, ngram, ngram_num_heads, shift_size=0): + super(NGramWindowPartition, self).__init__() + self.window_size = window_size[0] + self.ngram = ngram + self.shift_size = shift_size + + self.ngram_context = NGramContext(dim, window_size, ngram, ngram_num_heads, padding_mode='seq_refl_win_pad') + + def forward(self, x): + B, ph, pw, D = x.size() + wh, ww = ph//self.window_size, pw//self.window_size # number of windows (height, width) + _assert(0 not in [wh, ww], "feature map size should be larger than window size!") + + context = self.ngram_context(x) # [B, wh, ww, 1, 1, D] + + windows = rearrange(x, 'b (h wh) (w ww) c -> b h w wh ww c', + wh=self.window_size, ww=self.window_size).contiguous() # [B, wh, ww, WH, WW, D]. semi window partitioning + windows+=context # [B, wh, ww, WH, WW, D]. inject context + + # Cyclic Shift + if self.shift_size>0: + x = rearrange(windows, 'b h w wh ww c -> b (h wh) (w ww) c').contiguous() # [B, ph, pw, D]. re-patchfying + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # [B, ph, pw, D]. cyclic shift + windows = rearrange(shifted_x, 'b (h wh) (w ww) c -> b h w wh ww c', + wh=self.window_size, ww=self.window_size).contiguous() # [B, wh, ww, WH, WW, D]. re-semi window partitioning + windows = rearrange(windows, 'b h w wh ww c -> (b h w) wh ww c').contiguous() # [B*wh*ww, WH, WW, D]. window partitioning + + return windows + + +class HierarchicalTransformerBlock(nn.Module): + """ Hierarchical Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of heads for spatial self-correlation. + base_win_size (tuple[int]): The height and width of the base window. + window_size (tuple[int]): The height and width of the window. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + value_drop (float, optional): Dropout ratio of value. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, base_win_size, window_size, + mlp_ratio=4., drop=0., value_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + # check window size + if (window_size[0] > base_win_size[0]) and (window_size[1] > base_win_size[1]): + assert window_size[0] % base_win_size[0] == 0, "please ensure the window size is smaller than or divisible by the base window size" + assert window_size[1] % base_win_size[1] == 0, "please ensure the window size is smaller than or divisible by the base window size" + + + self.norm1 = norm_layer(dim) + self.correlation = SCC( + dim, base_win_size=base_win_size, window_size=self.window_size, num_heads=num_heads, + value_drop=value_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def check_image_size(self, x, win_size): + x = x.permute(0,3,1,2).contiguous() + _, _, h, w = x.size() + mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0] + mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1] + + if mod_pad_h >= h or mod_pad_w >= w: + pad_h, pad_w = h-1, w-1 + x = F.pad(x, (0, pad_w, 0, pad_h), 'reflect') + else: + pad_h, pad_w = 0, 0 + + mod_pad_h = mod_pad_h - pad_h + mod_pad_w = mod_pad_w - pad_w + + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + x = x.permute(0,2,3,1).contiguous() + return x + + def forward(self, x, x_size, win_size): + H, W = x_size + B, L, C = x.shape + + shortcut = x + x = x.view(B, H, W, C) + + # padding + x = self.check_image_size(x, (win_size[0]*2, win_size[1]*2)) + _, H_pad, W_pad, _ = x.shape # shape after padding + + x = self.correlation(x) + + # unpad + x = x[:, :H, :W, :].contiguous() + + # norm + x = x.view(B, H * W, C) + x = self.norm1(x) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +class PatchMerging(nn.Module): + """ Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + +class BasicLayer(nn.Module): + """ A basic Hierarchical Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of heads for spatial self-correlation. + base_win_size (tuple[int]): The height and width of the base window. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + value_drop (float, optional): Dropout ratio of value. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8]. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, base_win_size, + mlp_ratio=4., drop=0., value_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, + downsample=None, use_checkpoint=False, hier_win_ratios=[0.5,1,2,4,6,8]): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + self.win_hs = [int(base_win_size[0] * ratio) for ratio in hier_win_ratios] + self.win_ws = [int(base_win_size[1] * ratio) for ratio in hier_win_ratios] + + # build blocks + self.blocks = nn.ModuleList([ + HierarchicalTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, + base_win_size=base_win_size, + window_size=(self.win_hs[i], self.win_ws[i]), + mlp_ratio=mlp_ratio, + drop=drop, value_drop=value_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + + i = 0 + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size, (self.win_hs[i], self.win_ws[i])) + else: + x = blk(x, x_size, (self.win_hs[i], self.win_ws[i])) + i = i + 1 + + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class RHTB(nn.Module): + """Residual Hierarchical Transformer Block (RHTB). + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of heads for spatial self-correlation. + base_win_size (tuple[int]): The height and width of the base window. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + value_drop (float, optional): Dropout ratio of value. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8]. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, base_win_size, + mlp_ratio=4., drop=0., value_drop=0., drop_path=0., norm_layer=nn.LayerNorm, + downsample=None, use_checkpoint=False, img_size=224, patch_size=4, + resi_connection='1conv', hier_win_ratios=[0.5,1,2,4,6,8]): + super(RHTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + base_win_size=base_win_size, + mlp_ratio=mlp_ratio, + drop=drop, value_drop=value_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + hier_win_ratios=hier_win_ratios) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + +class HiT_SNG(nn.Module, PyTorchModelHubMixin): + """ HiT-SNG network. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Transformer block. + num_heads (tuple(int)): Number of heads for spatial self-correlation in different layers. + base_win_size (tuple[int]): The height and width of the base window. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + drop_rate (float): Dropout rate. Default: 0 + value_drop_rate (float): Dropout ratio of value. Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range (float): Image range. 1. or 255. + upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv' + hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8]. + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=60, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + base_win_size=[8,8], mlp_ratio=2., + drop_rate=0., value_drop_rate=0., drop_path_rate=0., + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffledirect', resi_connection='1conv', + hier_win_ratios=[0.5,1,2,4,6,8], + **kwargs): + super(HiT_SNG, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.base_win_size = base_win_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Hierarchical Transformer blocks (RHTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RHTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + base_win_size=base_win_size, + mlp_ratio=self.mlp_ratio, + drop=drop_rate, value_drop=value_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection, + hier_win_ratios=hier_win_ratios + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def infer_image(self, image_path, device): + + io_backend_opt = {'type':'disk'} + self.file_client = FileClient(io_backend_opt.pop('type'), **io_backend_opt) + + # load lq image + lq_path = image_path + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # BGR to RGB, HWC to CHW, numpy to tensor + x = img2tensor(img_lq, bgr2rgb=True, float32=True)[None,...] + + x= x.to(device) + + out = self(x) + + out = out.cpu() + + out = tensor2img(out) + + return out + + def forward(self, x): + H, W = x.shape[2:] + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + +if __name__ == '__main__': + upscale = 4 + base_win_size = [8, 8] + height = (1024 // upscale // base_win_size[0] + 1) * base_win_size[0] + width = (720 // upscale // base_win_size[1] + 1) * base_win_size[1] + + ## HiT-SIR + model = HiT_SNG(upscale=4, img_size=(height, width), + base_win_size=base_win_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("params: ", params_num) +