|
import os |
|
import random |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
from torchvision import transforms |
|
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
class _bn_relu_conv(nn.Module): |
|
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): |
|
super(_bn_relu_conv, self).__init__() |
|
self.model = nn.Sequential( |
|
nn.BatchNorm2d(in_filters, eps=1e-3), |
|
nn.LeakyReLU(0.2), |
|
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros') |
|
) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) |
|
for i,layer in enumerate(self.model): |
|
if i != 2: |
|
x = layer(x) |
|
else: |
|
x = layer(x) |
|
|
|
print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) |
|
print(x[0]) |
|
return x |
|
|
|
|
|
class _u_bn_relu_conv(nn.Module): |
|
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): |
|
super(_u_bn_relu_conv, self).__init__() |
|
self.model = nn.Sequential( |
|
nn.BatchNorm2d(in_filters, eps=1e-3), |
|
nn.LeakyReLU(0.2), |
|
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)), |
|
nn.Upsample(scale_factor=2, mode='nearest') |
|
) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
|
|
class _shortcut(nn.Module): |
|
def __init__(self, in_filters, nb_filters, subsample=1): |
|
super(_shortcut, self).__init__() |
|
self.process = False |
|
self.model = None |
|
if in_filters != nb_filters or subsample != 1: |
|
self.process = True |
|
self.model = nn.Sequential( |
|
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample) |
|
) |
|
|
|
def forward(self, x, y): |
|
|
|
if self.process: |
|
y0 = self.model(x) |
|
|
|
return y0 + y |
|
else: |
|
|
|
return x + y |
|
|
|
class _u_shortcut(nn.Module): |
|
def __init__(self, in_filters, nb_filters, subsample): |
|
super(_u_shortcut, self).__init__() |
|
self.process = False |
|
self.model = None |
|
if in_filters != nb_filters: |
|
self.process = True |
|
self.model = nn.Sequential( |
|
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'), |
|
nn.Upsample(scale_factor=2, mode='nearest') |
|
) |
|
|
|
def forward(self, x, y): |
|
if self.process: |
|
return self.model(x) + y |
|
else: |
|
return x + y |
|
|
|
|
|
class basic_block(nn.Module): |
|
def __init__(self, in_filters, nb_filters, init_subsample=1): |
|
super(basic_block, self).__init__() |
|
self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) |
|
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) |
|
self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample) |
|
|
|
def forward(self, x): |
|
x1 = self.conv1(x) |
|
x2 = self.residual(x1) |
|
return self.shortcut(x, x2) |
|
|
|
class _u_basic_block(nn.Module): |
|
def __init__(self, in_filters, nb_filters, init_subsample=1): |
|
super(_u_basic_block, self).__init__() |
|
self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) |
|
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) |
|
self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample) |
|
|
|
def forward(self, x): |
|
y = self.residual(self.conv1(x)) |
|
return self.shortcut(x, y) |
|
|
|
|
|
class _residual_block(nn.Module): |
|
def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False): |
|
super(_residual_block, self).__init__() |
|
layers = [] |
|
for i in range(repetitions): |
|
init_subsample = 1 |
|
if i == repetitions - 1 and not is_first_layer: |
|
init_subsample = 2 |
|
if i == 0: |
|
l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample) |
|
else: |
|
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample) |
|
layers.append(l) |
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class _upsampling_residual_block(nn.Module): |
|
def __init__(self, in_filters, nb_filters, repetitions): |
|
super(_upsampling_residual_block, self).__init__() |
|
layers = [] |
|
for i in range(repetitions): |
|
l = None |
|
if i == 0: |
|
l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters) |
|
else: |
|
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters) |
|
layers.append(l) |
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class res_skip(nn.Module): |
|
|
|
def __init__(self): |
|
super(res_skip, self).__init__() |
|
self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True) |
|
self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3) |
|
self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5) |
|
self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7) |
|
self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12) |
|
|
|
self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7) |
|
self.res1 = _shortcut(in_filters=192, nb_filters=192) |
|
|
|
self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5) |
|
self.res2 = _shortcut(in_filters=96, nb_filters=96) |
|
|
|
self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3) |
|
self.res3 = _shortcut(in_filters=48, nb_filters=48) |
|
|
|
self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2) |
|
self.res4 = _shortcut(in_filters=24, nb_filters=24) |
|
|
|
self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True) |
|
self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1) |
|
|
|
def forward(self, x): |
|
x0 = self.block0(x) |
|
x1 = self.block1(x0) |
|
x2 = self.block2(x1) |
|
x3 = self.block3(x2) |
|
x4 = self.block4(x3) |
|
|
|
x5 = self.block5(x4) |
|
res1 = self.res1(x3, x5) |
|
|
|
x6 = self.block6(res1) |
|
res2 = self.res2(x2, x6) |
|
|
|
x7 = self.block7(res2) |
|
res3 = self.res3(x1, x7) |
|
|
|
x8 = self.block8(res3) |
|
res4 = self.res4(x0, x8) |
|
|
|
x9 = self.block9(res4) |
|
y = self.conv15(x9) |
|
|
|
return y |
|
|
|
class MyDataset(Dataset): |
|
def __init__(self, image_paths, transform=None): |
|
self.image_paths = image_paths |
|
self.transform = transform |
|
|
|
def get_class_label(self, image_name): |
|
|
|
head, tail = os.path.split(image_name) |
|
|
|
return tail |
|
|
|
def __getitem__(self, index): |
|
image_path = self.image_paths[index] |
|
x = Image.open(image_path) |
|
y = self.get_class_label(image_path.split('/')[-1]) |
|
if self.transform is not None: |
|
x = self.transform(x) |
|
return x, y |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def loadImages(folder): |
|
imgs = [] |
|
matches = [] |
|
|
|
|
|
for filename in os.listdir(folder): |
|
|
|
file_path = os.path.join(folder, filename) |
|
|
|
if os.path.isfile(file_path): |
|
matches.append(file_path) |
|
|
|
return matches |
|
|
|
|
|
def crop_center_square(image): |
|
""" |
|
将图像中心裁剪为正方形 |
|
|
|
:param image: PIL.Image对象 |
|
:return: 裁剪后的PIL.Image对象 |
|
""" |
|
|
|
width, height = image.size |
|
|
|
|
|
side_length = min(width, height) |
|
|
|
|
|
left = (width - side_length) // 2 |
|
top = (height - side_length) // 2 |
|
right = left + side_length |
|
bottom = top + side_length |
|
|
|
|
|
cropped_image = image.crop((left, top, right, bottom)) |
|
|
|
return cropped_image |
|
|
|
def crop_image(image, crop_size, stride): |
|
""" |
|
根据给定的裁剪大小和步长裁剪图像,并返回裁剪后的图像列表。 |
|
|
|
:param image: PIL.Image对象 |
|
:param crop_size: 裁剪大小,例如 (384, 384) |
|
:param stride: 重叠步长,例如 128 |
|
:return: 裁剪后的图像列表 |
|
""" |
|
width, height = image.size |
|
crop_width, crop_height = crop_size |
|
cropped_images = [] |
|
|
|
for j in range(0, height - crop_height + 1, stride): |
|
for i in range(0, width - crop_width + 1, stride): |
|
crop_box = (i, j, i + crop_width, j + crop_height) |
|
cropped_image = image.crop(crop_box) |
|
cropped_images.append(cropped_image) |
|
|
|
return cropped_images |
|
|
|
def process_image_ref(image): |
|
""" |
|
处理输入的PIL图像,返回包含所有裁剪后图像的列表。 |
|
|
|
:param image: PIL.Image对象 |
|
:return: 包含所有裁剪后图像的列表 |
|
""" |
|
|
|
resized_image_512 = image.resize((512, 512)) |
|
|
|
|
|
image_list = [resized_image_512] |
|
|
|
|
|
crop_size_384 = (384, 384) |
|
stride_384 = 128 |
|
image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return image_list |
|
|
|
|
|
def process_image_Q(image): |
|
""" |
|
处理输入的PIL图像,返回包含所有裁剪后图像的列表。 |
|
|
|
:param image: PIL.Image对象 |
|
:return: 包含所有裁剪后图像的列表 |
|
""" |
|
|
|
resized_image_512 = image.resize((512, 512)).convert("RGB").convert("RGB") |
|
|
|
|
|
image_list = [] |
|
|
|
|
|
crop_size_384 = (384, 384) |
|
stride_384 = 128 |
|
image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384)) |
|
|
|
return image_list |
|
|
|
def process_image(image, target_width=512, target_height = 512): |
|
|
|
img_width, img_height = image.size |
|
img_ratio = img_width / img_height |
|
|
|
|
|
|
|
target_ratio = target_width / target_height |
|
|
|
|
|
ratio_error = abs(img_ratio - target_ratio) / target_ratio |
|
|
|
if ratio_error < 0.15: |
|
|
|
resized_image = image.resize((target_width, target_height), Image.BICUBIC) |
|
else: |
|
|
|
if img_ratio > target_ratio: |
|
|
|
new_width = int(img_height * target_ratio) |
|
|
|
left = int((0 + img_width - new_width)/2) |
|
top = 0 |
|
right = left + new_width |
|
bottom = img_height |
|
else: |
|
|
|
new_height = int(img_width / target_ratio) |
|
left = 0 |
|
|
|
top = int((0 + img_height - new_height)/2) |
|
right = img_width |
|
bottom = top + new_height |
|
|
|
cropped_image = image.crop((left, top, right, bottom)) |
|
resized_image = cropped_image.resize((target_width, target_height), Image.BICUBIC) |
|
|
|
return resized_image.convert('RGB') |
|
|
|
def crop_image_varres(image, crop_size, h_stride, w_stride): |
|
""" |
|
根据给定的裁剪大小和步长裁剪图像,并返回裁剪后的图像列表。 |
|
|
|
:param image: PIL.Image对象 |
|
:param crop_size: 裁剪大小,例如 (384, 384) |
|
:param stride: 重叠步长,例如 128 |
|
:return: 裁剪后的图像列表 |
|
""" |
|
width, height = image.size |
|
crop_width, crop_height = crop_size |
|
cropped_images = [] |
|
|
|
for j in range(0, height - crop_height + 1, h_stride): |
|
for i in range(0, width - crop_width + 1, w_stride): |
|
crop_box = (i, j, i + crop_width, j + crop_height) |
|
cropped_image = image.crop(crop_box) |
|
cropped_images.append(cropped_image) |
|
|
|
return cropped_images |
|
|
|
def process_image_ref_varres(image, target_width=512, target_height = 512): |
|
""" |
|
处理输入的PIL图像,返回包含所有裁剪后图像的列表。 |
|
|
|
:param image: PIL.Image对象 |
|
:return: 包含所有裁剪后图像的列表 |
|
""" |
|
|
|
resized_image_512 = image.resize((target_width, target_height)) |
|
|
|
|
|
image_list = [resized_image_512] |
|
|
|
|
|
crop_size_384 = (target_width//4*3, target_height//4*3) |
|
w_stride_384 = target_width//4 |
|
h_stride_384 = target_height//4 |
|
image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return image_list |
|
|
|
|
|
def process_image_Q_varres(image, target_width=512, target_height = 512): |
|
""" |
|
处理输入的PIL图像,返回包含所有裁剪后图像的列表。 |
|
|
|
:param image: PIL.Image对象 |
|
:return: 包含所有裁剪后图像的列表 |
|
""" |
|
|
|
resized_image_512 = image.resize((target_width, target_height)).convert("RGB").convert("RGB") |
|
|
|
|
|
image_list = [] |
|
|
|
|
|
crop_size_384 = (target_width//4*3, target_height//4*3) |
|
w_stride_384 = target_width//4 |
|
h_stride_384 = target_height//4 |
|
image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384)) |
|
|
|
|
|
return image_list |
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class ResNetBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, stride=1): |
|
super(ResNetBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) |
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
|
self.shortcut = nn.Sequential() |
|
if stride != 1 or in_channels != out_channels: |
|
self.shortcut = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), |
|
nn.BatchNorm2d(out_channels) |
|
) |
|
|
|
def forward(self, x): |
|
out = F.relu(self.bn1(self.conv1(x))) |
|
out = self.bn2(self.conv2(out)) |
|
out += self.shortcut(x) |
|
out = F.relu(out) |
|
return out |
|
|
|
|
|
class TwoLayerResNet(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(TwoLayerResNet, self).__init__() |
|
self.block1 = ResNetBlock(in_channels, out_channels) |
|
self.block2 = ResNetBlock(out_channels, out_channels) |
|
self.block3 = ResNetBlock(out_channels, out_channels) |
|
self.block4 = ResNetBlock(out_channels, out_channels) |
|
|
|
def forward(self, x): |
|
x = self.block1(x) |
|
x = self.block2(x) |
|
x = self.block3(x) |
|
x = self.block4(x) |
|
return x |
|
|
|
|
|
class MultiHiddenResNetModel(nn.Module): |
|
def __init__(self, channels_list, num_tensors): |
|
super(MultiHiddenResNetModel, self).__init__() |
|
self.two_layer_resnets = nn.ModuleList([TwoLayerResNet(channels_list[idx]*2, channels_list[min(len(channels_list)-1,idx+2)]) for idx in range(num_tensors)]) |
|
|
|
def forward(self, tensor_list): |
|
processed_list = [] |
|
for i, tensor in enumerate(tensor_list): |
|
|
|
tensor = self.two_layer_resnets[i](tensor) |
|
processed_list.append(tensor) |
|
|
|
return processed_list |
|
|
|
|
|
def calculate_target_size(h, w): |
|
|
|
if random.random()>0.5: |
|
target_h = (h // 8) * 8 |
|
target_w = (w // 8) * 8 |
|
elif random.random()>0.5: |
|
target_h = (h // 8) * 8 |
|
target_w = (w // 8) * 8 |
|
else: |
|
target_h = (h // 8) * 8 |
|
target_w = (w // 8) * 8 |
|
|
|
|
|
if target_h == 0: |
|
target_h = 8 |
|
if target_w == 0: |
|
target_w = 8 |
|
|
|
return target_h, target_w |
|
|
|
|
|
def downsample_tensor(tensor): |
|
|
|
b, c, h, w = tensor.shape |
|
|
|
|
|
target_h, target_w = calculate_target_size(h, w) |
|
|
|
|
|
downsampled_tensor = F.interpolate(tensor, size=(target_h, target_w), mode='bilinear', align_corners=False) |
|
|
|
return downsampled_tensor |
|
|
|
|
|
|
|
def get_pixart_config(): |
|
pixart_config = { |
|
"_class_name": "Transformer2DModel", |
|
"_diffusers_version": "0.22.0.dev0", |
|
"activation_fn": "gelu-approximate", |
|
"attention_bias": True, |
|
"attention_head_dim": 72, |
|
"attention_type": "default", |
|
"caption_channels": 4096, |
|
"cross_attention_dim": 1152, |
|
"double_self_attention": False, |
|
"dropout": 0.0, |
|
"in_channels": 4, |
|
|
|
"norm_elementwise_affine": False, |
|
"norm_eps": 1e-06, |
|
"norm_num_groups": 32, |
|
"norm_type": "ada_norm_single", |
|
"num_attention_heads": 16, |
|
"num_embeds_ada_norm": 1000, |
|
"num_layers": 28, |
|
"num_vector_embeds": None, |
|
"only_cross_attention": False, |
|
"out_channels": 8, |
|
"patch_size": 2, |
|
"sample_size": 128, |
|
"upcast_attention": False, |
|
|
|
"use_linear_projection": False |
|
} |
|
return pixart_config |
|
|
|
|
|
|
|
class DoubleConv(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.double_conv = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(), |
|
nn.Conv2d(out_channels, out_channels, 3, 1, 1), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU() |
|
) |
|
|
|
def forward(self, x): |
|
return self.double_conv(x) |
|
|
|
|
|
class UNet(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.left_conv_1 = DoubleConv(6, 64) |
|
self.down_1 = nn.MaxPool2d(2, 2) |
|
|
|
self.left_conv_2 = DoubleConv(64, 128) |
|
self.down_2 = nn.MaxPool2d(2, 2) |
|
|
|
self.left_conv_3 = DoubleConv(128, 256) |
|
self.down_3 = nn.MaxPool2d(2, 2) |
|
|
|
self.left_conv_4 = DoubleConv(256, 512) |
|
self.down_4 = nn.MaxPool2d(2, 2) |
|
|
|
|
|
self.center_conv = DoubleConv(512, 1024) |
|
|
|
|
|
self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2) |
|
self.right_conv_1 = DoubleConv(1024, 512) |
|
|
|
self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2) |
|
self.right_conv_2 = DoubleConv(512, 256) |
|
|
|
self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2) |
|
self.right_conv_3 = DoubleConv(256, 128) |
|
|
|
self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2) |
|
self.right_conv_4 = DoubleConv(128, 64) |
|
|
|
|
|
self.output = nn.Conv2d(64, 3, 1, 1, 0) |
|
|
|
def forward(self, x): |
|
|
|
x1 = self.left_conv_1(x) |
|
x1_down = self.down_1(x1) |
|
|
|
x2 = self.left_conv_2(x1_down) |
|
x2_down = self.down_2(x2) |
|
|
|
x3 = self.left_conv_3(x2_down) |
|
x3_down = self.down_3(x3) |
|
|
|
x4 = self.left_conv_4(x3_down) |
|
x4_down = self.down_4(x4) |
|
|
|
|
|
x5 = self.center_conv(x4_down) |
|
|
|
|
|
x6_up = self.up_1(x5) |
|
temp = torch.cat((x6_up, x4), dim=1) |
|
x6 = self.right_conv_1(temp) |
|
|
|
x7_up = self.up_2(x6) |
|
temp = torch.cat((x7_up, x3), dim=1) |
|
x7 = self.right_conv_2(temp) |
|
|
|
x8_up = self.up_3(x7) |
|
temp = torch.cat((x8_up, x2), dim=1) |
|
x8 = self.right_conv_3(temp) |
|
|
|
x9_up = self.up_4(x8) |
|
temp = torch.cat((x9_up, x1), dim=1) |
|
x9 = self.right_conv_4(temp) |
|
|
|
|
|
output = self.output(x9) |
|
|
|
return output |
|
|
|
|
|
|
|
import sys |
|
sys.path.append('./BidirectionalTranslation') |
|
from data.base_dataset import BaseDataset, get_params, get_transform |
|
from data.image_folder import make_dataset |
|
|
|
def get_ScreenVAE_input(A_img, opt): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
L_img = A_img |
|
|
|
|
|
if A_img.size != L_img.size: |
|
A_img = A_img.resize(L_img.size, Image.ANTIALIAS) |
|
if A_img.size[1] > 2500: |
|
A_img = A_img.resize((A_img.size[0]//2, A_img.size[1]//2), Image.ANTIALIAS) |
|
|
|
|
|
ow, oh = A_img.size |
|
transform_params = get_params(opt, A_img.size) |
|
|
|
|
|
A_transform = get_transform(opt, transform_params, grayscale=False) |
|
L_transform = get_transform(opt, transform_params, grayscale=True) |
|
A = A_transform(A_img) |
|
L = L_transform(L_img) |
|
|
|
|
|
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 |
|
Ai = tmp.unsqueeze(0) |
|
|
|
return {'A': A.unsqueeze(0), 'Ai': Ai.unsqueeze(0), 'L': L.unsqueeze(0), 'A_paths': '', 'h': oh, 'w': ow, 'B': torch.zeros(1), |
|
'Bs': torch.zeros(1), |
|
'Bi': torch.zeros(1), |
|
'Bl': torch.zeros(1),} |
|
|
|
|
|
def get_bidirectional_translation_opt(opt): |
|
opt.results_dir = './results/test/western2manga' |
|
opt.dataroot = './datasets/color2manga' |
|
opt.checkpoints_dir = '/group/40034/zhuangjunhao/ScreenStyle/BidirectionalTranslation/checkpoints/color2manga/' |
|
opt.name = 'color2manga_cycle_ganstft' |
|
opt.model = 'cycle_ganstft' |
|
opt.direction = 'BtoA' |
|
opt.preprocess = 'none' |
|
opt.load_size = 512 |
|
opt.crop_size = 1024 |
|
opt.input_nc = 1 |
|
opt.output_nc = 3 |
|
opt.nz = 64 |
|
opt.netE = 'conv_256' |
|
opt.num_test = 30 |
|
opt.n_samples = 1 |
|
opt.upsample = 'bilinear' |
|
opt.ngf = 48 |
|
opt.nef = 48 |
|
opt.ndf = 32 |
|
opt.center_crop = True |
|
opt.color2screen = True |
|
opt.no_flip = True |
|
|
|
|
|
opt.num_threads = 1 |
|
opt.batch_size = 1 |
|
opt.serial_batches = True |
|
return opt |