DECO / models /components.py
ac5113's picture
added files
99a05f0
import torch
import torchvision
import torch.nn as nn
import numpy as np
from utils.hrnet import hrnet_w32
class Encoder(nn.Module):
def __init__(self, encoder='hrnet', pretrained=True):
super(Encoder, self).__init__()
if encoder == 'swin':
'''Swin Transformer encoder'''
self.encoder = torchvision.models.swin_b(weights='DEFAULT')
self.encoder.head = nn.GELU()
elif encoder == 'hrnet':
'''HRNet encoder'''
self.encoder = hrnet_w32(pretrained=pretrained)
else:
raise NotImplementedError('Encoder not implemented')
def forward(self, x):
out = self.encoder(x)
return out
class Self_Attn(nn.Module):
""" Self attention Layer for Feature Map dimension"""
def __init__(self, in_dim, out_dim):
super(Self_Attn, self).__init__()
self.channel_in = in_dim
self.query_conv = nn.Conv1d(in_channels = in_dim, out_channels = out_dim, kernel_size = 1)
self.key_conv = nn.Conv1d(in_channels = in_dim, out_channels = out_dim, kernel_size = 1)
self.value_conv = nn.Conv1d(in_channels = in_dim, out_channels = out_dim, kernel_size = 1)
self.softmax = nn.Softmax(dim = -1)
def forward(self, q, k, v):
"""
inputs :
x : input feature maps(B X C X H X W)
returns :
out : self attention value + input feature
attention: B X N X N (N is Height * Width)
"""
batchsize, C, height = q.size()
# proj_query: reshape to B x N x c, N = H x W
proj_query = self.query_conv(q.permute(0, 2, 1))
# proj_query: reshape to B x c x N, N = H x W
proj_key = self.key_conv(k.permute(0, 2, 1))
# transpose check, energy: B x N x N, N = H x W
energy = torch.bmm(proj_query, proj_key.permute(0, 2, 1))
# attention: B x N x N, N = H x W
attention = self.softmax(energy)
# proj_value is normal convolution, B x C x N
proj_value = self.value_conv(v.permute(0, 2, 1))
# out: B x C x N
out = torch.bmm(attention, proj_value)
out = out.view(batchsize, C, height)
out = out/np.sqrt(self.channel_in)
return out
class Cross_Att(nn.Module):
def __init__(self, in_dim, out_dim):
super(Cross_Att, self).__init__()
self.cross_attn_1 = Self_Attn(in_dim, out_dim)
self.cross_attn_2 = Self_Attn(in_dim, out_dim)
self.layer_norm = nn.LayerNorm([1, in_dim])
def forward(self, sem_seg, part_seg):
cross1 = self.cross_attn_1(sem_seg, part_seg, part_seg)
cross2 = self.cross_attn_1(part_seg, sem_seg, sem_seg)
out = cross1 * cross2
out = self.layer_norm(out)
return out
class Decoder(nn.Module):
def __init__(self, in_dim, out_dim, encoder='hrnet'):
super(Decoder, self).__init__()
self.out_dim = out_dim
if encoder == 'swin':
self.upsample = nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
nn.ReLU(),
nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
nn.ReLU(),
nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
nn.Softmax(1)
)
elif encoder == 'hrnet':
self.upsample = nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
nn.ReLU(),
nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
# nn.ReLU(),
# nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
# nn.BatchNorm2d(out_dim),
nn.Softmax(1)
)
else:
raise NotImplementedError('Decoder not implemented')
def forward(self, x):
out = self.upsample(x)
return out
class Classifier(nn.Module):
def __init__(self, in_dim, out_dim=6890):
super(Classifier, self).__init__()
self.out_dim = out_dim
self.classifier = nn.Sequential(
nn.Linear(in_dim, 4096, True),
nn.ReLU(),
nn.Linear(4096, out_dim, True),
nn.Sigmoid()
)
def forward(self, x):
out = self.classifier(x)
return out.reshape(-1, self.out_dim)