|
import gradio as gr |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import sqlite3 |
|
import base64 |
|
from io import BytesIO |
|
import tempfile |
|
import pandas as pd |
|
|
|
|
|
model = YOLO("best.pt") |
|
|
|
def predict_image(input_image, name, patient_id): |
|
if input_image is None: |
|
return None, "Please Input The Image" |
|
|
|
|
|
image_np = np.array(input_image) |
|
|
|
|
|
if len(image_np.shape) == 2: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) |
|
elif image_np.shape[2] == 4: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
results = model(image_np) |
|
|
|
|
|
image_with_boxes = image_np.copy() |
|
raw_predictions = [] |
|
|
|
if results[0].boxes: |
|
|
|
highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item()) |
|
|
|
|
|
class_index = highest_confidence_result.cls.item() |
|
if class_index == 0: |
|
label = "Immature" |
|
color = (0, 255, 255) |
|
elif class_index == 1: |
|
label = "Mature" |
|
color = (255, 0, 0) |
|
else: |
|
label = "Normal" |
|
color = (0, 255, 0) |
|
|
|
confidence = highest_confidence_result.conf.item() |
|
xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0]) |
|
|
|
|
|
cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) |
|
|
|
|
|
font_scale = 1.0 |
|
thickness = 2 |
|
|
|
|
|
(text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) |
|
cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED) |
|
|
|
|
|
cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) |
|
|
|
raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]") |
|
|
|
raw_predictions_str = "\n".join(raw_predictions) |
|
|
|
|
|
pil_image_with_boxes = Image.fromarray(image_with_boxes) |
|
|
|
|
|
pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, patient_id, label) |
|
|
|
return pil_image_with_boxes, raw_predictions_str |
|
|
|
|
|
def add_watermark(image): |
|
try: |
|
logo = Image.open('image-logo.png').convert("RGBA") |
|
image = image.convert("RGBA") |
|
|
|
|
|
basewidth = 100 |
|
wpercent = (basewidth / float(logo.size[0])) |
|
hsize = int((float(wpercent) * logo.size[1])) |
|
logo = logo.resize((basewidth, hsize), Image.LANCZOS) |
|
|
|
|
|
position = (image.width - logo.width - 10, image.height - logo.height - 10) |
|
|
|
|
|
transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0)) |
|
transparent.paste(image, (0, 0)) |
|
transparent.paste(logo, position, mask=logo) |
|
|
|
return transparent.convert("RGB") |
|
except Exception as e: |
|
print(f"Error adding watermark: {e}") |
|
return image |
|
|
|
|
|
def add_text_and_watermark(image, name, patient_id, label): |
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
font_size = 48 |
|
try: |
|
font = ImageFont.truetype("font.ttf", size=font_size) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
print("Error: cannot open resource, using default font.") |
|
|
|
text = f"Name: {name}, ID: {patient_id}, Result: {label}" |
|
|
|
|
|
text_bbox = draw.textbbox((0, 0), text, font=font) |
|
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] |
|
text_x = 20 |
|
text_y = 40 |
|
padding = 10 |
|
|
|
|
|
draw.rectangle( |
|
[text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding], |
|
fill="black" |
|
) |
|
|
|
|
|
draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font) |
|
|
|
|
|
image_with_watermark = add_watermark(image) |
|
|
|
return image_with_watermark |
|
|
|
|
|
def init_db(): |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
c.execute('''CREATE TABLE IF NOT EXISTS results |
|
(id INTEGER PRIMARY KEY, name TEXT, patient_id TEXT, input_image BLOB, predicted_image BLOB, result TEXT)''') |
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
def submit_result(name, patient_id, input_image, predicted_image, result): |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
|
|
input_image_np = np.array(input_image) |
|
_, input_buffer = cv2.imencode('.png', cv2.cvtColor(input_image_np, cv2.COLOR_RGB2BGR)) |
|
input_image_bytes = input_buffer.tobytes() |
|
|
|
predicted_image_np = np.array(predicted_image) |
|
predicted_image_rgb = cv2.cvtColor(predicted_image_np, cv2.COLOR_RGB2BGR) |
|
_, predicted_buffer = cv2.imencode('.png', predicted_image_rgb) |
|
predicted_image_bytes = predicted_buffer.tobytes() |
|
|
|
c.execute("INSERT INTO results (name, patient_id, input_image, predicted_image, result) VALUES (?, ?, ?, ?, ?)", |
|
(name, patient_id, input_image_bytes, predicted_image_bytes, result)) |
|
conn.commit() |
|
conn.close() |
|
return "Result submitted to database." |
|
|
|
|
|
def view_database(): |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
c.execute("SELECT name, patient_id, input_image, predicted_image, result FROM results") |
|
rows = c.fetchall() |
|
conn.close() |
|
|
|
|
|
df = pd.DataFrame(rows, columns=["Name", "Patient ID", "Input Image", "Predicted Image", "Raw Result"]) |
|
|
|
|
|
def decode_image(image_blob): |
|
image_np = np.frombuffer(image_blob, dtype=np.uint8) |
|
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR) |
|
return image |
|
|
|
df["Input Image"] = df["Input Image"].apply(lambda x: decode_image(x)) |
|
df["Predicted Image"] = df["Predicted Image"].apply(lambda x: decode_image(x)) |
|
|
|
return df |
|
|
|
|
|
def download_file(choice): |
|
if choice == "Database (.db)": |
|
|
|
return 'results.db' |
|
elif choice == "Database (.html)": |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
c.execute("SELECT name, patient_id, input_image, predicted_image, result FROM results") |
|
rows = c.fetchall() |
|
conn.close() |
|
df = pd.DataFrame(rows, columns=["Name", "Patient ID", "Input Image", "Predicted Image", "Raw Result"]) |
|
html = df.to_html() |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as f: |
|
f.write(html.encode()) |
|
return f.name |
|
else: |
|
|
|
pass |
|
|
|
|
|
init_db() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("YOLOv8 Inference"): |
|
with gr.Row(): |
|
input_image = gr.Image(label="Input Image", type="pil") |
|
with gr.Row(): |
|
name = gr.Textbox(label="Patient Name") |
|
patient_id = gr.Textbox(label="Patient ID") |
|
with gr.Row(): |
|
submit_button = gr.Button("Submit") |
|
predicted_image = gr.Image(label="Predicted Image") |
|
with gr.Row(): |
|
result = gr.Textbox(label="Raw Result", lines=5) |
|
submit_button.click(predict_image, inputs=[input_image, name, patient_id], outputs=[predicted_image, result]) |
|
|
|
with gr.Tab("View Database"): |
|
view_button = gr.Button("View Database") |
|
database_output = gr.DataFrame(label="Database Records") |
|
view_button.click(view_database, outputs=database_output) |
|
|
|
download_choice = gr.Radio(["Database (.db)", "Database (.html)", "Predicted Image (.png)"], label="Choose the file to download:") |
|
download_button = gr.Button("Download") |
|
download_button.click(download_file, inputs=download_choice, outputs=gr.File()) |
|
|
|
|
|
demo.launch() |