Spaces:
Sleeping
Sleeping
import gradio as gr | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
import os | |
from pathlib import Path | |
import shutil | |
import tempfile | |
# Load YOLOv8 model | |
model = YOLO("best.pt") | |
# Create directories if not present | |
uploaded_folder = Path('Uploaded_Picture') | |
predicted_folder = Path('Predicted_Picture') | |
uploaded_folder.mkdir(parents=True, exist_ok=True) | |
predicted_folder.mkdir(parents=True, exist_ok=True) | |
# Global patient data list to accumulate HTML data | |
patient_data = [] | |
def predict_image(input_image, name, age, medical_record, sex): | |
if input_image is None: | |
return None, "Please Input The Image" | |
# Convert Gradio input image (PIL Image) to numpy array | |
image_np = np.array(input_image) | |
# Ensure the image is in the correct format | |
if len(image_np.shape) == 2: # grayscale to RGB | |
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) | |
elif image_np.shape[2] == 4: # RGBA to RGB | |
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) | |
# Perform prediction | |
results = model(image_np) | |
# Draw bounding boxes and white circle on the image | |
image_with_boxes = image_np.copy() | |
raw_predictions = [] | |
if results[0].boxes: | |
# Sort the results by confidence and take the highest confidence one | |
highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item()) | |
# Determine the label based on the class index | |
class_index = highest_confidence_result.cls.item() | |
if class_index == 0: | |
label = "Immature" | |
color = (0, 255, 255) # Yellow for Immature | |
elif class_index == 1: | |
label = "Mature" | |
color = (255, 0, 0) # Red for Mature | |
else: | |
label = "Normal" | |
color = (0, 255, 0) # Green for Normal | |
confidence = highest_confidence_result.conf.item() | |
xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0]) | |
# Draw the bounding box | |
cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) | |
# Draw the white circle in the center of the bounding box | |
box_width = xmax - xmin | |
box_height = ymax - ymin | |
center_x = xmin + box_width // 2 | |
center_y = ymin + box_height // 2 | |
radius = int((box_width + box_height) / 2 / 12) | |
cv2.circle(image_with_boxes, (center_x, center_y), radius, (255, 255, 255), 2) | |
# Enlarge font scale and thickness | |
font_scale = 1.0 | |
thickness = 2 | |
# Calculate label background size | |
(text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) | |
cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED) | |
# Put the label text with black background | |
cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) | |
raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]") | |
raw_predictions_str = "\n".join(raw_predictions) | |
# Convert to PIL image for further processing | |
pil_image_with_boxes = Image.fromarray(image_with_boxes) | |
# Add text and watermark | |
pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, age, medical_record, sex, label) | |
# Save images to directories | |
image_name = f"{name}-{age}-{sex}-{medical_record}.png" | |
input_image.save(uploaded_folder / image_name) | |
pil_image_with_boxes.save(predicted_folder / image_name) | |
return pil_image_with_boxes, raw_predictions_str | |
# Function to add text and watermark | |
def add_text_and_watermark(image, name, age, medical_record, sex, label): | |
draw = ImageDraw.Draw(image) | |
# Load a larger font (adjust the size as needed) | |
font_size = 24 # Example font size | |
try: | |
font = ImageFont.truetype("font.ttf", size=font_size) | |
except IOError: | |
font = ImageFont.load_default() | |
print("Error: cannot open resource, using default font.") | |
text = f"Name: {name}, Age: {age}, Medical Record: {medical_record}, Sex: {sex}, Result: {label}" | |
# Calculate text bounding box | |
text_bbox = draw.textbbox((0, 0), text, font=font) | |
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] | |
text_x = 20 | |
text_y = 40 | |
padding = 10 | |
# Draw a filled rectangle for the background | |
draw.rectangle( | |
[text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding], | |
fill="black" | |
) | |
# Draw text on top of the rectangle | |
draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font) | |
return image | |
# Function to save patient info in HTML and accumulate data | |
def save_patient_info_to_html(name, age, medical_record, sex, result): | |
global patient_data | |
new_data = f"<p><strong>Name:</strong> {name}, <strong>Age:</strong> {age}, <strong>Medical Record:</strong> {medical_record}, <strong>Sex:</strong> {sex}, <strong>Result:</strong> {result}</p>" | |
patient_data.append(new_data) | |
html_content = f""" | |
<html> | |
<body> | |
<h1>Patient Information</h1> | |
{''.join(patient_data)} | |
</body> | |
</html> | |
""" | |
# Save HTML content to file | |
html_file_path = os.path.join(tempfile.gettempdir(), 'patient_info.html') | |
with open(html_file_path, 'w') as f: | |
f.write(html_content) | |
return html_file_path | |
# Function to download the folders (fix: pass string path instead of Path object) | |
def download_folder(folder_path): | |
zip_path = os.path.join(tempfile.gettempdir(), f"{Path(folder_path).name}.zip") | |
# Zip the folder | |
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', folder_path) | |
return zip_path | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# Cataract Detection System") | |
gr.Markdown("Upload an image to detect cataract and add patient details.") | |
gr.Markdown("This application uses YOLOv8 with mAP=0.981") | |
with gr.Column(): | |
name = gr.Textbox(label="Name") | |
age = gr.Number(label="Age") | |
medical_record = gr.Number(label="Medical Record") | |
sex = gr.Radio(["Male", "Female"], label="Sex") | |
input_image = gr.Image(type="pil", label="Upload an Image", image_mode="RGB") | |
with gr.Column(): | |
submit_btn = gr.Button("Submit") | |
output_image = gr.Image(type="pil", label="Predicted Image") | |
with gr.Row(): | |
raw_result = gr.Textbox(label="Prediction Result") | |
with gr.Row(): | |
download_html_btn = gr.Button("Download Patient Information (HTML)") | |
download_uploaded_btn = gr.Button("Download Uploaded Images") | |
download_predicted_btn = gr.Button("Download Predicted Images") | |
# Add file download output components for the uploaded and predicted images | |
patient_info_file = gr.File(label="Patient Information HTML File") | |
uploaded_folder_file = gr.File(label="Uploaded Images Zip File") | |
predicted_folder_file = gr.File(label="Predicted Images Zip File") | |
# Fix: pass string representation of Path objects | |
submit_btn.click(fn=predict_image, inputs=[name, age, medical_record, sex, input_image], outputs=[output_image, raw_result]) | |
download_html_btn.click(fn=save_patient_info_to_html, inputs=[name, age, medical_record, sex, raw_result], outputs=patient_info_file) | |
download_uploaded_btn.click(fn=download_folder, inputs=[str(uploaded_folder)], outputs=uploaded_folder_file) | |
download_predicted_btn.click(fn=download_folder, inputs=[str(predicted_folder)], outputs=predicted_folder_file) | |
# Launch Gradio app | |
demo.launch() |