Spaces:
Sleeping
Sleeping
File size: 4,814 Bytes
99a05f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torchvision
import torch.nn as nn
import numpy as np
from utils.hrnet import hrnet_w32
class Encoder(nn.Module):
def __init__(self, encoder='hrnet', pretrained=True):
super(Encoder, self).__init__()
if encoder == 'swin':
'''Swin Transformer encoder'''
self.encoder = torchvision.models.swin_b(weights='DEFAULT')
self.encoder.head = nn.GELU()
elif encoder == 'hrnet':
'''HRNet encoder'''
self.encoder = hrnet_w32(pretrained=pretrained)
else:
raise NotImplementedError('Encoder not implemented')
def forward(self, x):
out = self.encoder(x)
return out
class Self_Attn(nn.Module):
""" Self attention Layer for Feature Map dimension"""
def __init__(self, in_dim, out_dim):
super(Self_Attn, self).__init__()
self.channel_in = in_dim
self.query_conv = nn.Conv1d(in_channels = in_dim, out_channels = out_dim, kernel_size = 1)
self.key_conv = nn.Conv1d(in_channels = in_dim, out_channels = out_dim, kernel_size = 1)
self.value_conv = nn.Conv1d(in_channels = in_dim, out_channels = out_dim, kernel_size = 1)
self.softmax = nn.Softmax(dim = -1)
def forward(self, q, k, v):
"""
inputs :
x : input feature maps(B X C X H X W)
returns :
out : self attention value + input feature
attention: B X N X N (N is Height * Width)
"""
batchsize, C, height = q.size()
# proj_query: reshape to B x N x c, N = H x W
proj_query = self.query_conv(q.permute(0, 2, 1))
# proj_query: reshape to B x c x N, N = H x W
proj_key = self.key_conv(k.permute(0, 2, 1))
# transpose check, energy: B x N x N, N = H x W
energy = torch.bmm(proj_query, proj_key.permute(0, 2, 1))
# attention: B x N x N, N = H x W
attention = self.softmax(energy)
# proj_value is normal convolution, B x C x N
proj_value = self.value_conv(v.permute(0, 2, 1))
# out: B x C x N
out = torch.bmm(attention, proj_value)
out = out.view(batchsize, C, height)
out = out/np.sqrt(self.channel_in)
return out
class Cross_Att(nn.Module):
def __init__(self, in_dim, out_dim):
super(Cross_Att, self).__init__()
self.cross_attn_1 = Self_Attn(in_dim, out_dim)
self.cross_attn_2 = Self_Attn(in_dim, out_dim)
self.layer_norm = nn.LayerNorm([1, in_dim])
def forward(self, sem_seg, part_seg):
cross1 = self.cross_attn_1(sem_seg, part_seg, part_seg)
cross2 = self.cross_attn_1(part_seg, sem_seg, sem_seg)
out = cross1 * cross2
out = self.layer_norm(out)
return out
class Decoder(nn.Module):
def __init__(self, in_dim, out_dim, encoder='hrnet'):
super(Decoder, self).__init__()
self.out_dim = out_dim
if encoder == 'swin':
self.upsample = nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
nn.ReLU(),
nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
nn.ReLU(),
nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
nn.Softmax(1)
)
elif encoder == 'hrnet':
self.upsample = nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
nn.ReLU(),
nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_dim),
# nn.ReLU(),
# nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
# nn.BatchNorm2d(out_dim),
nn.Softmax(1)
)
else:
raise NotImplementedError('Decoder not implemented')
def forward(self, x):
out = self.upsample(x)
return out
class Classifier(nn.Module):
def __init__(self, in_dim, out_dim=6890):
super(Classifier, self).__init__()
self.out_dim = out_dim
self.classifier = nn.Sequential(
nn.Linear(in_dim, 4096, True),
nn.ReLU(),
nn.Linear(4096, out_dim, True),
nn.Sigmoid()
)
def forward(self, x):
out = self.classifier(x)
return out.reshape(-1, self.out_dim) |