FluttyProger commited on
Commit
1c70b30
1 Parent(s): 9b7dd4a

Delete maskgen.py

Browse files
Files changed (1) hide show
  1. maskgen.py +0 -63
maskgen.py DELETED
@@ -1,63 +0,0 @@
1
- import albumentations as albu
2
- import numpy as np
3
- from iglovikov_helper_functions.utils.image_utils import pad, unpad
4
- from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image
5
- import cv2
6
- import torch
7
- from PIL import Image
8
-
9
- def pil2tensor(image):
10
- return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
11
-
12
- def tensor2np(image):
13
- return np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
14
-
15
- class GenImageMask:
16
-
17
- def __init__(self):
18
- pass
19
-
20
- @classmethod
21
- def INPUT_TYPES(s):
22
- """
23
- Return a dictionary which contains config for all input fields.
24
- Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
25
- Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
26
- The type can be a list for selection.
27
-
28
- Returns: `dict`:
29
- - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
30
- - Value input_fields (`dict`): Contains input fields config:
31
- * Key field_name (`string`): Name of a entry-point method's argument
32
- * Value field_config (`tuple`):
33
- + First value is a string indicate the type of field or a list for selection.
34
- + Secound value is a config for type "INT", "STRING" or "FLOAT".
35
- """
36
- return {"required": {"image": ("IMAGE",), "segm_model": ("SEGM_MODEL",)}}
37
-
38
- CATEGORY = "image"
39
- RETURN_TYPES = ("IMAGE",)
40
- RETURN_NAMES = ("BW Mask",)
41
-
42
- FUNCTION = "mask_image"
43
- def mask_image(self, image, segm_model):
44
-
45
- transform = albu.Compose([albu.Normalize(p=1)], p=1)
46
- image = tensor2np(image)
47
- padded_image, pads = pad(image, factor=32, border=cv2.BORDER_CONSTANT)
48
- x = transform(image=padded_image)["image"]
49
- x = torch.unsqueeze(tensor_from_rgb_image(x), 0)
50
- with torch.no_grad():
51
- prediction = segm_model(x)[0][0]
52
-
53
- mask = (prediction > 0.01).cpu().numpy().astype(np.uint8)
54
-
55
- mask = unpad(mask, pads)
56
-
57
- imag = pil2tensor(Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) * 255))
58
- return (imag,)
59
-
60
-
61
- NODE_CLASS_MAPPINGS = {
62
- "GenImageMask": GenImageMask,
63
- }