import albumentations as A import base64 import cv2 import gradio as gr import inspect import io import numpy as np from copy import deepcopy from functools import wraps from PIL import Image, ImageDraw from typing import get_type_hints from utils import is_not_supported_transform HEADER = f"""

A lbumentations Demo ({A.__version__})

""" DEFAULT_TRANSFORM = "Rotate" DEFAULT_IMAGE = "images/doctor.webp" DEFAULT_IMAGE_HEIGHT = 400 DEFAULT_IMAGE_WIDTH = 600 DEFAULT_BOXES = [ [265, 121, 326, 177], # Mask [192, 169, 401, 395], # Coverall ] mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]] pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]] arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]] DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints BASE64_DEFAULT_MASKS = [ { "label": "Coverall", # light green color "color": (144, 238, 144), "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==", }, { "label": "Mask", # light blue color "color": (173, 216, 230), "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC", }, ] # Get all the transforms from the albumentations library transforms_map = { name: cls for name, cls in vars(A).items() if ( inspect.isclass(cls) and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform)) and not is_not_supported_transform(cls) ) } transforms_map.pop("DualTransform", None) transforms_map.pop("ImageOnlyTransform", None) transforms_map.pop("ReferenceBasedTransform", None) transforms_keys = list(sorted(transforms_map.keys())) # Decode the masks for mask in BASE64_DEFAULT_MASKS: mask["mask"] = np.array( Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L") ) def run_with_retry(compose): @wraps(compose) def wrapper(*args, **kwargs): processors = deepcopy(compose.processors) for _ in range(4): try: result = compose(*args, **kwargs) break except NotImplementedError as e: print(f"Caught NotImplementedError: {e}") if "bbox" in str(e): kwargs.pop("bboxes", None) kwargs.pop("category_id", None) compose.processors.pop("bboxes") if "keypoint" in str(e): kwargs.pop("keypoints", None) compose.processors.pop("keypoints") if "mask" in str(e): kwargs.pop("mask", None) except Exception as e: compose.processors = processors raise e compose.processors = processors return result return wrapper def draw_boxes(image, boxes, color=(255, 0, 0), thickness=2) -> np.ndarray: """Draw boxes with PIL.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) for box in boxes: x_min, y_min, x_max, y_max = box draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness) return np.array(pil_image) def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2): """Draw keypoints with PIL.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) for keypoint in keypoints: x, y = keypoint draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color) return np.array(pil_image) def get_rgb_mask(masks): """Get the RGB mask from the binary mask.""" rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8) for data in masks: mask = data["mask"] rgb_mask[mask > 0] = np.array(data["color"]) return rgb_mask def draw_mask(image, mask): """Draw the mask on the image.""" image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0) return image_with_mask def draw_not_implemented_image(image: np.ndarray, annotation_type: str): """Draw the image with a text. In the middle.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) # align in the centerm, and make bigger font text = f'Transform NOT working with "{annotation_type.upper()}" annotaions.' length = draw.textlength(text) draw.text( (DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2), text, fill=(255, 0, 0), align="center", ) return np.array(pil_image) def get_formatted_signature(function_or_class, indentation=4): signature = inspect.signature(function_or_class) type_hints = get_type_hints(function_or_class) args = [] for param in signature.parameters.values(): if param.name == "p": str_param = "p=1.0," elif param.default == inspect.Parameter.empty: str_param = f"{param.name}=," else: if isinstance(param.default, str): str_param = f'{param.name}="{param.default}",' else: str_param = f"{param.name}={param.default}," annotation = type_hints.get(param.name, param.annotation) if isinstance(param.annotation, type): str_param += f" # {param.annotation.__name__}" else: str_annotation = str(annotation).replace("typing.", "") str_param += f" # {str_annotation}" str_param = "\n" + " " * indentation + str_param args.append(str_param) result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")" return result def get_formatted_transform(transform_number): transform_name = transforms_keys[transform_number] transform = transforms_map[transform_name] return f"A.{transform.__name__}{get_formatted_signature(transform)}" def get_formatted_transform_docs(transform_number): transform_name = transforms_keys[transform_number] transform = transforms_map[transform_name] return transform.__doc__.strip("\n") def update_augmented_images(image, code): augmentation = eval(code) compose = A.Compose( [augmentation], bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]), keypoint_params=A.KeypointParams(format="xy"), ) compose = run_with_retry(compose) # to prevent NotImplementedError keypoints = DEFAULT_KEYPOINTS bboxes = DEFAULT_BOXES mask = get_rgb_mask(BASE64_DEFAULT_MASKS) augmented = compose( image=image, not_implemented_image=image.copy(), mask=mask, keypoints=keypoints, bboxes=bboxes, category_id=range(len(bboxes)), ) image = augmented["image"] mask = augmented.get("mask", None) bboxes = augmented.get("bboxes", None) keypoints = augmented.get("keypoints", None) # Draw the augmented images (or replace by placeholder if not implemented) if mask is not None: image_with_mask = draw_mask(image.copy(), mask) else: image_with_mask = draw_not_implemented_image(image.copy(), "mask") if bboxes is not None: image_with_bboxes = draw_boxes(image.copy(), bboxes) else: image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes") if keypoints is not None: image_with_keypoints = draw_keypoints(image.copy(), keypoints) else: image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints") return [ (image_with_mask, "Mask"), (image_with_bboxes, "Boxes"), (image_with_keypoints, "Keypoints"), ] def update_image_info(image): h, w = image.shape[:2] dtype = image.dtype max_, min_ = image.max(), image.min() return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}" def update_code_and_docs(select): code = get_formatted_transform(select) docs = get_formatted_transform_docs(select) return code, docs with gr.Blocks() as demo: gr.Markdown(HEADER) with gr.Row(): with gr.Column(): with gr.Group(): select = gr.Dropdown( label="Select a transformation", choices=transforms_keys, value=DEFAULT_TRANSFORM, type="index", interactive=True, ) with gr.Accordion("Documentation (click to expand)", open=False): docs = gr.TextArea( get_formatted_transform_docs( transforms_keys.index(DEFAULT_TRANSFORM) ), show_label=False, interactive=False, ) code = gr.Code( label="Code", language="python", value=get_formatted_transform( transforms_keys.index(DEFAULT_TRANSFORM) ), interactive=True, lines=5, ) info = gr.TextArea( value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)", show_label=False, lines=1, max_lines=1, ) button = gr.Button("Apply!") image = gr.Image( value=DEFAULT_IMAGE, type="numpy", height=500, width=300, sources=[], ) with gr.Row(): augmented_image = gr.Gallery(rows=1, columns=3) select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs]) button.click( fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image] ) if __name__ == "__main__": demo.launch()