Spaces:
Sleeping
Sleeping
import gradio as gr | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
import base64 | |
from io import BytesIO | |
import zipfile | |
import os | |
from pathlib import Path | |
# Load YOLOv8 model | |
model = YOLO("best.pt") | |
# Define paths for uploaded and predicted images | |
uploaded_folder = Path('Uploaded_Picture') | |
predicted_folder = Path('Predicted_Picture') | |
uploaded_folder.mkdir(parents=True, exist_ok=True) | |
predicted_folder.mkdir(parents=True, exist_ok=True) | |
# Path for HTML database file | |
html_db_file = Path('patient_predictions.html') | |
# Initialize HTML file if not present | |
if not html_db_file.exists(): | |
with open(html_db_file, 'w') as f: | |
f.write(""" | |
<html> | |
<head><title>Patient Prediction Database</title></head> | |
<body> | |
<h1>Patient Prediction Database</h1> | |
<table border="1" style="width:100%; border-collapse: collapse; text-align: center;"> | |
<thead> | |
<tr> | |
<th>Name</th> | |
<th>Age</th> | |
<th>Medical Record</th> | |
<th>Sex</th> | |
<th>Result</th> | |
<th>Predicted Image</th> | |
</tr> | |
</thead> | |
<tbody> | |
""") | |
def predict_image(input_image, name, age, medical_record, sex): | |
# Ensure input image is provided | |
if input_image is None: | |
return None, "Please upload an image for prediction." | |
# Convert PIL image to NumPy array | |
image_np = np.array(input_image) | |
# Perform YOLO prediction | |
results = model(image_np) | |
image_with_boxes = image_np.copy() | |
label = "Unknown" | |
if results[0].boxes: | |
# Take the result with the highest confidence | |
best_result = max(results[0].boxes, key=lambda x: x.conf.item()) | |
class_index = best_result.cls.item() | |
# Determine class label | |
if class_index == 0: | |
label = "Immature" | |
color = (0, 255, 255) | |
elif class_index == 1: | |
label = "Mature" | |
color = (255, 0, 0) | |
else: | |
label = "Normal" | |
color = (0, 255, 0) | |
confidence = best_result.conf.item() | |
xmin, ymin, xmax, ymax = map(int, best_result.xyxy[0]) | |
# Draw bounding box and label on image | |
cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) | |
font_scale, thickness = 1.0, 2 | |
cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness) | |
# Convert the annotated image back to PIL | |
pil_image_with_boxes = Image.fromarray(image_with_boxes) | |
# Save images to folders | |
image_name = f"{name}_{age}_{medical_record}_{sex}.png" | |
input_image.save(uploaded_folder / image_name) | |
pil_image_with_boxes.save(predicted_folder / image_name) | |
# Convert predicted image to base64 for embedding in HTML | |
buffered = BytesIO() | |
pil_image_with_boxes.save(buffered, format="PNG") | |
predicted_image_base64 = base64.b64encode(buffered.getvalue()).decode() | |
# Append patient information to HTML | |
append_patient_info_to_html(name, age, medical_record, sex, label, predicted_image_base64) | |
raw_prediction = f"Name: {name}, Age: {age}, Medical Record: {medical_record}, Sex: {sex}, Result: {label}" | |
return pil_image_with_boxes, raw_prediction | |
def append_patient_info_to_html(name, age, medical_record, sex, result, predicted_image_base64): | |
# Append a new patient entry to the HTML file | |
html_entry = f""" | |
<tr> | |
<td>{name}</td> | |
<td>{age}</td> | |
<td>{medical_record}</td> | |
<td>{sex}</td> | |
<td>{result}</td> | |
<td><img src="data:image/png;base64,{predicted_image_base64}" alt="Predicted Image" width="150"></td> | |
</tr> | |
""" | |
with open(html_db_file, 'a') as f: | |
f.write(html_entry) | |
# Close the HTML file after writing (for proper structure) | |
with open(html_db_file, 'a') as f: | |
f.write("</tbody></table></body></html>") | |
return str(html_db_file) | |
def download_uploaded_folder(): | |
# Create a zip file of the uploaded folder | |
zip_path = 'uploaded_images.zip' | |
with zipfile.ZipFile(zip_path, 'w') as zf: | |
for file in uploaded_folder.iterdir(): | |
zf.write(file, arcname=file.name) | |
return zip_path | |
def download_predicted_folder(): | |
# Create a zip file of the predicted folder | |
zip_path = 'predicted_images.zip' | |
with zipfile.ZipFile(zip_path, 'w') as zf: | |
for file in predicted_folder.iterdir(): | |
zf.write(file, arcname=file.name) | |
return zip_path | |
# Launch Gradio Interface | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# Cataract Detection System") | |
gr.Markdown("Upload an image to detect cataract and add patient details.") | |
gr.Markdown("This application uses YOLOv8 with mAP=0.981") | |
with gr.Column(): | |
name = gr.Textbox(label="Name") | |
age = gr.Number(label="Age") | |
medical_record = gr.Number(label="Medical Record") | |
sex = gr.Radio(["Male", "Female"], label="Sex") | |
input_image = gr.Image(type="pil", label="Upload an Image", image_mode="RGB") | |
with gr.Column(): | |
submit_btn = gr.Button("Submit") | |
output_image = gr.Image(type="pil", label="Predicted Image") | |
with gr.Row(): | |
raw_result = gr.Textbox(label="Prediction Result") | |
with gr.Row(): | |
download_html_btn = gr.Button("Download Patient Information (HTML)") | |
download_uploaded_btn = gr.Button("Download Uploaded Images") | |
download_predicted_btn = gr.Button("Download Predicted Images") | |
# Add file download output components for the uploaded and predicted images | |
patient_info_file = gr.File(label="Patient Information HTML File") | |
uploaded_folder_file = gr.File(label="Uploaded Images Zip File") | |
predicted_folder_file = gr.File(label="Predicted Images Zip File") | |
# Connect functions with components | |
submit_btn.click(fn=predict_image, inputs=[name, age, medical_record, sex, input_image], outputs=[output_image, raw_result]) | |
download_html_btn.click(fn=append_patient_info_to_html, inputs=[name, age, medical_record, sex, raw_result], outputs=patient_info_file) | |
download_uploaded_btn.click(fn=download_uploaded_folder, outputs=uploaded_folder_file) | |
download_predicted_btn.click(fn=download_predicted_folder, outputs=predicted_folder_file) | |
# Launch Gradio app | |
demo.launch() |