import gradio as gr from ultralytics import YOLO import cv2 import numpy as np from PIL import Image, ImageDraw, ImageFont import sqlite3 import tempfile import pandas as pd # Load YOLOv8 model model = YOLO("best.pt") # Function to perform prediction def predict_image(input_image, name, patient_id): if input_image is None: return None, "Please Input The Image" # Convert Gradio input image (PIL Image) to numpy array image_np = np.array(input_image) # Ensure the image is in the correct format if len(image_np.shape) == 2: # grayscale to RGB image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) elif image_np.shape[2] == 4: # RGBA to RGB image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) # Perform prediction results = model(image_np) # Draw bounding boxes on the image image_with_boxes = image_np.copy() raw_predictions = [] label = "Unknown" # Default label if no detection if results[0].boxes: for box in results[0].boxes: # Get class index and confidence for each detection class_index = box.cls.item() confidence = box.conf.item() # Determine the label based on the class index if class_index == 0: label = "Immature" color = (0, 255, 255) # Yellow for Immature elif class_index == 1: label = "Mature" color = (255, 0, 0) # Red for Mature else: label = "Normal" color = (0, 255, 0) # Green for Normal xmin, ymin, xmax, ymax = map(int, box.xyxy[0]) # Draw the bounding box cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) # Enlarge font scale and thickness font_scale = 1.0 thickness = 2 # Calculate label background size (text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED) # Put the label text with black background cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]") raw_predictions_str = "\n".join(raw_predictions) # Convert to PIL image for further processing pil_image_with_boxes = Image.fromarray(image_with_boxes) # Add text and watermark pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, patient_id, label) return pil_image_with_boxes, raw_predictions_str # Function to add watermark def add_watermark(image): try: logo = Image.open('image-logo.png').convert("RGBA") image = image.convert("RGBA") # Resize logo basewidth = 100 wpercent = (basewidth / float(logo.size[0])) hsize = int((float(wpercent) * logo.size[1])) logo = logo.resize((basewidth, hsize), Image.LANCZOS) # Position logo position = (image.width - logo.width - 10, image.height - logo.height - 10) # Composite image transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0)) transparent.paste(image, (0, 0)) transparent.paste(logo, position, mask=logo) return transparent.convert("RGB") except Exception as e: print(f"Error adding watermark: {e}") return image # Function to add text and watermark def add_text_and_watermark(image, name, patient_id, label): draw = ImageDraw.Draw(image) # Load a larger font (adjust the size as needed) font_size = 48 # Example font size try: font = ImageFont.truetype("font.ttf", size=font_size) except IOError: font = ImageFont.load_default() print("Error: cannot open resource, using default font.") text = f"Name: {name}, ID: {patient_id}, Result: {label}" # Calculate text bounding box text_bbox = draw.textbbox((0, 0), text, font=font) text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] text_x = 20 text_y = 40 padding = 10 # Draw a filled rectangle for the background draw.rectangle( [text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding], fill="black" ) # Draw text on top of the rectangle draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font) # Add watermark to the image image_with_watermark = add_watermark(image) return image_with_watermark # Function to initialize the database def init_db(): conn = sqlite3.connect('results.db') c = conn.cursor() c.execute('''CREATE TABLE IF NOT EXISTS results (id INTEGER PRIMARY KEY, name TEXT, patient_id TEXT, input_image BLOB, predicted_image BLOB, result TEXT)''') conn.commit() conn.close() # Function to submit result to the database def submit_result(name, patient_id, input_image, predicted_image, result): conn = sqlite3.connect('results.db') c = conn.cursor() input_image_np = np.array(input_image) _, input_buffer = cv2.imencode('.png', cv2.cvtColor(input_image_np, cv2.COLOR_RGB2BGR)) input_image_bytes = input_buffer.tobytes() predicted_image_np = np.array(predicted_image) predicted_image_rgb = cv2.cvtColor(predicted_image_np, cv2.COLOR_RGB2BGR) # Ensure correct color conversion _, predicted_buffer = cv2.imencode('.png', predicted_image_rgb) predicted_image_bytes = predicted_buffer.tobytes() c.execute("INSERT INTO results (name, patient_id, input_image, predicted_image, result) VALUES (?, ?, ?, ?, ?)", (name, patient_id, input_image_bytes, predicted_image_bytes, result)) conn.commit() conn.close() return "Result submitted to database." # Function to load and view database def view_database(): conn = sqlite3.connect('results.db') c = conn.cursor() c.execute("SELECT * FROM results") rows = c.fetchall() conn.close() # Convert to pandas DataFrame df = pd.DataFrame(rows, columns=["ID", "Name", "Patient ID", "Input Image", "Predicted Image", "Result"]) return df # Function to download database or image def download_file(choice): conn = sqlite3.connect('results.db') c = conn.cursor() if choice == "Database (.db)": conn.close() return 'results.db' else: c.execute("SELECT predicted_image FROM results ORDER BY id DESC LIMIT 1") row = c.fetchone() conn.close() if row: image_bytes = row[0] with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: temp_file.write(image_bytes) temp_file.flush() # Ensure all data is written before closing return temp_file.name else: conn.close() raise FileNotFoundError("No images found in the database.") # Initialize the database init_db() # Gradio Interface def interface(name, patient_id, input_image): if input_image is None: return "Please upload an image." output_image, raw_result = predict_image(input_image, name, patient_id) submit_status = submit_result(name, patient_id, input_image, output_image, raw_result) return output_image, raw_result, submit_status # View Database Function def view_db_interface(): df = view_database() return df # Download Function def download_interface(choice): try: file_path = download_file(choice) with open(file_path, "rb") as file: return file.read(), file_path except Exception as e: return f"Error: {str(e)}", None # Build Gradio Interface app = gr.Blocks() with app: gr.Markdown("# Eye Condition Detection System") with gr.Row(): with gr.Column(): name = gr.Textbox(label="Name") patient_id = gr.Textbox(label="Patient ID") input_image = gr.Image(label="Input Image", tool="editor", type="pil") with gr.Column(): output_image = gr.Image(label="Predicted Image") raw_result = gr.Textbox(label="Raw Predictions", lines=5) submit_status = gr.Textbox(label="Submit Status") predict_button = gr.Button("Predict") predict_button.click(fn=interface, inputs=[name, patient_id, input_image], outputs=[output_image, raw_result, submit_status]) with gr.Row(): with gr.Column(): view_button = gr.Button("View Database") download_choice = gr.Dropdown(label="Download Option", choices=["Database (.db)", "Predicted Image (.png)"]) download_button = gr.Button("Download") view_button.click(fn=view_db_interface, inputs=[], outputs=[gr.Dataframe()]) download_button.click(fn=download_interface, inputs=[download_choice], outputs=[gr.File(), gr.Textbox()]) # Launch the Gradio app app.launch()