|
from . import * |
|
from .Encoder_U import DW_Encoder |
|
from .Decoder_U import DW_Decoder |
|
from .Noise import Noise |
|
from .Random_Noise import Random_Noise |
|
|
|
|
|
class DW_EncoderDecoder(nn.Module): |
|
''' |
|
A Sequential of Encoder_MP-Noise-Decoder |
|
''' |
|
|
|
def __init__(self, message_length, noise_layers_R, noise_layers_F, attention_encoder, attention_decoder): |
|
super(DW_EncoderDecoder, self).__init__() |
|
self.encoder = DW_Encoder(message_length, attention = attention_encoder) |
|
self.noise = Random_Noise(noise_layers_R + noise_layers_F, len(noise_layers_R), len(noise_layers_F)) |
|
self.decoder_C = DW_Decoder(message_length, attention = attention_decoder) |
|
self.decoder_RF = DW_Decoder(message_length, attention = attention_decoder) |
|
|
|
|
|
def forward(self, image, message, mask): |
|
encoded_image = self.encoder(image, message) |
|
noised_image_C, noised_image_R, noised_image_F = self.noise([encoded_image, image, mask]) |
|
decoded_message_C = self.decoder_C(noised_image_C) |
|
decoded_message_R = self.decoder_RF(noised_image_R) |
|
decoded_message_F = self.decoder_RF(noised_image_F) |
|
return encoded_image, noised_image_C, decoded_message_C, decoded_message_R, decoded_message_F |
|
|
|
|