Spaces:
Sleeping
Sleeping
import gradio as gr | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
import sqlite3 | |
import base64 | |
from io import BytesIO | |
import tempfile | |
import pandas as pd | |
import os | |
# Load YOLOv8 model | |
model = YOLO("best.pt") | |
def predict_image(input_image, name, age, medical_record, sex): | |
if input_image is None: | |
return None, "Please Input The Image" | |
# Convert Gradio input image (PIL Image) to numpy array | |
image_np = np.array(input_image) | |
# Ensure the image is in the correct format | |
if len(image_np.shape) == 2: # grayscale to RGB | |
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) | |
elif image_np.shape[2] == 4: # RGBA to RGB | |
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) | |
# Perform prediction | |
results = model(image_np) | |
# Draw bounding boxes on the image | |
image_with_boxes = image_np.copy() | |
raw_predictions = [] | |
if results[0].boxes: | |
# Sort the results by confidence and take the highest confidence one | |
highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item()) | |
# Determine the label based on the class index | |
class_index = highest_confidence_result.cls.item() | |
if class_index == 0: | |
label = "Immature" | |
color = (0, 255, 255) # Yellow for Immature | |
elif class_index == 1: | |
label = "Mature" | |
color = (255, 0, 0) # Red for Mature | |
else: | |
label = "Normal" | |
color = (0, 255, 0) # Green for Normal | |
confidence = highest_confidence_result.conf.item() | |
xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0]) | |
# Draw the bounding box | |
cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) | |
# Enlarge font scale and thickness | |
font_scale = 1.0 | |
thickness = 2 | |
# Calculate label background size | |
(text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) | |
cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED) | |
# Put the label text with black background | |
cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) | |
raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]") | |
raw_predictions_str = "\n".join(raw_predictions) | |
# Convert to PIL image for further processing | |
pil_image_with_boxes = Image.fromarray(image_with_boxes) | |
# Add text and watermark | |
pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, age, medical_record, sex, label) | |
return pil_image_with_boxes, raw_predictions_str | |
# Function to add watermark | |
def add_watermark(image): | |
try: | |
logo = Image.open('image-logo.png').convert("RGBA") | |
image = image.convert("RGBA") | |
# Resize logo | |
basewidth = 100 | |
wpercent = (basewidth / float(logo.size[0])) | |
hsize = int((float(wpercent) * logo.size[1])) | |
logo = logo.resize((basewidth, hsize), Image.LANCZOS) | |
# Position logo | |
position = (image.width - logo.width - 10, image.height - logo.height - 10) | |
# Composite image | |
transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0)) | |
transparent.paste(image, (0, 0)) | |
transparent.paste(logo, position, mask=logo) | |
return transparent.convert("RGB") | |
except Exception as e: | |
print(f"Error adding watermark: {e}") | |
return image | |
# Function to add text and watermark | |
def add_text_and_watermark(image, name, age, medical_record, sex, label): | |
draw = ImageDraw.Draw(image) | |
# Load a larger font (adjust the size as needed) | |
font_size = 48 # Example font size | |
try: | |
font = ImageFont.truetype("font.ttf", size=font_size) | |
except IOError: | |
font = ImageFont.load_default() | |
print("Error: cannot open resource, using default font.") | |
text = f"Name: {name}, Age: {age}, Medical Record: {medical_record}, Sex: {sex}, Result: {label}" | |
# Calculate text bounding box | |
text_bbox = draw.textbbox((0, 0), text, font=font) | |
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] | |
text_x = 20 | |
text_y = 40 | |
padding = 10 | |
# Draw a filled rectangle for the background | |
draw.rectangle( | |
[text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding], | |
fill="black" | |
) | |
# Draw text on top of the rectangle | |
draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font) | |
# Add watermark to the image | |
image_with_watermark = add_watermark(image) | |
return image_with_watermark | |
# Function to initialize the database | |
def init_db(): | |
conn = sqlite3.connect('results.db') | |
c = conn.cursor() | |
c.execute('''CREATE TABLE IF NOT EXISTS results | |
(id INTEGER PRIMARY KEY, name TEXT, age INTEGER, medical_record INTEGER, sex TEXT, input_image BLOB, predicted_image BLOB, result TEXT)''') | |
conn.commit() | |
conn.close() | |
# Function to submit result to the database | |
def submit_result(name, age, medical_record, sex, input_image, predicted_image, result): | |
conn = sqlite3.connect('results.db') | |
c = conn.cursor() | |
input_image_np = np.array(input_image) | |
_, input_buffer = cv2.imencode('.png', cv2.cvtColor(input_image_np, cv2.COLOR_RGB2BGR)) | |
input_image_bytes = input_buffer.tobytes() | |
predicted_image_np = np.array(predicted_image) | |
predicted_image_rgb = cv2.cvtColor(predicted_image_np, cv2.COLOR_RGB2BGR) # Ensure correct color conversion | |
_, predicted_buffer = cv2.imencode('.png', predicted_image_rgb) | |
predicted_image_bytes = predicted_buffer.tobytes() | |
c.execute("INSERT INTO results (name, age, medical_record, sex, input_image, predicted_image, result) VALUES (?, ?, ?, ?, ?, ?, ?)", | |
(name, age, medical_record, sex, input_image_bytes, predicted_image_bytes, result)) | |
conn.commit() | |
conn.close() | |
return "Result submitted to database." | |
# Function to load and view database in HTML format | |
def view_database(): | |
conn = sqlite3.connect('results.db') | |
c = conn.cursor() | |
c.execute("SELECT name, age, medical_record, sex, input_image, predicted_image, result FROM results") | |
rows = c.fetchall() | |
conn.close() | |
# Prepare the HTML content | |
html_content = "<table border='1'><tr><th>Name</th><th>Age</th><th>Medical Record</th><th>Sex</th><th>Input Image</th><th>Predicted Image</th><th>Result</th></tr>" | |
for row in rows: | |
name, age, medical_record, sex, input_image_bytes, predicted_image_bytes, result = row | |
# Decode the images | |
input_image = Image.open(BytesIO(input_image_bytes)) | |
predicted_image = Image.open(BytesIO(predicted_image_bytes)) | |
# Convert images to base64 for display in HTML | |
buffered_input = BytesIO() | |
input_image.save(buffered_input, format="PNG") | |
input_image_base64 = base64.b64encode(buffered_input.getvalue()).decode('utf-8') | |
buffered_predicted = BytesIO() | |
predicted_image.save(buffered_predicted, format="PNG") | |
predicted_image_base64 = base64.b64encode(buffered_predicted.getvalue()).decode('utf-8') | |
# Add a row to the HTML table | |
html_content += f"<tr><td>{name}</td><td>{age}</td><td>{medical_record}</td><td>{sex}</td><td><img src='data:image/png;base64,{input_image_base64}' width='100'></td><td><img src='data:image/png;base64,{predicted_image_base64}' width='100'></td><td>{result}</td></tr>" | |
html_content += "</table>" | |
return html_content | |
# Function to download database or HTML file | |
def download_file(choice): | |
directory = tempfile.gettempdir() | |
# Ensure the directory exists | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
db_file_path = os.path.join(directory, 'results.db') | |
if choice == "Database (.db)": | |
# Return the correct database file path | |
return db_file_path | |
elif choice == "Database (.html)": | |
# Check if the database file exists | |
if not os.path.isfile(db_file_path): | |
raise FileNotFoundError(f"Database file not found at path: {db_file_path}") | |
# Connect to the SQLite database | |
conn = sqlite3.connect(db_file_path) | |
try: | |
# Attempt to read the results table into a DataFrame | |
df = pd.read_sql_query("SELECT * FROM results", conn) | |
except pd.errors.DatabaseError as e: | |
conn.close() | |
raise ValueError("Table 'results' does not exist in the database.") from e | |
# Close the database connection | |
conn.close() | |
# Define the path for the HTML file | |
html_file_path = os.path.join(directory, "results.html") | |
# Save the DataFrame as an HTML file | |
df.to_html(html_file_path, index=False) | |
# Return the path to the HTML file | |
return html_file_path | |
else: | |
raise ValueError("Invalid choice. Please select a valid format.") | |
# Initialize the database | |
init_db() | |
# Gradio Interface | |
def interface(name, age, medical_record, sex, input_image): | |
if input_image is None: | |
return None, "Please upload an image.", None | |
output_image, raw_result = predict_image(input_image, name, age, medical_record, sex) | |
submit_status = submit_result(name, age, medical_record, sex, input_image, output_image, raw_result) | |
return output_image, raw_result, submit_status | |
# View Database Function (Updated) | |
def view_db_interface(): | |
html_content = view_database() | |
return html_content | |
# Download Function | |
def download_interface(choice): | |
try: | |
# Get the file path | |
file_path = download_file(choice) | |
# Return the file path (string) directly for the Gradio component to handle | |
return file_path | |
except (FileNotFoundError, ValueError) as e: | |
# Display error message in Gradio output | |
return str(e) | |
# Gradio Blocks | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# Cataract Detection System") | |
gr.Markdown("Upload an image to detect cataract and add patient details.") | |
gr.Markdown("This application uses YOLOv8 with mAP=0.981") | |
with gr.Column(): | |
name = gr.Textbox(label="Name") | |
age = gr.Number(label="Age") | |
medical_record = gr.Number(label="Medical Record") | |
sex = gr.Radio(["Male", "Female"], label="Sex") | |
input_image = gr.Image(type="pil", label="Upload an Image", image_mode="RGB") | |
with gr.Column(): | |
submit_btn = gr.Button("Submit") | |
output_image = gr.Image(type="pil", label="Predicted Image") | |
with gr.Row(): | |
raw_result = gr.Textbox(label="Raw Result", lines=5) | |
submit_status = gr.Textbox(label="Submission Status") | |
submit_btn.click(fn=interface, inputs=[name, age, medical_record, sex, input_image], outputs=[output_image, raw_result, submit_status]) | |
with gr.Column(): | |
view_db_btn = gr.Button("View Database") | |
db_output = gr.HTML(label="Database Records") | |
view_db_btn.click(fn=view_db_interface, inputs=[], outputs=[db_output]) | |
with gr.Column(): | |
download_choice = gr.Radio(["Database (.db)", "Database (.html)"], label="Choose the file to download:") | |
download_btn = gr.Button("Download") | |
download_output = gr.File(label="Download File") | |
download_btn.click(fn=download_interface, inputs=[download_choice], outputs=gr.File()) | |
# Launch the Gradio app | |
demo.launch() |