File size: 4,701 Bytes
3719834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import PIL
from typing import List
import numpy as np
import torchvision.transforms.functional as F
import torch

def multiply_grayscale_images(image1, image2):
    # Convert the images to NumPy arrays
    image1_np = np.array(image1)
    image2_np = np.array(image2)

    # Perform element-wise multiplication (ensure to use np.float32 to avoid overflow)
    multiplied_image = image1_np.astype(np.float32) * image2_np.astype(np.float32)

    # Normalize the result to the range 0-255 (if needed)
    multiplied_image = np.clip(multiplied_image, 0, 255)

    # Convert back to uint8 (8-bit grayscale image)
    multiplied_image = multiplied_image.astype(np.uint8)

    # Convert back to an image and save the result
    result_image = PIL.Image.fromarray(multiplied_image)
    return result_image

def create_color_masks(image: PIL.Image.Image):
    # Load the image
    image = image.convert("RGB")
    image_np = np.array(image)  # Convert to numpy array (Height x Width x 3)
    # Find unique colors in the image
    unique_colors = np.unique(image_np.reshape(-1, 3), axis=0)
    output = []
    # Create masks for each color
    for color in unique_colors:
        if sum(color) == 0:
            continue
        mask = np.all(image_np == color, axis=-1)
        color_str = '_'.join(map(str, color))  # Create a string representation of the color
        output.append((color_str, mask))
    # Skip Background Mask Image
    background_area = 0.0
    background_mask_index = -1
    for idx, (color_str, mask) in enumerate(output):
        area = np.sum(mask > 0) / (mask.shape[0] * mask.shape[1])
        if area > background_area:
            background_area = area
            background_mask_index = idx
    # Final Elements
    elements = []
    for idx, (color_str, mask) in enumerate(output):
        if idx == background_mask_index:
            print(background_mask_index)
            continue
        mask_image = PIL.Image.fromarray(mask.astype(np.uint8) * 255)
        elements.append((color_str, mask_image))
    # Final Background
    final_background_mask_image = PIL.Image.new("L", (image.size[0], image.size[1]), 255)
    draw = PIL.ImageDraw.Draw(final_background_mask_image)
    for idx, (color_str, mask_image) in enumerate(elements):
        final_background_mask_image = multiply_grayscale_images(final_background_mask_image, PIL.ImageOps.invert(mask_image))

    return final_background_mask_image, elements


def create_text_masks(polygons, width, height):
    # Loop over each polygon in the list
    text_masks = []
    for i, polygon_coords in enumerate(polygons):
        # Create a new grayscale image (L mode) with a black background (0)
        mask = PIL.Image.new('L', (width, height), 0)

        # Create a drawing object
        draw = PIL.ImageDraw.Draw(mask)

        # Convert the list of polygon coordinates into a format ImageDraw can use (list of tuples)
        polygon_points = [(polygon_coords[j], polygon_coords[j + 1]) for j in range(0, len(polygon_coords), 2)]

        # Draw the polygon with white (255) fill
        draw.polygon(polygon_points, fill=255)
        text_masks.append(mask)
    return text_masks

class GetLayerMask:

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "json_data": ("JSON",),
            },
        }

    RETURN_TYPES = ("MASK", "MASK", "JSON")

    FUNCTION = "main"

    CATEGORY = "tensorops"

    def main(self, image: torch.Tensor, json_data: str):
        # Create PIL.Image
        image = image.permute(0, 3, 1, 2)
        image_pil = F.to_pil_image(image[0])
        # Create bg and elements
        bg, elements = create_color_masks(image_pil)
        # Create Text Masks
        print("items", json_data)
        items = [item for item in json_data]
        text_polygon_list = []
        text_label_list = []
        text_masks = []

        for item in items:
            text_polygon_list.append(item["polygon"])
            text_label_list.append(item["label"])

        for mask_image in create_text_masks(text_polygon_list, bg.size[0], bg.size[1]):
            img = np.array(mask_image).astype(np.float32) / 255.0
            img = torch.from_numpy(img)[None,]
            text_masks.append(img)

        output = []
        bg = np.array(bg).astype(np.float32) / 255.0
        bg = torch.from_numpy(bg)[None,]
        output.append(bg)
        for _, mask_image in elements:
            img = np.array(mask_image).astype(np.float32) / 255.0
            img = torch.from_numpy(img)[None,]
            output.append(img)
        return (torch.cat(output, dim=0), torch.cat(text_masks, dim=0), text_label_list)