|
import sys |
|
|
|
import os |
|
import math |
|
import argparse |
|
import random |
|
import logging |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
from data.data_sampler import DistIterSampler |
|
|
|
import options.options as option |
|
from utils import util |
|
from data.util import read_img |
|
from data import create_dataloader, create_dataset |
|
from models import create_model |
|
import numpy as np |
|
from PIL import Image |
|
from diffusers import StableDiffusionInpaintPipeline |
|
|
|
|
|
def init_dist(backend='nccl', **kwargs): |
|
''' initialization for distributed training''' |
|
|
|
if mp.get_start_method(allow_none=True) != 'spawn': |
|
mp.set_start_method('spawn') |
|
rank = int(os.environ['RANK']) |
|
num_gpus = torch.cuda.device_count() |
|
torch.cuda.set_device(rank % num_gpus) |
|
dist.init_process_group(backend=backend, **kwargs) |
|
|
|
|
|
def load_image(image, message = None): |
|
|
|
img_GT = image / 255 |
|
|
|
img_GT = img_GT[:, :, [2, 1, 0]] |
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0) |
|
img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None) |
|
img_GT = img_GT.unsqueeze(0) |
|
|
|
_, T, C, W, H = img_GT.shape |
|
list_h = [] |
|
R = 0 |
|
G = 0 |
|
B = 255 |
|
image = Image.new('RGB', (W, H), (R, G, B)) |
|
result = np.array(image) / 255. |
|
expanded_matrix = np.expand_dims(result, axis=0) |
|
expanded_matrix = np.repeat(expanded_matrix, T, axis=0) |
|
imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float() |
|
imgs_LQ = imgs_LQ.permute(0, 3, 1, 2) |
|
imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None) |
|
imgs_LQ = imgs_LQ.unsqueeze(0) |
|
|
|
list_h.append(imgs_LQ) |
|
|
|
list_h = torch.stack(list_h, dim=0) |
|
|
|
return { |
|
'LQ': list_h, |
|
'GT': img_GT, |
|
'MES': message |
|
} |
|
|
|
|
|
def image_editing(image_numpy, mask_image, prompt): |
|
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
|
|
pil_image = Image.fromarray(image_numpy) |
|
print(mask_image.shape) |
|
print("maskmin", mask_image.min(), "maskmax", mask_image.max()) |
|
mask_image = Image.fromarray(mask_image.astype(np.uint8)).convert("L") |
|
image_init = pil_image.convert("RGB").resize((512, 512)) |
|
|
|
h, w = mask_image.size |
|
|
|
image_inpaint = pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0] |
|
image_inpaint = np.array(image_inpaint) / 255. |
|
image = np.array(image_init) / 255. |
|
mask_image = np.array(mask_image) |
|
mask_image = np.stack([mask_image] * 3, axis=-1) / 255. |
|
mask_image = mask_image.astype(np.uint8) |
|
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image |
|
|
|
return image_fuse |