|
import albumentations as A |
|
import base64 |
|
import cv2 |
|
import gradio as gr |
|
import inspect |
|
import io |
|
import numpy as np |
|
import os |
|
|
|
from dataclasses import dataclass |
|
from loguru import logger |
|
from copy import deepcopy |
|
from functools import wraps |
|
from PIL import Image, ImageDraw |
|
from typing import get_type_hints, Optional |
|
from pydantic_core._pydantic_core import ValidationError |
|
|
|
|
|
from utils import is_not_supported_transform |
|
|
|
|
|
|
|
|
|
|
|
|
|
HEADER = f""" |
|
<div align="center"> |
|
<p> |
|
<img src="https://avatars.githubusercontent.com/u/57894582?s=200&v=4" alt="A" width="50" height="50" style="display:inline;"> |
|
<span style="font-size: 30px; vertical-align: bottom;"> lbumentations Demo ({A.__version__})</span> |
|
</p> |
|
<p style="margin-top: -15px;"> |
|
<a href="https://albumentations.ai/docs/" target="_blank" style="color: grey;">Documentation</a> |
|
|
|
<a href="https://github.com/albumentations-team/albumentations" target="_blank" style="color: grey;">GitHub Repository</a> |
|
</p> |
|
</div> |
|
""" |
|
|
|
DEFAULT_TRANSFORM = "Rotate" |
|
NO_OPERATION_TRANFORM = "NoOp" |
|
|
|
DEFAULT_IMAGE_PATH = "images/doctor.webp" |
|
DEFAULT_IMAGE = np.array(Image.open(DEFAULT_IMAGE_PATH)) |
|
DEFAULT_IMAGE_HEIGHT = DEFAULT_IMAGE.shape[0] |
|
DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE.shape[1] |
|
DEFAULT_BOXES = [ |
|
[265, 121, 326, 177], |
|
[192, 169, 401, 395], |
|
] |
|
|
|
mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]] |
|
pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]] |
|
arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]] |
|
DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints |
|
|
|
BASE64_DEFAULT_MASKS = [ |
|
{ |
|
"label": "Coverall", |
|
|
|
"color": (144, 238, 144), |
|
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==", |
|
}, |
|
{ |
|
"label": "Mask", |
|
|
|
"color": (173, 216, 230), |
|
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC", |
|
}, |
|
] |
|
|
|
|
|
transforms_map = { |
|
name: cls |
|
for name, cls in vars(A).items() |
|
if ( |
|
inspect.isclass(cls) |
|
and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform)) |
|
and not is_not_supported_transform(cls) |
|
and not name.endswith("3D") |
|
) |
|
} |
|
transforms_map.pop("DualTransform", None) |
|
transforms_map.pop("ImageOnlyTransform", None) |
|
transforms_map.pop("ReferenceBasedTransform", None) |
|
transforms_map.pop("ToFloat", None) |
|
transforms_map.pop("Normalize", None) |
|
transforms_keys = list(sorted(transforms_map.keys())) |
|
|
|
|
|
|
|
for mask in BASE64_DEFAULT_MASKS: |
|
mask["mask"] = np.array( |
|
Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L") |
|
) |
|
|
|
|
|
@dataclass |
|
class RequestParams: |
|
user_ip: str |
|
transform_name: Optional[str] |
|
|
|
def track_event(event_name, user_id="unknown", properties=None): |
|
if properties is None: |
|
properties = {} |
|
|
|
logger.info(f"Event tracked: {event_name} - {properties}") |
|
|
|
|
|
def get_params(request: gr.Request) -> RequestParams: |
|
"""Parse input request parameters.""" |
|
ip = request.client.host |
|
transform_name = request.query_params.get("transform", None) |
|
params = RequestParams(user_ip=ip, transform_name=transform_name) |
|
track_event("app_opened", user_id=params.user_ip, properties={"transform_name": params.transform_name}) |
|
return params |
|
|
|
|
|
def run_with_retry(compose): |
|
@wraps(compose) |
|
def wrapper(*args, **kwargs): |
|
processors = deepcopy(compose.processors) |
|
for _ in range(4): |
|
try: |
|
result = compose(*args, **kwargs) |
|
break |
|
except NotImplementedError as e: |
|
print(f"Caught NotImplementedError: {e}") |
|
if "bbox" in str(e): |
|
kwargs.pop("bboxes", None) |
|
kwargs.pop("category_id", None) |
|
compose.processors.pop("bboxes") |
|
if "keypoint" in str(e): |
|
kwargs.pop("keypoints", None) |
|
compose.processors.pop("keypoints") |
|
if "mask" in str(e): |
|
kwargs.pop("mask", None) |
|
except (ValueError, ValidationError) as e: |
|
raise gr.Error(str(e)) |
|
except Exception as e: |
|
compose.processors = processors |
|
raise e |
|
compose.processors = processors |
|
return result |
|
|
|
return wrapper |
|
|
|
|
|
def draw_boxes(image, boxes, color=(255, 0, 0), thickness=1) -> np.ndarray: |
|
"""Draw boxes with PIL.""" |
|
pil_image = Image.fromarray(image) |
|
draw = ImageDraw.Draw(pil_image) |
|
for box in boxes: |
|
x_min, y_min, x_max, y_max = box |
|
draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness) |
|
return np.array(pil_image) |
|
|
|
|
|
def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2): |
|
"""Draw keypoints with PIL.""" |
|
pil_image = Image.fromarray(image) |
|
draw = ImageDraw.Draw(pil_image) |
|
for keypoint in keypoints: |
|
x, y = keypoint |
|
draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color) |
|
return np.array(pil_image) |
|
|
|
|
|
def get_rgb_mask(masks): |
|
"""Get the RGB mask from the binary mask.""" |
|
rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8) |
|
for data in masks: |
|
mask = data["mask"] |
|
rgb_mask[mask > 0] = np.array(data["color"]) |
|
return rgb_mask |
|
|
|
|
|
def draw_mask(image, mask): |
|
"""Draw the mask on the image.""" |
|
image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0) |
|
return image_with_mask |
|
|
|
|
|
def draw_not_implemented_image(image: np.ndarray, annotation_type: str): |
|
"""Draw the image with a text. In the middle.""" |
|
pil_image = Image.fromarray(image) |
|
draw = ImageDraw.Draw(pil_image) |
|
|
|
text = f'Transform NOT working with "{annotation_type.upper()}" annotations.' |
|
length = draw.textlength(text) |
|
draw.text( |
|
(DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2), |
|
text, |
|
fill=(255, 0, 0), |
|
align="center", |
|
) |
|
return np.array(pil_image) |
|
|
|
|
|
def get_formatted_signature(function_or_class, indentation=4): |
|
|
|
signature = inspect.signature(function_or_class) |
|
type_hints = get_type_hints(function_or_class) |
|
|
|
args = [] |
|
for param in signature.parameters.values(): |
|
if param.name == "p": |
|
str_param = "p=1.0," |
|
elif param.default == inspect.Parameter.empty: |
|
if "height" in param.name or "width" in param.name: |
|
str_param = f"{param.name}=300," |
|
else: |
|
str_param = f"{param.name}=," |
|
else: |
|
if isinstance(param.default, str): |
|
str_param = f'{param.name}="{param.default}",' |
|
else: |
|
str_param = f"{param.name}={param.default}," |
|
|
|
annotation = type_hints.get(param.name, param.annotation) |
|
if isinstance(param.annotation, type): |
|
str_param += f" # {param.annotation.__name__}" |
|
else: |
|
str_annotation = str(annotation).replace("typing.", "") |
|
str_param += f" # {str_annotation}" |
|
str_param = "\n" + " " * indentation + str_param |
|
args.append(str_param) |
|
|
|
result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")" |
|
return result |
|
|
|
|
|
def get_formatted_transform(transform_name): |
|
track_event("transform_selected", properties={"transform_name": transform_name}) |
|
transform = transforms_map[transform_name] |
|
return f"A.{transform.__name__}{get_formatted_signature(transform)}" |
|
|
|
|
|
def get_formatted_transform_docs(transform_name): |
|
transform = transforms_map[transform_name] |
|
return transform.__doc__.strip("\n") |
|
|
|
|
|
def update_augmented_images(image, code): |
|
|
|
if "=," in code: |
|
raise gr.Error("You have to fill in parameters to apply transform! See 'Code' section!") |
|
|
|
try: |
|
augmentation = eval(code) |
|
except ValidationError as e: |
|
raise gr.Error(str(e)) |
|
except Exception as e: |
|
logger.info(code) |
|
logger.error(e) |
|
raise e |
|
|
|
track_event("transform_applied", properties={"transform_name": augmentation.__class__.__name__, "code": code}) |
|
|
|
compose = A.Compose( |
|
[augmentation], |
|
bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]), |
|
keypoint_params=A.KeypointParams(format="xy"), |
|
) |
|
compose = run_with_retry(compose) |
|
|
|
keypoints = DEFAULT_KEYPOINTS |
|
bboxes = DEFAULT_BOXES |
|
mask = get_rgb_mask(BASE64_DEFAULT_MASKS) |
|
augmented = compose( |
|
image=image, |
|
mask=mask, |
|
keypoints=keypoints, |
|
bboxes=bboxes, |
|
category_id=range(len(bboxes)), |
|
) |
|
image = augmented["image"] |
|
mask = augmented.get("mask", None) |
|
bboxes = augmented.get("bboxes", None) |
|
keypoints = augmented.get("keypoints", None) |
|
|
|
|
|
if mask is not None: |
|
image_with_mask = draw_mask(image.copy(), mask) |
|
else: |
|
image_with_mask = draw_not_implemented_image(image.copy(), "mask") |
|
|
|
if bboxes is not None: |
|
image_with_bboxes = draw_boxes(image.copy(), bboxes) |
|
else: |
|
image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes") |
|
|
|
if keypoints is not None: |
|
image_with_keypoints = draw_keypoints(image.copy(), keypoints) |
|
else: |
|
image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints") |
|
|
|
return [ |
|
(image_with_mask, "Mask"), |
|
(image_with_bboxes, "Boxes"), |
|
(image_with_keypoints, "Keypoints"), |
|
] |
|
|
|
|
|
def update_image_info(image): |
|
h, w = image.shape[:2] |
|
dtype = image.dtype |
|
max_, min_ = image.max(), image.min() |
|
return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}" |
|
|
|
|
|
def update_code_and_docs(select): |
|
code = get_formatted_transform(select) |
|
docs = get_formatted_transform_docs(select) |
|
return code, docs |
|
|
|
def update_code_and_docs_on_start(url_params: gr.Request): |
|
params = get_params(url_params) |
|
if params.transform_name is not None and params.transform_name not in transforms_map: |
|
gr.Warning(f"Sorry, `{params.transform_name}` transform is not supported at the moment :(") |
|
transform_name = NO_OPERATION_TRANFORM |
|
elif params.transform_name in transforms_map: |
|
transform_name = params.transform_name |
|
else: |
|
transform_name = DEFAULT_TRANSFORM |
|
return gr.update(value=transform_name) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(HEADER) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Group(): |
|
|
|
|
|
|
|
|
|
select = gr.Dropdown( |
|
label="Select a transformation", |
|
choices=transforms_keys, |
|
value=DEFAULT_TRANSFORM, |
|
type="value", |
|
interactive=True, |
|
) |
|
with gr.Accordion("Documentation (click to expand)", open=False): |
|
docs = gr.TextArea( |
|
get_formatted_transform_docs(DEFAULT_TRANSFORM), |
|
show_label=False, |
|
interactive=False, |
|
) |
|
code = gr.Code( |
|
label="Code", |
|
language="python", |
|
value=get_formatted_transform(DEFAULT_TRANSFORM), |
|
interactive=True, |
|
lines=5, |
|
) |
|
info = gr.TextArea( |
|
value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)", |
|
show_label=False, |
|
lines=1, |
|
max_lines=1, |
|
) |
|
button = gr.Button("Apply!") |
|
image = gr.Image( |
|
value=DEFAULT_IMAGE_PATH, |
|
type="numpy", |
|
height=500, |
|
width=300, |
|
sources=[], |
|
) |
|
with gr.Row(): |
|
augmented_image = gr.Gallery( |
|
value=update_augmented_images(DEFAULT_IMAGE, "A.NoOp()"), |
|
rows=1, |
|
columns=3, |
|
show_label=False, |
|
) |
|
select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs]) |
|
button.click( |
|
fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image] |
|
) |
|
demo.load( |
|
update_code_and_docs_on_start, inputs=None, outputs=[select], queue=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|