from model import common from model import attention import torch from lambda_networks import LambdaLayer import torch.nn as nn import torch.cuda.amp as amp class ConvGRU(nn.Module): def __init__(self, hidden_dim=128, input_dim=192+128): super(ConvGRU, self).__init__() self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) def forward(self, h, x): hx = torch.cat([h, x], dim=1) z = torch.sigmoid(self.convz(hx)) r = torch.sigmoid(self.convr(hx)) q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) # h = (1-z) * h + z * q # return h return (1-z) * h + z * q class SepConvGRU(nn.Module): def __init__(self, hidden_dim=128, input_dim=192+128): super(SepConvGRU, self).__init__() self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) def forward(self, h, x): # horizontal hx = torch.cat([h, x], dim=1) z = torch.sigmoid(self.convz1(hx)) r = torch.sigmoid(self.convr1(hx)) q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) h = (1-z) * h + z * q # vertical hx = torch.cat([h, x], dim=1) z = torch.sigmoid(self.convz2(hx)) r = torch.sigmoid(self.convr2(hx)) q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) h = (1-z) * h + z * q return h def make_model(args, parent=False): return RAFTNET(args) class RAFTNET(nn.Module): def __init__(self, args, conv=common.default_conv): super(RAFTNET, self).__init__() n_resblocks = args.n_resblocks n_feats = args.n_feats kernel_size = 3 scale = args.scale[0] rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) # msa = attention.PyramidAttention() # define head module m_head = [conv(args.n_colors, n_feats, kernel_size)] # perhaps a shallow network here? for i in range(5): m_head.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale)) # convert feature to image, shared m_tail=[] for i in range(5): m_tail.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale)) m_tail.append(conv(n_feats, args.n_colors, kernel_size)) # middle recurrent part layer = LambdaLayer( dim = n_feats, dim_out = n_feats, r = 23, # the receptive field for relative positional encoding (23 x 23) dim_k = 16, heads = 4, dim_u = 4 ) # define body module m_body = [ common.ResBlock( conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale ) for _ in range(n_resblocks//2) ] m_body.append(layer) for i in range(n_resblocks//2): m_body.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale)) m_body.append(conv(n_feats, n_feats, kernel_size)) self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) self.gru = ConvGRU(hidden_dim=64,input_dim=64) self.recurrence = args.recurrence self.detach = args.detach # self.step_detach = args.step_detach self.amp = args.amp def forward(self, x): with amp.autocast(self.amp): x=(x-0.5)/0.5 x = self.head(x) hidden = x.clone() output_lst=[None]*self.recurrence for i in range(self.recurrence): gru_out=self.gru(hidden,x) res=self.body(gru_out) gru_out=res+gru_out hidden=gru_out output_lst[i]=self.tail(gru_out)*0.5+0.5 return output_lst def load_state_dict(self, state_dict, strict=True): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find('tail') == -1: raise RuntimeError('While copying the parameter named {}, ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}.' .format(name, own_state[name].size(), param.size())) elif strict: if name.find('tail') == -1: raise KeyError('unexpected key "{}" in state_dict' .format(name))