|
import albumentations as albu |
|
import numpy as np |
|
from iglovikov_helper_functions.utils.image_utils import pad, unpad |
|
from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image |
|
import cv2 |
|
import torch |
|
from PIL import Image |
|
|
|
def pil2tensor(image): |
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
|
|
def tensor2np(image): |
|
return np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) |
|
|
|
class GenImageMask: |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
""" |
|
Return a dictionary which contains config for all input fields. |
|
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". |
|
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. |
|
The type can be a list for selection. |
|
|
|
Returns: `dict`: |
|
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` |
|
- Value input_fields (`dict`): Contains input fields config: |
|
* Key field_name (`string`): Name of a entry-point method's argument |
|
* Value field_config (`tuple`): |
|
+ First value is a string indicate the type of field or a list for selection. |
|
+ Secound value is a config for type "INT", "STRING" or "FLOAT". |
|
""" |
|
return {"required": {"image": ("IMAGE",), "segm_model": ("SEGM_MODEL",)}} |
|
|
|
CATEGORY = "image" |
|
RETURN_TYPES = ("IMAGE",) |
|
RETURN_NAMES = ("BW Mask",) |
|
|
|
FUNCTION = "mask_image" |
|
def mask_image(self, image, segm_model): |
|
|
|
transform = albu.Compose([albu.Normalize(p=1)], p=1) |
|
image = tensor2np(image) |
|
padded_image, pads = pad(image, factor=32, border=cv2.BORDER_CONSTANT) |
|
x = transform(image=padded_image)["image"] |
|
x = torch.unsqueeze(tensor_from_rgb_image(x), 0) |
|
with torch.no_grad(): |
|
prediction = segm_model(x)[0][0] |
|
|
|
mask = (prediction > 0.01).cpu().numpy().astype(np.uint8) |
|
|
|
mask = unpad(mask, pads) |
|
|
|
imag = pil2tensor(Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) * 255)) |
|
return (imag,) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"GenImageMask": GenImageMask, |
|
} |
|
|