FluttyProger
commited on
Commit
•
1c70b30
1
Parent(s):
9b7dd4a
Delete maskgen.py
Browse files- maskgen.py +0 -63
maskgen.py
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
import albumentations as albu
|
2 |
-
import numpy as np
|
3 |
-
from iglovikov_helper_functions.utils.image_utils import pad, unpad
|
4 |
-
from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image
|
5 |
-
import cv2
|
6 |
-
import torch
|
7 |
-
from PIL import Image
|
8 |
-
|
9 |
-
def pil2tensor(image):
|
10 |
-
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
11 |
-
|
12 |
-
def tensor2np(image):
|
13 |
-
return np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
|
14 |
-
|
15 |
-
class GenImageMask:
|
16 |
-
|
17 |
-
def __init__(self):
|
18 |
-
pass
|
19 |
-
|
20 |
-
@classmethod
|
21 |
-
def INPUT_TYPES(s):
|
22 |
-
"""
|
23 |
-
Return a dictionary which contains config for all input fields.
|
24 |
-
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
|
25 |
-
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
|
26 |
-
The type can be a list for selection.
|
27 |
-
|
28 |
-
Returns: `dict`:
|
29 |
-
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
|
30 |
-
- Value input_fields (`dict`): Contains input fields config:
|
31 |
-
* Key field_name (`string`): Name of a entry-point method's argument
|
32 |
-
* Value field_config (`tuple`):
|
33 |
-
+ First value is a string indicate the type of field or a list for selection.
|
34 |
-
+ Secound value is a config for type "INT", "STRING" or "FLOAT".
|
35 |
-
"""
|
36 |
-
return {"required": {"image": ("IMAGE",), "segm_model": ("SEGM_MODEL",)}}
|
37 |
-
|
38 |
-
CATEGORY = "image"
|
39 |
-
RETURN_TYPES = ("IMAGE",)
|
40 |
-
RETURN_NAMES = ("BW Mask",)
|
41 |
-
|
42 |
-
FUNCTION = "mask_image"
|
43 |
-
def mask_image(self, image, segm_model):
|
44 |
-
|
45 |
-
transform = albu.Compose([albu.Normalize(p=1)], p=1)
|
46 |
-
image = tensor2np(image)
|
47 |
-
padded_image, pads = pad(image, factor=32, border=cv2.BORDER_CONSTANT)
|
48 |
-
x = transform(image=padded_image)["image"]
|
49 |
-
x = torch.unsqueeze(tensor_from_rgb_image(x), 0)
|
50 |
-
with torch.no_grad():
|
51 |
-
prediction = segm_model(x)[0][0]
|
52 |
-
|
53 |
-
mask = (prediction > 0.01).cpu().numpy().astype(np.uint8)
|
54 |
-
|
55 |
-
mask = unpad(mask, pads)
|
56 |
-
|
57 |
-
imag = pil2tensor(Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) * 255))
|
58 |
-
return (imag,)
|
59 |
-
|
60 |
-
|
61 |
-
NODE_CLASS_MAPPINGS = {
|
62 |
-
"GenImageMask": GenImageMask,
|
63 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|