Spaces:
Build error
Build error
import gradio | |
import gradio_image_annotation | |
import gradio_imageslider | |
import spaces | |
import torch | |
import src.SegmentAnything2Assist as SegmentAnything2Assist | |
example_image_annotation = { | |
"image": "assets/cars.jpg", | |
"boxes": [{'label': '+', 'color': (0, 255, 0), 'xmin': 886, 'ymin': 551, 'xmax': 886, 'ymax': 551}, {'label': '-', 'color': (255, 0, 0), 'xmin': 1239, 'ymin': 576, 'xmax': 1239, 'ymax': 576}, {'label': '-', 'color': (255, 0, 0), 'xmin': 610, 'ymin': 574, 'xmax': 610, 'ymax': 574}, {'label': '', 'color': (0, 0, 255), 'xmin': 254, 'ymin': 466, 'xmax': 1347, 'ymax': 1047}] | |
} | |
VERBOSE = True | |
segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = "sam2_hiera_tiny", device = torch.device("cuda")) | |
__image_point_coords = None | |
__image_point_labels = None | |
__image_box = None | |
__current_mask = None | |
__current_segment = None | |
def __change_base_model(model_name, device): | |
global segment_anything2assist | |
try: | |
segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = model_name, device = torch.device(device)) | |
gradio.Info(f"Model changed to {model_name} on {device}", duration = 5) | |
except: | |
gradio.Error(f"Model could not be changed", duration = 5) | |
def __post_process_annotator_inputs(value): | |
global __image_point_coords, __image_point_labels, __image_box | |
global __current_mask, __current_segment | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Called.") | |
__current_mask, __current_segment = None, None | |
new_boxes = [] | |
__image_point_coords = [] | |
__image_point_labels = [] | |
__image_box = [] | |
b_has_box = False | |
for box in value["boxes"]: | |
if box['label'] == '': | |
if not b_has_box: | |
new_box = box.copy() | |
new_box['color'] = (0, 0, 255) | |
new_boxes.append(new_box) | |
b_has_box = True | |
__image_box = [ | |
box['xmin'], | |
box['ymin'], | |
box['xmax'], | |
box['ymax'] | |
] | |
elif box['label'] == '+' or box['label'] == '-': | |
new_box = box.copy() | |
new_box['color'] = (0, 255, 0) if box['label'] == '+' else (255, 0, 0) | |
new_box['xmin'] = int((box['xmin'] + box['xmax']) / 2) | |
new_box['ymin'] = int((box['ymin'] + box['ymax']) / 2) | |
new_box['xmax'] = new_box['xmin'] | |
new_box['ymax'] = new_box['ymin'] | |
new_boxes.append(new_box) | |
__image_point_coords.append([new_box['xmin'], new_box['ymin']]) | |
__image_point_labels.append(1 if box['label'] == '+' else 0) | |
if len(__image_box) == 0: | |
__image_box = None | |
if len(__image_point_coords) == 0: | |
__image_point_coords = None | |
if len(__image_point_labels) == 0: | |
__image_point_labels = None | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Done.") | |
def __generate_mask(value, mask_threshold, max_hole_area, max_sprinkle_area, image_output_mode): | |
global __current_mask, __current_segment | |
global __image_point_coords, __image_point_labels, __image_box | |
global segment_anything2assist | |
# Force post processing of annotated image | |
__post_process_annotator_inputs(value) | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_mask::Called.") | |
mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image( | |
value["image"], | |
__image_point_coords, | |
__image_point_labels, | |
__image_box, | |
mask_threshold, | |
max_hole_area, | |
max_sprinkle_area | |
) | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_mask::Masks generated.") | |
__current_mask, __current_segment = segment_anything2assist.apply_mask_to_image(value["image"], mask_chw[0]) | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_mask::Masks and Segments created.") | |
if image_output_mode == "Mask": | |
return [value["image"], __current_mask] | |
elif image_output_mode == "Segment": | |
return [value["image"], __current_segment] | |
else: | |
gradio.Warning("This is an issue, please report the problem!", duration=5) | |
return gradio_imageslider.ImageSlider(render = True) | |
def __change_output_mode(image_input, radio): | |
global __current_mask, __current_segment | |
global __image_point_coords, __image_point_labels, __image_box | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_mask::Called.") | |
if __current_mask is None or __current_segment is None: | |
gradio.Warning("Configuration was changed, generate the mask again", duration=5) | |
return gradio_imageslider.ImageSlider(render = True) | |
if radio == "Mask": | |
return [image_input["image"], __current_mask] | |
elif radio == "Segment": | |
return [image_input["image"], __current_segment] | |
else: | |
gradio.Warning("This is an issue, please report the problem!", duration=5) | |
return gradio_imageslider.ImageSlider(render = True) | |
def __generate_multi_mask_output(image, auto_list, auto_mode, auto_bbox_mode): | |
global segment_anything2assist | |
image_with_bbox, mask, segment = segment_anything2assist.apply_auto_mask_to_image(image, [int(i) - 1 for i in auto_list]) | |
output_1 = image_with_bbox if auto_bbox_mode else image | |
output_2 = mask if auto_mode == "Mask" else segment | |
return [output_1, output_2] | |
def __generate_auto_mask( | |
image, | |
points_per_side, | |
points_per_batch, | |
pred_iou_thresh, | |
stability_score_thresh, | |
stability_score_offset, | |
mask_threshold, | |
box_nms_thresh, | |
crop_n_layers, | |
crop_nms_thresh, | |
crop_overlay_ratio, | |
crop_n_points_downscale_factor, | |
min_mask_region_area, | |
use_m2m, | |
multimask_output, | |
output_mode | |
): | |
global segment_anything2assist | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_auto_mask::Called.") | |
__auto_masks = segment_anything2assist.generate_automatic_masks( | |
image, | |
points_per_side, | |
points_per_batch, | |
pred_iou_thresh, | |
stability_score_thresh, | |
stability_score_offset, | |
mask_threshold, | |
box_nms_thresh, | |
crop_n_layers, | |
crop_nms_thresh, | |
crop_overlay_ratio, | |
crop_n_points_downscale_factor, | |
min_mask_region_area, | |
use_m2m, | |
multimask_output | |
) | |
if len(__auto_masks) == 0: | |
gradio.Warning("No masks generated, please tweak the advanced parameters.", duration = 5) | |
return gradio_imageslider.ImageSlider(), \ | |
gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False), \ | |
gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False) | |
else: | |
choices = [str(i) for i in range(len(__auto_masks))] | |
returning_image = __generate_multi_mask_output(image, ["0"], output_mode, False) | |
return returning_image, \ | |
gradio.CheckboxGroup(choices, value = ["0"], label = "Mask List", interactive = True), \ | |
gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = True) | |
with gradio.Blocks() as base_app: | |
gradio.Markdown("# SegmentAnything2Assist") | |
with gradio.Row(): | |
with gradio.Column(): | |
base_model_choice = gradio.Dropdown( | |
['sam2_hiera_large', 'sam2_hiera_small', 'sam2_hiera_base_plus','sam2_hiera_tiny'], | |
value = 'sam2_hiera_tiny', | |
label = "Model Choice" | |
) | |
with gradio.Column(): | |
base_gpu_choice = gradio.Dropdown( | |
['cpu', 'cuda'], | |
value = 'cuda', | |
label = "Device Choice" | |
) | |
base_model_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice]) | |
base_gpu_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice]) | |
with gradio.Tab(label = "Image Segmentation", id = "image_tab") as image_tab: | |
gradio.Markdown("Image Segmentation", render = True) | |
with gradio.Column(): | |
with gradio.Accordion("Image Annotation Documentation", open = False): | |
gradio.Markdown(""" | |
Image annotation allows you to mark specific regions of an image with labels. | |
In this app, you can annotate an image by drawing boxes and assigning labels to them. | |
The labels can be either '+' or '-'. | |
To annotate an image, simply click and drag to draw a box around the desired region. | |
You can add multiple boxes with different labels. | |
Once you have annotated the image, click the 'Generate Mask' button to generate a mask based on the annotations. | |
The mask can be either a binary mask or a segmented mask, depending on the selected output mode. | |
You can switch between the output modes using the radio buttons. | |
If you make any changes to the annotations or the output mode, you need to regenerate the mask by clicking the button again. | |
Note that the advanced options allow you to adjust the SAM mask threshold, maximum hole area, and maximum sprinkle area. | |
These options control the sensitivity and accuracy of the segmentation process. | |
Experiment with different settings to achieve the desired results. | |
""") | |
image_input = gradio_image_annotation.image_annotator(example_image_annotation) | |
with gradio.Accordion("Advanced Options", open = False): | |
image_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "SAM Mask Threshold") | |
image_generate_SAM_max_hole_area = gradio.Slider(0, 1000, 0, label = "SAM Max Hole Area") | |
image_generate_SAM_max_sprinkle_area = gradio.Slider(0, 1000, 0, label = "SAM Max Sprinkle Area") | |
image_generate_mask_button = gradio.Button("Generate Mask") | |
image_output = gradio_imageslider.ImageSlider() | |
image_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode") | |
image_input.change(__post_process_annotator_inputs, inputs = [image_input]) | |
image_generate_mask_button.click(__generate_mask, inputs = [ | |
image_input, | |
image_generate_SAM_mask_threshold, | |
image_generate_SAM_max_hole_area, | |
image_generate_SAM_max_sprinkle_area, | |
image_output_mode | |
], | |
outputs = [image_output]) | |
image_output_mode.change(__change_output_mode, inputs = [image_input, image_output_mode], outputs = [image_output]) | |
with gradio.Tab(label = "Auto Segmentation", id = "auto_tab"): | |
gradio.Markdown("Auto Segmentation", render = True) | |
with gradio.Column(): | |
with gradio.Accordion("Auto Annotation Documentation", open = False): | |
gradio.Markdown(""" | |
""") | |
auto_input = gradio.Image("assets/cars.jpg") | |
with gradio.Accordion("Advanced Options", open = False): | |
auto_generate_SAM_points_per_side = gradio.Slider(1, 64, 32, 1, label = "Points Per Side", interactive = True) | |
auto_generate_SAM_points_per_batch = gradio.Slider(1, 64, 32, 1, label = "Points Per Batch", interactive = True) | |
auto_generate_SAM_pred_iou_thresh = gradio.Slider(0.0, 1.0, 0.8, 1, label = "Pred IOU Threshold", interactive = True) | |
auto_generate_SAM_stability_score_thresh = gradio.Slider(0.0, 1.0, 0.95, label = "Stability Score Threshold", interactive = True) | |
auto_generate_SAM_stability_score_offset = gradio.Slider(0.0, 1.0, 1.0, label = "Stability Score Offset", interactive = True) | |
auto_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "Mask Threshold", interactive = True) | |
auto_generate_SAM_box_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Box NMS Threshold", interactive = True) | |
auto_generate_SAM_crop_n_layers = gradio.Slider(0, 10, 0, 1, label = "Crop N Layers", interactive = True) | |
auto_generate_SAM_crop_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Crop NMS Threshold", interactive = True) | |
auto_generate_SAM_crop_overlay_ratio = gradio.Slider(0.0, 1.0, 512 / 1500, label = "Crop Overlay Ratio", interactive = True) | |
auto_generate_SAM_crop_n_points_downscale_factor = gradio.Slider(1, 10, 1, label = "Crop N Points Downscale Factor", interactive = True) | |
auto_generate_SAM_min_mask_region_area = gradio.Slider(0, 1000, 0, label = "Min Mask Region Area", interactive = True) | |
auto_generate_SAM_use_m2m = gradio.Checkbox(label = "Use M2M", interactive = True) | |
auto_generate_SAM_multimask_output = gradio.Checkbox(value = True, label = "Multi Mask Output", interactive = True) | |
auto_generate_button = gradio.Button("Generate Auto Mask") | |
with gradio.Row(): | |
with gradio.Column(): | |
auto_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode", interactive = True) | |
auto_output_list = gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False) | |
auto_output_bbox = gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False) | |
with gradio.Column(scale = 3): | |
auto_output = gradio_imageslider.ImageSlider() | |
auto_generate_button.click( | |
__generate_auto_mask, | |
inputs = [ | |
auto_input, | |
auto_generate_SAM_points_per_side, | |
auto_generate_SAM_points_per_batch, | |
auto_generate_SAM_pred_iou_thresh, | |
auto_generate_SAM_stability_score_thresh, | |
auto_generate_SAM_stability_score_offset, | |
auto_generate_SAM_mask_threshold, | |
auto_generate_SAM_box_nms_thresh, | |
auto_generate_SAM_crop_n_layers, | |
auto_generate_SAM_crop_nms_thresh, | |
auto_generate_SAM_crop_overlay_ratio, | |
auto_generate_SAM_crop_n_points_downscale_factor, | |
auto_generate_SAM_min_mask_region_area, | |
auto_generate_SAM_use_m2m, | |
auto_generate_SAM_multimask_output, | |
auto_output_mode | |
], | |
outputs = [ | |
auto_output, | |
auto_output_list, | |
auto_output_bbox | |
] | |
) | |
auto_output_list.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output]) | |
auto_output_bbox.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output]) | |
auto_output_mode.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output]) | |
if __name__ == "__main__": | |
base_app.launch() | |