|
import gradio as gr |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import sqlite3 |
|
import pandas as pd |
|
|
|
|
|
model = YOLO("best.pt") |
|
|
|
|
|
label_mapping = {0: 'immature', 1: 'mature', 2: 'normal'} |
|
inverse_label_mapping = {'immature': 0, 'mature': 1, 'normal': 2} |
|
|
|
|
|
def predict_image(input_image, name, patient_id): |
|
if input_image is None: |
|
return None, "Please Input The Image" |
|
|
|
|
|
image_np = np.array(input_image) |
|
|
|
|
|
if len(image_np.shape) == 2: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) |
|
elif image_np.shape[2] == 4: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
results = model(image_np) |
|
|
|
|
|
image_with_boxes = image_np.copy() |
|
raw_predictions = [] |
|
|
|
if results[0].boxes: |
|
|
|
for i in range(len(results[0].boxes)): |
|
box = results[0].boxes[i] |
|
predicted_class = int(box.cls.item()) |
|
confidence = box.conf.item() |
|
|
|
|
|
if confidence >= 0.5: |
|
|
|
label = label_mapping[predicted_class] |
|
|
|
|
|
xmin, ymin, xmax, ymax = map(int, box.xyxy[0]) |
|
|
|
|
|
color = (0, 255, 0) if label == 'normal' else (0, 255, 255) if label == 'immature' else (255, 0, 0) |
|
|
|
|
|
cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) |
|
|
|
|
|
cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
|
|
|
raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]") |
|
|
|
|
|
pil_image_with_boxes = Image.fromarray(image_with_boxes) |
|
|
|
return pil_image_with_boxes, "\n".join(raw_predictions) |
|
|
|
|
|
def interface(name, patient_id, input_image): |
|
if input_image is None: |
|
return "Please upload an image." |
|
|
|
|
|
output_image, raw_result = predict_image(input_image, name, patient_id) |
|
|
|
return output_image, raw_result |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Markdown("# Cataract Detection System") |
|
gr.Markdown("Upload an image to detect cataract and add patient details.") |
|
|
|
with gr.Column(): |
|
name = gr.Textbox(label="Name") |
|
patient_id = gr.Textbox(label="Patient ID") |
|
input_image = gr.Image(type="pil", label="Upload an Image", image_mode="RGB") |
|
|
|
with gr.Column(): |
|
submit_btn = gr.Button("Submit") |
|
output_image = gr.Image(type="pil", label="Predicted Image") |
|
|
|
with gr.Row(): |
|
raw_result = gr.Textbox(label="Raw Result", lines=5) |
|
|
|
submit_btn.click(fn=interface, inputs=[name, patient_id, input_image], outputs=[output_image, raw_result]) |
|
|
|
|
|
demo.launch() |