EditGuard / models /bitnetwork /DW_EncoderDecoder.py
Ricoooo's picture
'folder'
5d21dd2
raw
history blame
1.16 kB
from . import *
from .Encoder_U import DW_Encoder
from .Decoder_U import DW_Decoder
from .Noise import Noise
from .Random_Noise import Random_Noise
class DW_EncoderDecoder(nn.Module):
'''
A Sequential of Encoder_MP-Noise-Decoder
'''
def __init__(self, message_length, noise_layers_R, noise_layers_F, attention_encoder, attention_decoder):
super(DW_EncoderDecoder, self).__init__()
self.encoder = DW_Encoder(message_length, attention = attention_encoder)
self.noise = Random_Noise(noise_layers_R + noise_layers_F, len(noise_layers_R), len(noise_layers_F))
self.decoder_C = DW_Decoder(message_length, attention = attention_decoder)
self.decoder_RF = DW_Decoder(message_length, attention = attention_decoder)
def forward(self, image, message, mask):
encoded_image = self.encoder(image, message)
noised_image_C, noised_image_R, noised_image_F = self.noise([encoded_image, image, mask])
decoded_message_C = self.decoder_C(noised_image_C)
decoded_message_R = self.decoder_RF(noised_image_R)
decoded_message_F = self.decoder_RF(noised_image_F)
return encoded_image, noised_image_C, decoded_message_C, decoded_message_R, decoded_message_F