|
import gradio as gr |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import base64 |
|
from io import BytesIO |
|
import zipfile |
|
import os |
|
from pathlib import Path |
|
|
|
|
|
model = YOLO("best.pt") |
|
|
|
|
|
uploaded_folder = Path('Uploaded_Picture') |
|
predicted_folder = Path('Predicted_Picture') |
|
uploaded_folder.mkdir(parents=True, exist_ok=True) |
|
predicted_folder.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
html_db_file = Path('patient_predictions.html') |
|
|
|
|
|
if not html_db_file.exists(): |
|
with open(html_db_file, 'w') as f: |
|
f.write(""" |
|
<html> |
|
<head><title>Patient Prediction Database</title></head> |
|
<body> |
|
<h1>Patient Prediction Database</h1> |
|
<table border="1" style="width:100%; border-collapse: collapse; text-align: center;"> |
|
<thead> |
|
<tr> |
|
<th>Name</th> |
|
<th>Age</th> |
|
<th>Medical Record</th> |
|
<th>Sex</th> |
|
<th>Result</th> |
|
<th>Predicted Image</th> |
|
</tr> |
|
</thead> |
|
<tbody> |
|
""") |
|
|
|
def predict_image(input_image, name, age, medical_record, sex): |
|
|
|
if input_image is None: |
|
return None, "Please upload an image for prediction." |
|
|
|
|
|
image_np = np.array(input_image) |
|
|
|
|
|
results = model(image_np) |
|
image_with_boxes = image_np.copy() |
|
label = "Unknown" |
|
|
|
if results[0].boxes: |
|
|
|
best_result = max(results[0].boxes, key=lambda x: x.conf.item()) |
|
class_index = best_result.cls.item() |
|
|
|
|
|
if class_index == 0: |
|
label = "Immature" |
|
color = (0, 255, 255) |
|
elif class_index == 1: |
|
label = "Mature" |
|
color = (255, 0, 0) |
|
else: |
|
label = "Normal" |
|
color = (0, 255, 0) |
|
|
|
confidence = best_result.conf.item() |
|
xmin, ymin, xmax, ymax = map(int, best_result.xyxy[0]) |
|
|
|
|
|
cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) |
|
font_scale, thickness = 1.0, 2 |
|
cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness) |
|
|
|
|
|
pil_image_with_boxes = Image.fromarray(image_with_boxes) |
|
|
|
|
|
image_name = f"{name}_{age}_{medical_record}_{sex}.png" |
|
input_image.save(uploaded_folder / image_name) |
|
pil_image_with_boxes.save(predicted_folder / image_name) |
|
|
|
|
|
buffered = BytesIO() |
|
pil_image_with_boxes.save(buffered, format="PNG") |
|
predicted_image_base64 = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
append_patient_info_to_html(name, age, medical_record, sex, label, predicted_image_base64) |
|
|
|
raw_prediction = f"Name: {name}, Age: {age}, Medical Record: {medical_record}, Sex: {sex}, Result: {label}" |
|
|
|
return pil_image_with_boxes, raw_prediction |
|
|
|
def append_patient_info_to_html(name, age, medical_record, sex, result, predicted_image_base64): |
|
|
|
html_entry = f""" |
|
<tr> |
|
<td>{name}</td> |
|
<td>{age}</td> |
|
<td>{medical_record}</td> |
|
<td>{sex}</td> |
|
<td>{result}</td> |
|
<td><img src="data:image/png;base64,{predicted_image_base64}" alt="Predicted Image" width="150"></td> |
|
</tr> |
|
""" |
|
|
|
with open(html_db_file, 'a') as f: |
|
f.write(html_entry) |
|
|
|
|
|
with open(html_db_file, 'a') as f: |
|
f.write("</tbody></table></body></html>") |
|
|
|
return str(html_db_file) |
|
|
|
def download_uploaded_folder(): |
|
|
|
zip_path = 'uploaded_images.zip' |
|
with zipfile.ZipFile(zip_path, 'w') as zf: |
|
for file in uploaded_folder.iterdir(): |
|
zf.write(file, arcname=file.name) |
|
return zip_path |
|
|
|
def download_predicted_folder(): |
|
|
|
zip_path = 'predicted_images.zip' |
|
with zipfile.ZipFile(zip_path, 'w') as zf: |
|
for file in predicted_folder.iterdir(): |
|
zf.write(file, arcname=file.name) |
|
return zip_path |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Markdown("# Cataract Detection System") |
|
gr.Markdown("Upload an image to detect cataract and add patient details.") |
|
gr.Markdown("This application uses YOLOv8 with mAP=0.981") |
|
|
|
with gr.Column(): |
|
name = gr.Textbox(label="Name") |
|
age = gr.Number(label="Age") |
|
medical_record = gr.Number(label="Medical Record") |
|
sex = gr.Radio(["Male", "Female"], label="Sex") |
|
input_image = gr.Image(type="pil", label="Upload an Image", image_mode="RGB") |
|
|
|
with gr.Column(): |
|
submit_btn = gr.Button("Submit") |
|
output_image = gr.Image(type="pil", label="Predicted Image") |
|
|
|
with gr.Row(): |
|
raw_result = gr.Textbox(label="Prediction Result") |
|
|
|
with gr.Row(): |
|
download_html_btn = gr.Button("Download Patient Information (HTML)") |
|
download_uploaded_btn = gr.Button("Download Uploaded Images") |
|
download_predicted_btn = gr.Button("Download Predicted Images") |
|
|
|
|
|
patient_info_file = gr.File(label="Patient Information HTML File") |
|
uploaded_folder_file = gr.File(label="Uploaded Images Zip File") |
|
predicted_folder_file = gr.File(label="Predicted Images Zip File") |
|
|
|
|
|
submit_btn.click(fn=predict_image, inputs=[name, age, medical_record, sex, input_image], outputs=[output_image, raw_result]) |
|
download_html_btn.click(fn=append_patient_info_to_html, inputs=[name, age, medical_record, sex, raw_result], outputs=patient_info_file) |
|
download_uploaded_btn.click(fn=download_uploaded_folder, outputs=uploaded_folder_file) |
|
download_predicted_btn.click(fn=download_predicted_folder, outputs=predicted_folder_file) |
|
|
|
|
|
demo.launch() |