from . import * from .Encoder_U import DW_Encoder from .Decoder_U import DW_Decoder from .Noise import Noise from .Random_Noise import Random_Noise class DW_EncoderDecoder(nn.Module): ''' A Sequential of Encoder_MP-Noise-Decoder ''' def __init__(self, message_length, noise_layers_R, noise_layers_F, attention_encoder, attention_decoder): super(DW_EncoderDecoder, self).__init__() self.encoder = DW_Encoder(message_length, attention = attention_encoder) self.noise = Random_Noise(noise_layers_R + noise_layers_F, len(noise_layers_R), len(noise_layers_F)) self.decoder_C = DW_Decoder(message_length, attention = attention_decoder) self.decoder_RF = DW_Decoder(message_length, attention = attention_decoder) def forward(self, image, message, mask): encoded_image = self.encoder(image, message) noised_image_C, noised_image_R, noised_image_F = self.noise([encoded_image, image, mask]) decoded_message_C = self.decoder_C(noised_image_C) decoded_message_R = self.decoder_RF(noised_image_R) decoded_message_F = self.decoder_RF(noised_image_F) return encoded_image, noised_image_C, decoded_message_C, decoded_message_R, decoded_message_F