import torch from torch import nn class SelfAttention(nn.Module): def __init__(self, in_channels): super(SelfAttention, self).__init__() self.query = nn.Conv2d(in_channels, in_channels//8, 1) self.key = nn.Conv2d(in_channels, in_channels//8, 1) self.value = nn.Conv2d(in_channels, in_channels, 1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): batch_size, C, H, W = x.size() q = self.query(x).view(batch_size, -1, H*W).permute(0, 2, 1) k = self.key(x).view(batch_size, -1, H*W) v = self.value(x).view(batch_size, -1, H*W) attention = torch.bmm(q, k) attention = torch.softmax(attention, dim=-1) out = torch.bmm(v, attention.permute(0, 2, 1)) out = out.view(batch_size, C, H, W) return self.gamma * out + x class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels) self.relu = nn.ReLU() def forward(self, x): residual = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += residual out = self.relu(out) return out class aeModel(nn.Module): def __init__(self): super(aeModel, self).__init__() self.encoder = nn.ModuleList([ nn.Sequential( nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.BatchNorm2d(32), nn.ReLU(), ResidualBlock(32) ), nn.Sequential( nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(), ResidualBlock(64) ), nn.Sequential( nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(), ResidualBlock(128), SelfAttention(128) ), nn.Sequential( nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(), ResidualBlock(256), SelfAttention(256) ) ]) self.decoder = nn.ModuleList([ nn.Sequential( nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(128), nn.ReLU(), ResidualBlock(128), SelfAttention(128) ), nn.Sequential( nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(64), nn.ReLU(), ResidualBlock(64) ), nn.Sequential( nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(32), nn.ReLU(), ResidualBlock(32) ), nn.Sequential( nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1), nn.Sigmoid() ) ]) def forward(self, x): for encoder_block in self.encoder: x = encoder_block(x) for decoder_block in self.decoder: x = decoder_block(x) return x def encode(self, x): for encoder_block in self.encoder: x = encoder_block(x) return x def decode(self, x): for decoder_block in self.decoder: x = decoder_block(x) return x