|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .module_util import initialize_weights_xavier |
|
from torch.nn import init |
|
from .common import DWT,IWT |
|
import cv2 |
|
from basicsr.archs.arch_util import flow_warp |
|
from models.modules.Subnet_constructor import subnet |
|
import numpy as np |
|
|
|
from pdb import set_trace as stx |
|
import numbers |
|
|
|
from einops import rearrange |
|
from models.bitnetwork.Encoder_U import DW_Encoder |
|
from models.bitnetwork.Decoder_U import DW_Decoder |
|
|
|
|
|
|
|
def to_3d(x): |
|
return rearrange(x, 'b c h w -> b (h w) c') |
|
|
|
|
|
def to_4d(x, h, w): |
|
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) |
|
|
|
|
|
class BiasFree_LayerNorm(nn.Module): |
|
def __init__(self, normalized_shape): |
|
super(BiasFree_LayerNorm, self).__init__() |
|
if isinstance(normalized_shape, numbers.Integral): |
|
normalized_shape = (normalized_shape,) |
|
normalized_shape = torch.Size(normalized_shape) |
|
|
|
assert len(normalized_shape) == 1 |
|
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
self.normalized_shape = normalized_shape |
|
|
|
def forward(self, x): |
|
sigma = x.var(-1, keepdim=True, unbiased=False) |
|
return x / torch.sqrt(sigma + 1e-5) * self.weight |
|
|
|
|
|
class WithBias_LayerNorm(nn.Module): |
|
def __init__(self, normalized_shape): |
|
super(WithBias_LayerNorm, self).__init__() |
|
if isinstance(normalized_shape, numbers.Integral): |
|
normalized_shape = (normalized_shape,) |
|
normalized_shape = torch.Size(normalized_shape) |
|
|
|
assert len(normalized_shape) == 1 |
|
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
|
self.normalized_shape = normalized_shape |
|
|
|
def forward(self, x): |
|
mu = x.mean(-1, keepdim=True) |
|
sigma = x.var(-1, keepdim=True, unbiased=False) |
|
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
def __init__(self, dim, LayerNorm_type): |
|
super(LayerNorm, self).__init__() |
|
if LayerNorm_type == 'BiasFree': |
|
self.body = BiasFree_LayerNorm(dim) |
|
else: |
|
self.body = WithBias_LayerNorm(dim) |
|
|
|
def forward(self, x): |
|
h, w = x.shape[-2:] |
|
return to_4d(self.body(to_3d(x)), h, w) |
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, ffn_expansion_factor, bias): |
|
super(FeedForward, self).__init__() |
|
|
|
hidden_features = int(dim * ffn_expansion_factor) |
|
|
|
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) |
|
|
|
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, |
|
groups=hidden_features * 2, bias=bias) |
|
|
|
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) |
|
|
|
def forward(self, x): |
|
x = self.project_in(x) |
|
x1, x2 = self.dwconv(x).chunk(2, dim=1) |
|
x = F.gelu(x1) * x2 |
|
x = self.project_out(x) |
|
return x |
|
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, num_heads, bias): |
|
super(Attention, self).__init__() |
|
self.num_heads = num_heads |
|
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) |
|
|
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) |
|
self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) |
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) |
|
|
|
def forward(self, x): |
|
b, c, h, w = x.shape |
|
|
|
qkv = self.qkv_dwconv(self.qkv(x)) |
|
q, k, v = qkv.chunk(3, dim=1) |
|
|
|
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) |
|
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) |
|
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) |
|
|
|
q = torch.nn.functional.normalize(q, dim=-1) |
|
k = torch.nn.functional.normalize(k, dim=-1) |
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.temperature |
|
attn = attn.softmax(dim=-1) |
|
|
|
out = (attn @ v) |
|
|
|
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) |
|
|
|
out = self.project_out(out) |
|
return out |
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, dim, num_heads=4, ffn_expansion_factor=4, bias=False, LayerNorm_type="withbias"): |
|
super(TransformerBlock, self).__init__() |
|
|
|
self.norm1 = LayerNorm(dim, LayerNorm_type) |
|
self.attn = Attention(dim, num_heads, bias) |
|
self.norm2 = LayerNorm(dim, LayerNorm_type) |
|
self.ffn = FeedForward(dim, ffn_expansion_factor, bias) |
|
|
|
def forward(self, x): |
|
x = x + self.attn(self.norm1(x)) |
|
x = x + self.ffn(self.norm2(x)) |
|
|
|
return x |
|
|
|
dwt=DWT() |
|
iwt=IWT() |
|
|
|
class LayerNormFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, x, weight, bias, eps): |
|
ctx.eps = eps |
|
N, C, H, W = x.size() |
|
mu = x.mean(1, keepdim=True) |
|
var = (x - mu).pow(2).mean(1, keepdim=True) |
|
y = (x - mu) / (var + eps).sqrt() |
|
ctx.save_for_backward(y, var, weight) |
|
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) |
|
return y |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
eps = ctx.eps |
|
|
|
N, C, H, W = grad_output.size() |
|
y, var, weight = ctx.saved_variables |
|
g = grad_output * weight.view(1, C, 1, 1) |
|
mean_g = g.mean(dim=1, keepdim=True) |
|
|
|
mean_gy = (g * y).mean(dim=1, keepdim=True) |
|
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) |
|
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( |
|
dim=0), None |
|
|
|
class LayerNorm2d(nn.Module): |
|
|
|
def __init__(self, channels, eps=1e-6): |
|
super(LayerNorm2d, self).__init__() |
|
self.register_parameter('weight', nn.Parameter(torch.ones(channels))) |
|
self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) |
|
self.eps = eps |
|
|
|
def forward(self, x): |
|
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) |
|
|
|
class SimpleGate(nn.Module): |
|
def forward(self, x): |
|
x1, x2 = x.chunk(2, dim=1) |
|
return x1 * x2 |
|
|
|
class NAFBlock(nn.Module): |
|
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): |
|
super().__init__() |
|
dw_channel = c * DW_Expand |
|
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) |
|
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, |
|
bias=True) |
|
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) |
|
|
|
|
|
self.sca = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, |
|
groups=1, bias=True), |
|
) |
|
|
|
|
|
self.sg = SimpleGate() |
|
|
|
ffn_channel = FFN_Expand * c |
|
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) |
|
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) |
|
|
|
self.norm1 = LayerNorm2d(c) |
|
self.norm2 = LayerNorm2d(c) |
|
|
|
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() |
|
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() |
|
|
|
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) |
|
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) |
|
|
|
def forward(self, inp): |
|
x = inp |
|
|
|
x = self.norm1(x) |
|
|
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
x = self.sg(x) |
|
x = x * self.sca(x) |
|
x = self.conv3(x) |
|
|
|
x = self.dropout1(x) |
|
|
|
y = inp + x * self.beta |
|
|
|
x = self.conv4(self.norm2(y)) |
|
x = self.sg(x) |
|
x = self.conv5(x) |
|
|
|
x = self.dropout2(x) |
|
|
|
return y + x * self.gamma |
|
|
|
def thops_mean(tensor, dim=None, keepdim=False): |
|
if dim is None: |
|
|
|
return torch.mean(tensor) |
|
else: |
|
if isinstance(dim, int): |
|
dim = [dim] |
|
dim = sorted(dim) |
|
for d in dim: |
|
tensor = tensor.mean(dim=d, keepdim=True) |
|
if not keepdim: |
|
for i, d in enumerate(dim): |
|
tensor.squeeze_(d-i) |
|
return tensor |
|
|
|
|
|
class ResidualBlockNoBN(nn.Module): |
|
def __init__(self, nf=64, model='MIMO-VRN'): |
|
super(ResidualBlockNoBN, self).__init__() |
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
|
|
|
if model == 'LSTM-VRN': |
|
self.relu = nn.ReLU(inplace=True) |
|
elif model == 'MIMO-VRN': |
|
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
|
|
initialize_weights_xavier([self.conv1, self.conv2], 0.1) |
|
|
|
def forward(self, x): |
|
identity = x |
|
out = self.relu(self.conv1(x)) |
|
out = self.conv2(out) |
|
return identity + out |
|
|
|
|
|
class InvBlock(nn.Module): |
|
def __init__(self, subnet_constructor, subnet_constructor_v2, channel_num_ho, channel_num_hi, groups, clamp=1.): |
|
super(InvBlock, self).__init__() |
|
self.split_len1 = channel_num_ho |
|
self.split_len2 = channel_num_hi |
|
self.clamp = clamp |
|
|
|
self.F = subnet_constructor_v2(self.split_len2, self.split_len1, groups=groups) |
|
self.NF = NAFBlock(self.split_len2) |
|
if groups == 1: |
|
self.G = subnet_constructor(self.split_len1, self.split_len2, groups=groups) |
|
self.NG = NAFBlock(self.split_len1) |
|
self.H = subnet_constructor(self.split_len1, self.split_len2, groups=groups) |
|
self.NH = NAFBlock(self.split_len1) |
|
else: |
|
self.G = subnet_constructor(self.split_len1, self.split_len2) |
|
self.NG = NAFBlock(self.split_len1) |
|
self.H = subnet_constructor(self.split_len1, self.split_len2) |
|
self.NH = NAFBlock(self.split_len1) |
|
|
|
def forward(self, x1, x2, rev=False): |
|
if not rev: |
|
y1 = x1 + self.NF(self.F(x2)) |
|
self.s = self.clamp * (torch.sigmoid(self.NH(self.H(y1))) * 2 - 1) |
|
y2 = [x2i.mul(torch.exp(self.s)) + self.NG(self.G(y1)) for x2i in x2] |
|
else: |
|
self.s = self.clamp * (torch.sigmoid(self.NH(self.H(x1))) * 2 - 1) |
|
y2 = [(x2i - self.NG(self.G(x1))).div(torch.exp(self.s)) for x2i in x2] |
|
y1 = x1 - self.NF(self.F(y2)) |
|
|
|
return y1, y2 |
|
|
|
def jacobian(self, x, rev=False): |
|
if not rev: |
|
jac = torch.sum(self.s) |
|
else: |
|
jac = -torch.sum(self.s) |
|
|
|
return jac / x.shape[0] |
|
|
|
class InvNN(nn.Module): |
|
def __init__(self, channel_in_ho=3, channel_in_hi=3, subnet_constructor=None, subnet_constructor_v2=None, block_num=[], down_num=2, groups=None): |
|
super(InvNN, self).__init__() |
|
operations = [] |
|
|
|
current_channel_ho = channel_in_ho |
|
current_channel_hi = channel_in_hi |
|
for i in range(down_num): |
|
for j in range(block_num[i]): |
|
b = InvBlock(subnet_constructor, subnet_constructor_v2, current_channel_ho, current_channel_hi, groups=groups) |
|
operations.append(b) |
|
|
|
self.operations = nn.ModuleList(operations) |
|
|
|
def forward(self, x, x_h, rev=False, cal_jacobian=False): |
|
|
|
jacobian = 0 |
|
|
|
if not rev: |
|
for op in self.operations: |
|
x, x_h = op.forward(x, x_h, rev) |
|
if cal_jacobian: |
|
jacobian += op.jacobian(x, rev) |
|
else: |
|
for op in reversed(self.operations): |
|
x, x_h = op.forward(x, x_h, rev) |
|
if cal_jacobian: |
|
jacobian += op.jacobian(x, rev) |
|
|
|
if cal_jacobian: |
|
return x, x_h, jacobian |
|
else: |
|
return x, x_h |
|
|
|
class PredictiveModuleMIMO(nn.Module): |
|
def __init__(self, channel_in, nf, block_num_rbm=8, block_num_trans=4): |
|
super(PredictiveModuleMIMO, self).__init__() |
|
self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True) |
|
res_block = [] |
|
trans_block = [] |
|
for i in range(block_num_rbm): |
|
res_block.append(ResidualBlockNoBN(nf)) |
|
for j in range(block_num_trans): |
|
trans_block.append(TransformerBlock(nf)) |
|
|
|
self.res_block = nn.Sequential(*res_block) |
|
self.transformer_block = nn.Sequential(*trans_block) |
|
|
|
def forward(self, x): |
|
x = self.conv_in(x) |
|
x = self.res_block(x) |
|
res = self.transformer_block(x) + x |
|
|
|
return res |
|
|
|
class ConvRelu(nn.Module): |
|
def __init__(self, channels_in, channels_out, stride=1, init_zero=False): |
|
super(ConvRelu, self).__init__() |
|
self.init_zero = init_zero |
|
if self.init_zero: |
|
self.layers = nn.Conv2d(channels_in, channels_out, 3, stride, padding=1) |
|
|
|
else: |
|
self.layers = nn.Sequential( |
|
nn.Conv2d(channels_in, channels_out, 3, stride, padding=1), |
|
nn.LeakyReLU(inplace=True) |
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
class PredictiveModuleBit(nn.Module): |
|
def __init__(self, channel_in, nf, block_num_rbm=4, block_num_trans=2): |
|
super(PredictiveModuleBit, self).__init__() |
|
self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True) |
|
res_block = [] |
|
trans_block = [] |
|
for i in range(block_num_rbm): |
|
res_block.append(ResidualBlockNoBN(nf)) |
|
for j in range(block_num_trans): |
|
trans_block.append(TransformerBlock(nf)) |
|
|
|
blocks = 4 |
|
layers = [ConvRelu(nf, 1, 2)] |
|
for _ in range(blocks - 1): |
|
layer = ConvRelu(1, 1, 2) |
|
layers.append(layer) |
|
self.layers = nn.Sequential(*layers) |
|
|
|
self.res_block = nn.Sequential(*res_block) |
|
self.transformer_block = nn.Sequential(*trans_block) |
|
|
|
def forward(self, x): |
|
x = self.conv_in(x) |
|
x = self.res_block(x) |
|
res = self.transformer_block(x) + x |
|
res = self.layers(res) |
|
|
|
return res |
|
|
|
|
|
|
|
class PromptGenBlock(nn.Module): |
|
def __init__(self,prompt_dim=12,prompt_len=3,prompt_size = 36,lin_dim = 12): |
|
super(PromptGenBlock,self).__init__() |
|
self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size)) |
|
self.linear_layer = nn.Linear(lin_dim,prompt_len) |
|
self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False) |
|
|
|
|
|
def forward(self,x): |
|
B,C,H,W = x.shape |
|
emb = x.mean(dim=(-2,-1)) |
|
prompt_weights = F.softmax(self.linear_layer(emb),dim=1) |
|
prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1) |
|
prompt = torch.sum(prompt,dim=1) |
|
prompt = F.interpolate(prompt,(H,W),mode="bilinear") |
|
prompt = self.conv3x3(prompt) |
|
|
|
return prompt |
|
|
|
class PredictiveModuleMIMO_prompt(nn.Module): |
|
def __init__(self, channel_in, nf, prompt_len=3, block_num_rbm=8, block_num_trans=4): |
|
super(PredictiveModuleMIMO_prompt, self).__init__() |
|
self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True) |
|
res_block = [] |
|
trans_block = [] |
|
for i in range(block_num_rbm): |
|
res_block.append(ResidualBlockNoBN(nf)) |
|
for j in range(block_num_trans): |
|
trans_block.append(TransformerBlock(nf)) |
|
|
|
self.res_block = nn.Sequential(*res_block) |
|
self.transformer_block = nn.Sequential(*trans_block) |
|
self.prompt = PromptGenBlock(prompt_dim=nf,prompt_len=prompt_len,prompt_size = 36,lin_dim = nf) |
|
self.fuse = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) |
|
|
|
def forward(self, x): |
|
x = self.conv_in(x) |
|
x = self.res_block(x) |
|
res = self.transformer_block(x) + x |
|
prompt = self.prompt(res) |
|
|
|
result = self.fuse(torch.cat([res, prompt], dim=1)) |
|
|
|
return result |
|
|
|
def gauss_noise(shape): |
|
noise = torch.zeros(shape).cuda() |
|
for i in range(noise.shape[0]): |
|
noise[i] = torch.randn(noise[i].shape).cuda() |
|
|
|
return noise |
|
|
|
def gauss_noise_mul(shape): |
|
noise = torch.randn(shape).cuda() |
|
|
|
return noise |
|
|
|
class PredictiveModuleBit_prompt(nn.Module): |
|
def __init__(self, channel_in, nf, prompt_length, block_num_rbm=4, block_num_trans=2): |
|
super(PredictiveModuleBit_prompt, self).__init__() |
|
self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True) |
|
res_block = [] |
|
trans_block = [] |
|
for i in range(block_num_rbm): |
|
res_block.append(ResidualBlockNoBN(nf)) |
|
for j in range(block_num_trans): |
|
trans_block.append(TransformerBlock(nf)) |
|
|
|
blocks = 4 |
|
layers = [ConvRelu(nf, 1, 2)] |
|
for _ in range(blocks - 1): |
|
layer = ConvRelu(1, 1, 2) |
|
layers.append(layer) |
|
self.layers = nn.Sequential(*layers) |
|
|
|
self.res_block = nn.Sequential(*res_block) |
|
self.transformer_block = nn.Sequential(*trans_block) |
|
self.prompt = PromptGenBlock(prompt_dim=nf,prompt_len=prompt_length,prompt_size = 36,lin_dim = nf) |
|
self.fuse = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) |
|
|
|
def forward(self, x): |
|
x = self.conv_in(x) |
|
x = self.res_block(x) |
|
res = self.transformer_block(x) + x |
|
prompt = self.prompt(res) |
|
res = self.fuse(torch.cat([res, prompt], dim=1)) |
|
res = self.layers(res) |
|
|
|
return res |
|
|
|
class VSN(nn.Module): |
|
def __init__(self, opt, subnet_constructor=None, subnet_constructor_v2=None, down_num=2): |
|
super(VSN, self).__init__() |
|
self.model = opt['model'] |
|
self.mode = opt['mode'] |
|
opt_net = opt['network_G'] |
|
self.num_image = opt['num_image'] |
|
self.gop = opt['gop'] |
|
self.channel_in = opt_net['in_nc'] * self.gop |
|
self.channel_out = opt_net['out_nc'] * self.gop |
|
self.channel_in_hi = opt_net['in_nc'] * self.gop |
|
self.channel_in_ho = opt_net['in_nc'] * self.gop |
|
self.message_len = opt['message_length'] |
|
|
|
self.block_num = opt_net['block_num'] |
|
self.block_num_rbm = opt_net['block_num_rbm'] |
|
self.block_num_trans = opt_net['block_num_trans'] |
|
self.nf = self.channel_in_hi |
|
|
|
self.bitencoder = DW_Encoder(self.message_len, attention = "se") |
|
self.bitdecoder = DW_Decoder(self.message_len, attention = "se") |
|
self.irn = InvNN(self.channel_in_ho, self.channel_in_hi, subnet_constructor, subnet_constructor_v2, self.block_num, down_num, groups=self.num_image) |
|
|
|
if opt['prompt']: |
|
self.pm = PredictiveModuleMIMO_prompt(self.channel_in_ho, self.nf* self.num_image, opt['prompt_len'], block_num_rbm=self.block_num_rbm, block_num_trans=self.block_num_trans) |
|
else: |
|
self.pm = PredictiveModuleMIMO(self.channel_in_ho, self.nf* self.num_image, opt['prompt_len'], block_num_rbm=self.block_num_rbm, block_num_trans=self.block_num_trans) |
|
self.BitPM = PredictiveModuleBit(3, 4, block_num_rbm=4, block_num_trans=2) |
|
|
|
|
|
def forward(self, x, x_h=None, message=None, rev=False, hs=[], direction='f'): |
|
if not rev: |
|
if self.mode == "image": |
|
out_y, out_y_h = self.irn(x, x_h, rev) |
|
out_y = iwt(out_y) |
|
encoded_image = self.bitencoder(out_y, message) |
|
return out_y, encoded_image |
|
|
|
elif self.mode == "bit": |
|
out_y = iwt(x) |
|
encoded_image = self.bitencoder(out_y, message) |
|
return out_y, encoded_image |
|
|
|
else: |
|
if self.mode == "image": |
|
recmessage = self.bitdecoder(x) |
|
|
|
x = dwt(x) |
|
out_z = self.pm(x).unsqueeze(1) |
|
out_z_new = out_z.view(-1, self.num_image, self.channel_in, x.shape[-2], x.shape[-1]) |
|
out_z_new = [out_z_new[:,i] for i in range(self.num_image)] |
|
out_x, out_x_h = self.irn(x, out_z_new, rev) |
|
|
|
return out_x, out_x_h, out_z, recmessage |
|
|
|
elif self.mode == "bit": |
|
recmessage = self.bitdecoder(x) |
|
return recmessage |
|
|
|
|