import cv2 import gradio as gr import numpy as np import supervision as sv import torch from PIL import Image from transformers import ( RTDetrForObjectDetection, RTDetrImageProcessor, VitPoseForPoseEstimation, VitPoseImageProcessor, pipeline, ) KEYPOINT_LABEL_MAP = { 0: "Nose", 1: "L_Eye", 2: "R_Eye", 3: "L_Ear", 4: "R_Ear", 5: "L_Shoulder", 6: "R_Shoulder", 7: "L_Elbow", 8: "R_Elbow", 9: "L_Wrist", 10: "R_Wrist", 11: "L_Hip", 12: "R_Hip", 13: "L_Knee", 14: "R_Knee", 15: "L_Ankle", 16: "R_Ankle", } class InteractionDetector: def __init__(self): self.person_detector = None self.person_processor = None self.pose_model = None self.pose_processor = None self.depth_model = None self.segmentation_model = None self.interaction_threshold = 2 self.load_models() def load_models(self): """Load all required models""" # Person detection model self.person_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") self.person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") # Pose estimation model self.pose_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple") self.pose_model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple") # Depth estimation model self.depth_model = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") # Semantic segmentation model self.segmentation_model = pipeline("image-segmentation", model="facebook/maskformer-swin-base-ade") self.segmentation_id2label = self.segmentation_model.model.config.id2label self.segmentation_label2id = {v: k for k, v in self.segmentation_model.model.config.id2label.items()} def get_nearest_pixel_class(self, joint, depth_map, segmentation_map): """ Find the nearest pixel of a specific class to a given joint coordinate Args: joint: (x, y) coordinates of the joint depth_map: Depth map segmentation_map: Semantic segmentation results Returns: tuple: class_name of nearest pixel, distance to that pixel """ PERSON_ID = 12 grid_x, grid_y = np.meshgrid(np.arange(depth_map.shape[0]), np.arange(depth_map.shape[1])) dist_x = np.abs(grid_x.T - joint[1]) dist_y = np.abs(grid_y.T - joint[0]) dist_coord = dist_x + dist_y depth_dist = np.abs(depth_map - depth_map[joint[1], joint[0]]) depth_dist[(segmentation_map == PERSON_ID) | (dist_coord > 50)] = 255 min_dist = np.unravel_index(np.argmin(depth_dist), depth_dist.shape) return segmentation_map[min_dist], depth_dist[min_dist] def detect_persons(self, image: Image.Image): """Detect persons in the image""" inputs = self.person_processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = self.person_detector(**inputs) results = self.person_processor.post_process_object_detection( outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3 ) boxes = results[0]["boxes"][results[0]["labels"] == 0] scores = results[0]["scores"][results[0]["labels"] == 0] return boxes.cpu().numpy(), scores.cpu().numpy() def detect_keypoints(self, image: Image.Image): """Detect keypoints in the image""" boxes, scores = self.detect_persons(image) pixel_values = self.pose_processor(image, boxes=[boxes], return_tensors="pt").pixel_values with torch.no_grad(): outputs = self.pose_model(pixel_values) pose_results = self.pose_processor.post_process_pose_estimation(outputs, boxes=[boxes])[0] return pose_results, boxes, scores def estimate_depth(self, image: Image.Image): """Estimate depth for the image""" with torch.no_grad(): depth_map = np.array(self.depth_model(image)['depth']) return depth_map def segment_image(self, image: Image.Image): """Perform semantic segmentation on the image""" with torch.no_grad(): segmentation_map = self.segmentation_model(image) result = np.zeros(np.array(image).shape[:2], dtype=np.uint8) print("Found", [l['label'] for l in segmentation_map]) for cls_item in sorted(segmentation_map, key=lambda l: np.sum(l['mask']), reverse=True): result[np.array(cls_item['mask']) > 0] = self.segmentation_label2id[cls_item['label']] return result def detect_wall_interaction(self, image: Image.Image): """Detect if hands are touching walls""" # Get all necessary information pose_results, boxes, scores = self.detect_keypoints(image) depth_map = self.estimate_depth(image) segmentation_map = self.segment_image(image) interactions = [] for person_idx, pose_result in enumerate(pose_results): # Get hand keypoints right_hand = pose_result["keypoints"][10].numpy().astype(int) left_hand = pose_result["keypoints"][9].numpy().astype(int) # Find nearest anything pixels right_cls, r_distance = self.get_nearest_pixel_class(right_hand[:2], depth_map, segmentation_map) left_cls, l_distance = self.get_nearest_pixel_class(left_hand[:2], depth_map, segmentation_map) # Check for interactions right_touching = r_distance < self.interaction_threshold left_touching = l_distance < self.interaction_threshold interactions.append({ "person_id": person_idx, "right_hand_touching_object": self.segmentation_id2label[right_cls], "left_hand_touching_object": self.segmentation_id2label[left_cls], "right_hand_touching": right_touching, "left_hand_touching": left_touching, "right_hand_distance": r_distance, "left_hand_distance": l_distance }) return interactions, pose_results, segmentation_map, depth_map def visualize_results(self, image: Image.Image, interactions, pose_results): """Visualize detection results""" # Create base visualization from original image vis_image = np.array(image).copy() # Add pose keypoints edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=2) key_points = sv.KeyPoints( xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy() ) vis_image = edge_annotator.annotate(scene=vis_image, key_points=key_points) # Add interaction indicators for interaction in interactions: person_id = interaction["person_id"] pose_result = pose_results[person_id] # Draw indicators for touching hands if interaction["right_hand_touching"]: cv2.circle(vis_image, tuple(map(int, pose_result["keypoints"][10][:2])), 10, (0, 0, 255), -1) if interaction["left_hand_touching"]: cv2.circle(vis_image, tuple(map(int, pose_result["keypoints"][9][:2])), 10, (0, 0, 255), -1) return Image.fromarray(vis_image) def process_image(self, input_image): """Process image and return visualization with interaction detection""" if input_image is None: return None, "" # Convert to PIL Image if necessary if isinstance(input_image, np.ndarray): image = Image.fromarray(input_image) else: image = input_image image = image.resize((1280, 720)) # Detect interactions interactions, pose_results, segmentation_map, depth_map = self.detect_wall_interaction(image) # Visualize results result_image = self.visualize_results(image, interactions, pose_results) # Create interaction information text info_text = [] for interaction in interactions: info_text.append(f"\nPerson {interaction['person_id'] + 1}:") if interaction["right_hand_touching"]: info_text.append(f"Right hand is touching {interaction['right_hand_touching_object']}") if interaction["left_hand_touching"]: info_text.append(f"Left hand is touching {interaction['left_hand_touching_object']}") info_text.append(f"Right hand distance to wall: {interaction['right_hand_distance']:.2f}") info_text.append(f"Left hand distance to wall: {interaction['left_hand_distance']:.2f}") # Add color to segmentation mask = np.zeros((*segmentation_map.shape, 3), dtype=np.uint8) colors = np.random.randint(0, 255, size=(100, 3)) for cl_id in np.unique(segmentation_map): mask_array = np.array(segmentation_map == cl_id) color = colors[cl_id % len(colors)] mask[mask_array] = color return result_image, mask, depth_map, "\n".join(info_text) def create_gradio_interface(): """Create Gradio interface""" detector = InteractionDetector() with gr.Blocks() as interface: gr.Markdown("# Object Interaction Detection") gr.Markdown("Upload an image to detect when people are touching objects.") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image") process_button = gr.Button("Detect Interactions") with gr.Column(): output_image = gr.Image(label="Detection Results") interaction_info = gr.Textbox( label="Interaction Information", lines=10, placeholder="Interaction details will appear here..." ) segmentation_im = gr.Image(label="Segmentaiton Results") depth_im = gr.Image(label="Depth Results") process_button.click( fn=detector.process_image, inputs=input_image, outputs=[output_image, segmentation_im, depth_im, interaction_info] ) gr.Examples( examples=[ "images/1-8ea4418f.jpg", "images/276757975.jpg" ], inputs=input_image ) return interface interface = create_gradio_interface() if __name__ == "__main__": interface.launch(debug=True)