LoRA_Nastya_Renz / maskgen.py
FluttyProger's picture
Upload 3 files
a17c757
raw
history blame
2.38 kB
import albumentations as albu
import numpy as np
from iglovikov_helper_functions.utils.image_utils import pad, unpad
from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image
import cv2
import torch
from PIL import Image
def pil2tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
def tensor2np(image):
return np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
class GenImageMask:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
"""
Return a dictionary which contains config for all input fields.
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
The type can be a list for selection.
Returns: `dict`:
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
- Value input_fields (`dict`): Contains input fields config:
* Key field_name (`string`): Name of a entry-point method's argument
* Value field_config (`tuple`):
+ First value is a string indicate the type of field or a list for selection.
+ Secound value is a config for type "INT", "STRING" or "FLOAT".
"""
return {"required": {"image": ("IMAGE",), "segm_model": ("SEGM_MODEL",)}}
CATEGORY = "image"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("BW Mask",)
FUNCTION = "mask_image"
def mask_image(self, image, segm_model):
transform = albu.Compose([albu.Normalize(p=1)], p=1)
image = tensor2np(image)
padded_image, pads = pad(image, factor=32, border=cv2.BORDER_CONSTANT)
x = transform(image=padded_image)["image"]
x = torch.unsqueeze(tensor_from_rgb_image(x), 0)
with torch.no_grad():
prediction = segm_model(x)[0][0]
mask = (prediction > 0.01).cpu().numpy().astype(np.uint8)
mask = unpad(mask, pads)
imag = pil2tensor(Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) * 255))
return (imag,)
NODE_CLASS_MAPPINGS = {
"GenImageMask": GenImageMask,
}