EditGuard / test_gradio.py
Ricoooo's picture
Add local files to repository
8da8f47
import sys
import os
import math
import argparse
import random
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from data.data_sampler import DistIterSampler
import options.options as option
from utils import util
from data.util import read_img
from data import create_dataloader, create_dataset
from models import create_model
import numpy as np
from PIL import Image
from diffusers import StableDiffusionInpaintPipeline
def init_dist(backend='nccl', **kwargs):
''' initialization for distributed training'''
# if mp.get_start_method(allow_none=True) is None:
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def load_image(image, message = None):
# img_GT = read_img(None, image_path)
img_GT = image / 255
# print(img_GT)
img_GT = img_GT[:, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)
img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None)
img_GT = img_GT.unsqueeze(0)
_, T, C, W, H = img_GT.shape
list_h = []
R = 0
G = 0
B = 255
image = Image.new('RGB', (W, H), (R, G, B))
result = np.array(image) / 255.
expanded_matrix = np.expand_dims(result, axis=0)
expanded_matrix = np.repeat(expanded_matrix, T, axis=0)
imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float()
imgs_LQ = imgs_LQ.permute(0, 3, 1, 2)
imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None)
imgs_LQ = imgs_LQ.unsqueeze(0)
list_h.append(imgs_LQ)
list_h = torch.stack(list_h, dim=0)
return {
'LQ': list_h,
'GT': img_GT,
'MES': message
}
def image_editing(image_numpy, mask_image, prompt):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
).to("cuda")
pil_image = Image.fromarray(image_numpy)
print(mask_image.shape)
print("maskmin", mask_image.min(), "maskmax", mask_image.max())
mask_image = Image.fromarray(mask_image.astype(np.uint8)).convert("L")
image_init = pil_image.convert("RGB").resize((512, 512))
h, w = mask_image.size
image_inpaint = pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0]
image_inpaint = np.array(image_inpaint) / 255.
image = np.array(image_init) / 255.
mask_image = np.array(mask_image)
mask_image = np.stack([mask_image] * 3, axis=-1) / 255.
mask_image = mask_image.astype(np.uint8)
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image
return image_fuse