import gradio as gr from ultralyticsplus import YOLO, render_result import numpy as np import time import torch # System Configuration print("\n" + "="*40) print(f"PyTorch: {torch.__version__}") print(f"CUDA Available: {torch.cuda.is_available()}") print("="*40 + "\n") # Load model with optimized parameters for leaf counting model = YOLO('foduucom/plant-leaf-detection-and-classification') # Custom configuration for leaf counting model.overrides.update({ 'conf': 0.15, # Lower confidence threshold for better recall 'iou': 0.25, # Lower IoU threshold for overlapping leaves 'imgsz': 1280, # Higher resolution for small leaves 'agnostic_nms': False, 'max_det': 300, # Higher maximum detections 'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'classes': None, # Detect all classes (leaves only in this model) 'half': torch.cuda.is_available() }) def count_leaves(image): try: start_time = time.time() # Preprocessing - enhance contrast image = np.array(image) lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) cl = clahe.apply(l) limg = cv2.merge((cl,a,b)) enhanced_img = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB) # Prediction with overlap handling results = model.predict( source=enhanced_img, augment=True, # Test time augmentation verbose=False, agnostic_nms=False, overlap_mask=False ) # Post-processing for overlapping leaves boxes = results[0].boxes valid_boxes = [] # Filter small detections and merge overlapping for box in boxes: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() w = x2 - x1 h = y2 - y1 # Filter too small boxes (adjust based on your leaf sizes) if w > 20 and h > 20: valid_boxes.append(box) # Improved NMS for overlapping leaves from utils.nms import non_max_suppression final_boxes = non_max_suppression( torch.stack([b.xywh[0] for b in valid_boxes]), conf_thres=0.1, iou_thres=0.15, multi_label=False ) num_leaves = len(final_boxes) # Visual validation debug_img = enhanced_img.copy() for box in final_boxes: x1, y1, x2, y2 = map(int, box[:4]) cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0,255,0), 2) print(f"Processing time: {time.time()-start_time:.2f}s") return debug_img, num_leaves except Exception as e: print(f"Error: {str(e)}") return image, 0 # Gradio interface with visualization interface = gr.Interface( fn=count_leaves, inputs=gr.Image(label="Input Image"), outputs=[ gr.Image(label="Detection Visualization"), gr.Number(label="Estimated Leaf Count") ], title="🍃 Advanced Leaf Counter", description="Specialized for overlapping leaves and dense foliage", examples=[ ["sample_leaf1.jpg"], ["sample_leaf2.jpg"] ] ) if __name__ == "__main__": interface.launch( server_port=7860, share=False )