|
import gradio as gr |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import sqlite3 |
|
import base64 |
|
from io import BytesIO |
|
import tempfile |
|
import pandas as pd |
|
|
|
|
|
model = YOLO("best.pt") |
|
|
|
def predict_image(input_image, name, patient_id): |
|
if input_image is None: |
|
return None, "Please Input The Image" |
|
|
|
|
|
image_np = np.array(input_image) |
|
|
|
|
|
if len(image_np.shape) == 2: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) |
|
elif image_np.shape[2] == 4: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
results = model(image_np) |
|
|
|
|
|
image_with_boxes = image_np.copy() |
|
raw_predictions = [] |
|
|
|
if results[0].boxes: |
|
|
|
highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item()) |
|
|
|
|
|
class_index = highest_confidence_result.cls.item() |
|
if class_index == 0: |
|
label = "Immature" |
|
color = (0, 255, 255) |
|
elif class_index == 1: |
|
label = "Mature" |
|
color = (255, 0, 0) |
|
else: |
|
label = "Normal" |
|
color = (0, 255, 0) |
|
|
|
confidence = highest_confidence_result.conf.item() |
|
xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0]) |
|
|
|
|
|
cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) |
|
|
|
|
|
font_scale = 1.0 |
|
thickness = 2 |
|
|
|
|
|
(text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) |
|
cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED) |
|
|
|
|
|
cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) |
|
|
|
raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]") |
|
|
|
raw_predictions_str = "\n".join(raw_predictions) |
|
|
|
|
|
pil_image_with_boxes = Image.fromarray(image_with_boxes) |
|
|
|
|
|
pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, patient_id, label) |
|
|
|
return pil_image_with_boxes, raw_predictions_str |
|
|
|
|
|
def add_watermark(image): |
|
try: |
|
logo = Image.open('image-logo.png').convert("RGBA") |
|
image = image.convert("RGBA") |
|
|
|
|
|
basewidth = 100 |
|
wpercent = (basewidth / float(logo.size[0])) |
|
hsize = int((float(wpercent) * logo.size[1])) |
|
logo = logo.resize((basewidth, hsize), Image.LANCZOS) |
|
|
|
|
|
position = (image.width - logo.width - 10, image.height - logo.height - 10) |
|
|
|
|
|
transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0)) |
|
transparent.paste(image, (0, 0)) |
|
transparent.paste(logo, position, mask=logo) |
|
|
|
return transparent.convert("RGB") |
|
except Exception as e: |
|
print(f"Error adding watermark: {e}") |
|
return image |
|
|
|
|
|
def add_text_and_watermark(image, name, patient_id, label): |
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
font_size = 48 |
|
try: |
|
font = ImageFont.truetype("font.ttf", size=font_size) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
print("Error: cannot open resource, using default font.") |
|
|
|
text = f"Name: {name}, ID: {patient_id}, Result: {label}" |
|
|
|
|
|
text_bbox = draw.textbbox((0, 0), text, font=font) |
|
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] |
|
text_x = 20 |
|
text_y = 40 |
|
padding = 10 |
|
|
|
|
|
draw.rectangle( |
|
[text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding], |
|
fill="black" |
|
) |
|
|
|
|
|
draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font) |
|
|
|
|
|
image_with_watermark = add_watermark(image) |
|
|
|
return image_with_watermark |
|
|
|
|
|
def init_db(): |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
c.execute('''CREATE TABLE IF NOT EXISTS results |
|
(id INTEGER PRIMARY KEY, name TEXT, patient_id TEXT, input_image BLOB, predicted_image BLOB, result TEXT)''') |
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
def submit_result(name, patient_id, input_image, predicted_image, result): |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
|
|
input_image_np = np.array(input_image) |
|
_, input_buffer = cv2.imencode('.png', cv2.cvtColor(input_image_np, cv2.COLOR_RGB2BGR)) |
|
input_image_bytes = input_buffer.tobytes() |
|
|
|
predicted_image_np = np.array(predicted_image) |
|
predicted_image_rgb = cv2.cvtColor(predicted_image_np, cv2.COLOR_RGB2BGR) |
|
_, predicted_buffer = cv2.imencode('.png', predicted_image_rgb) |
|
predicted_image_bytes = predicted_buffer.tobytes() |
|
|
|
c.execute("INSERT INTO results (name, patient_id, input_image, predicted_image, result) VALUES (?, ?, ?, ?, ?)", |
|
(name, patient_id, input_image_bytes, predicted_image_bytes, result)) |
|
conn.commit() |
|
conn.close() |
|
return "Result submitted to database." |
|
|
|
|
|
def view_database(): |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
c.execute("SELECT name, patient_id, input_image, predicted_image FROM results") |
|
rows = c.fetchall() |
|
conn.close() |
|
|
|
|
|
df = pd.DataFrame(rows, columns=["Name", "Patient ID", "Input Image", "Predicted Image"]) |
|
|
|
return df |
|
|
|
|
|
def download_file(choice): |
|
if choice == "Database (.db)": |
|
|
|
return 'results.db' |
|
else: |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
c.execute("SELECT predicted_image FROM results ORDER BY id DESC LIMIT 1") |
|
row = c.fetchone() |
|
conn.close() |
|
if row: |
|
image_bytes = row[0] |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: |
|
temp_file.write(image_bytes) |
|
temp_file.flush() |
|
return temp_file.name |
|
else: |
|
raise FileNotFoundError("No images found in the database.") |
|
|
|
|
|
init_db() |
|
|
|
|
|
def interface(name, patient_id, input_image): |
|
if input_image is None: |
|
return "Please upload an image." |
|
|
|
output_image, raw_result = predict_image(input_image, name, patient_id) |
|
submit_status = submit_result(name, patient_id, input_image, output_image, raw_result) |
|
|
|
return output_image, raw_result, submit_status |
|
|
|
|
|
def view_db_interface(): |
|
df = view_database() |
|
return df |
|
|
|
|
|
def download_interface(choice): |
|
try: |
|
file_path = download_file(choice) |
|
with open(file_path, "rb") as file: |
|
return file.read(), file_path.split('/')[-1] |
|
except FileNotFoundError as e: |
|
return str(e), None |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Markdown("# Cataract Detection System") |
|
gr.Markdown("Upload an image to detect cataract and add patient details.") |
|
gr.Image("PR_curve.png", label="Model PR Curve") |
|
gr.Markdown("This application uses YOLOv8 with mAP=0.981") |
|
|
|
with gr.Column(): |
|
name = gr.Textbox(label="Name") |
|
patient_id = gr.Textbox(label="Patient ID") |
|
input_image = gr.Image(type="pil", label="Upload an Image", image_mode="RGB") |
|
|
|
with gr.Column(): |
|
submit_btn = gr.Button("Submit") |
|
output_image = gr.Image(type="pil", label="Predicted Image") |
|
|
|
with gr.Row(): |
|
raw_result = gr.Textbox(label="Raw Result", lines=5) |
|
submit_status = gr.Textbox(label="Submission Status") |
|
|
|
submit_btn.click(fn=interface, inputs=[name, patient_id, input_image], outputs=[output_image, raw_result, submit_status]) |
|
|
|
with gr.Column(): |
|
view_db_btn = gr.Button("View Database") |
|
db_output = gr.Dataframe(label="Database Records") |
|
|
|
view_db_btn.click(fn=view_db_interface, inputs=[], outputs=[db_output]) |
|
|
|
with gr.Column(): |
|
download_choice = gr.Radio(["Database (.db)", "Predicted Image (.png)"], label="Choose the file to download:") |
|
download_btn = gr.Button("Download") |
|
download_output = gr.File(label="Download File") |
|
|
|
download_btn.click(fn=download_interface, inputs=[download_choice], outputs=[download_output]) |
|
|
|
|
|
demo.launch() |