|
import gradio as gr |
|
from ultralytics import YOLO |
|
import os |
|
from PIL import Image |
|
import cv2 |
|
import numpy as np |
|
import uuid |
|
|
|
class TreeDetectionModel: |
|
def __init__(self, model_path): |
|
""" |
|
Initialize the YOLOv8 model for tree detection. |
|
Args: |
|
model_path (str): Path to the trained YOLOv8 model (.pt or .onnx). |
|
""" |
|
self.model = YOLO(model_path) |
|
|
|
def detect(self, image_path): |
|
""" |
|
Perform inference on an image. |
|
Args: |
|
image_path (str): Path to the input image. |
|
Returns: |
|
dict: Detection results including bounding boxes, class labels, and confidence scores. |
|
""" |
|
results = self.model(image_path) |
|
detections = [] |
|
|
|
for result in results: |
|
for box in result.boxes: |
|
detections.append({ |
|
'class': result.names[int(box.cls)], |
|
'confidence': float(box.conf), |
|
'bbox': box.xyxy.tolist()[0] |
|
}) |
|
|
|
return detections |
|
|
|
def extract_geotag(self, image_path): |
|
""" |
|
Extract GPS coordinates from the image's EXIF data. |
|
Args: |
|
image_path (str): Path to the input image. |
|
Returns: |
|
dict: Geotag information (latitude, longitude, altitude). |
|
""" |
|
img = Image.open(image_path) |
|
exif_data = img._getexif() |
|
if exif_data: |
|
gps_info = exif_data.get(34853, {}) |
|
return { |
|
'lat': gps_info.get(2), |
|
'lon': gps_info.get(4), |
|
'alt': gps_info.get(6) |
|
} |
|
return None |
|
|
|
|
|
|
|
model_path = "/home/kalie/work/projects/Community-Tree/ai/yolov8_model/best.pt" |
|
tree_detector = TreeDetectionModel(model_path) |
|
|
|
def detect_trees(image): |
|
""" |
|
Perform tree detection on the uploaded image and display results. |
|
Args: |
|
image (PIL.Image): Input image. |
|
Returns: |
|
tuple: (annotated_image, detections_table, geotag_info) |
|
""" |
|
|
|
image_cv = np.array(image) |
|
image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
temp_image_path = "temp_image.jpg" |
|
cv2.imwrite(temp_image_path, image_cv) |
|
|
|
|
|
detections = tree_detector.detect(temp_image_path) |
|
|
|
|
|
for detection in detections: |
|
x1, y1, x2, y2 = map(int, detection['bbox']) |
|
cv2.rectangle(image_cv, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
label = f"{detection['class']} ({detection['confidence']:.2f})" |
|
cv2.putText(image_cv, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) |
|
|
|
|
|
annotated_image = Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
detections_table = [ |
|
[detection['class'], detection['confidence'], detection['bbox']] |
|
for detection in detections |
|
] |
|
|
|
|
|
geotag = tree_detector.extract_geotag(temp_image_path) |
|
geotag_info = f"Latitude: {geotag['lat']}\nLongitude: {geotag['lon']}\nAltitude: {geotag['alt']}" if geotag else "No geotag found." |
|
|
|
return annotated_image, detections_table, geotag_info |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# π³ Tree Detection App") |
|
gr.Markdown("Upload an image to detect trees and view results.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Upload Image", type="pil") |
|
submit_button = gr.Button("Detect Trees") |
|
with gr.Column(): |
|
image_output = gr.Image(label="Annotated Image") |
|
detections_output = gr.Dataframe( |
|
headers=["Class", "Confidence", "Bounding Box"], |
|
label="Detection Results" |
|
) |
|
geotag_output = gr.Textbox(label="Geotag Information") |
|
|
|
submit_button.click( |
|
detect_trees, |
|
inputs=image_input, |
|
outputs=[image_output, detections_output, geotag_output] |
|
) |
|
|
|
|
|
demo.launch() |