Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import torchvision.transforms as T | |
from PIL import Image, ImageDraw, ImageOps | |
import numpy as np | |
from torchvision.models.detection import maskrcnn_resnet50_fpn | |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor | |
import os | |
# Set up device | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
# Load and configure the Mask R-CNN model with 2 classes | |
model_path = "mask_rcnn_lego.pth" | |
if not os.path.exists(model_path): | |
raise FileNotFoundError( | |
"The model file 'mask_rcnn_lego.pth' was not found in the directory." | |
) | |
model = maskrcnn_resnet50_fpn(weights="DEFAULT") | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
# Update the box predictor head to match 2 classes (background + LEGO) | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2) | |
# Update the mask predictor head to match 2 classes | |
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels | |
hidden_layer = 256 | |
model.roi_heads.mask_predictor = MaskRCNNPredictor( | |
in_features_mask, hidden_layer, num_classes=2 | |
) | |
# Now, load the state_dict for your custom model | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.to(device) | |
model.eval() | |
# Set up transformations | |
transform = T.Compose([T.ToTensor()]) | |
# Function to create pseudo-masks based on bounding boxes | |
def create_pseudo_mask(image, box): | |
mask = Image.new("L", image.size, 0) # Create a blank mask | |
draw = ImageDraw.Draw(mask) | |
draw.rectangle(box, fill=255) # Fill in the bounding box area | |
return mask | |
# Function to process image with pseudo-mask visualization and bounding boxes | |
def detect_legos(image, use_pseudo_masks=True): | |
# Apply transformations | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
# Make predictions with the custom model | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
# Extract boxes and scores above threshold | |
boxes = outputs[0]["boxes"].cpu().numpy() | |
scores = outputs[0]["scores"].cpu().numpy() | |
thresholded_indices = [i for i, score in enumerate(scores) if score >= 0.5] | |
boxes = boxes[thresholded_indices] | |
num_legos_detected = len(boxes) | |
# Draw pseudo-masks on the image first | |
image_with_masks = image.copy() | |
for box in boxes: | |
x1, y1, x2, y2 = box | |
# Use pseudo-masks based on bounding boxes | |
if use_pseudo_masks: | |
mask_img = create_pseudo_mask(image, [x1, y1, x2, y2]) | |
mask_img = ImageOps.colorize( | |
mask_img.convert("L"), black="blue", white="blue" | |
).convert("RGBA") | |
image_with_masks.paste(mask_img, (0, 0), mask_img) | |
# Draw the bounding boxes on top of the masks for better visibility | |
draw = ImageDraw.Draw(image_with_masks) | |
for box in boxes: | |
x1, y1, x2, y2 = box | |
draw.rectangle( | |
[x1, y1, x2, y2], outline="yellow", width=3 | |
) # Draw yellow bounding box | |
# Set title with count of detected LEGO pieces | |
title = f"Detected LEGO pieces: {num_legos_detected}" | |
return image_with_masks, title | |
# Gradio interface function | |
def gradio_interface(image): | |
image_with_masks, title = detect_legos(image, use_pseudo_masks=True) | |
return image_with_masks, title | |
# Set up Gradio Interface | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Image(type="pil"), gr.Textbox(label="Detection Summary")], | |
title="LEGO Detection with Mask R-CNN", | |
description="Upload an image to detect and count LEGO pieces with bounding boxes and simulated masks.", | |
) | |
# Launch the Gradio app | |
interface.launch() | |