ariankhalfani's picture
Update app.py
a9d5577 verified
raw
history blame
4.17 kB
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2
import tensorflow as tf
import gradio as gr
import io
def load_model(model_path):
model = tf.keras.models.load_model(model_path)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.BinaryCrossentropy(), metrics=['accuracy'])
return model
def get_model_summary(model):
stream = io.StringIO()
model.summary(print_fn=lambda x: stream.write(x + "\n"))
summary_str = stream.getvalue()
stream.close()
return summary_str
def get_input_shape(model):
input_shape = model.input_shape[1:] # Skip the batch dimension
return input_shape
def preprocess_image(image, input_shape):
img = np.array(image)
num_channels = input_shape[-1]
if num_channels == 1: # Model expects grayscale
if len(img.shape) == 2: # Image is already grayscale
img = np.expand_dims(img, axis=-1)
elif img.shape[2] == 3: # Convert RGB to grayscale
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = np.expand_dims(img, axis=-1)
elif num_channels == 3: # Model expects RGB
if len(img.shape) == 2: # Convert grayscale to RGB
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 1: # Convert single channel to RGB
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
img_resized = cv2.resize(img, (input_shape[0], input_shape[1]))
img_normalized = img_resized / 255.0
img_batch = np.expand_dims(img_normalized, axis=0)
return img_batch
def diagnose_image(image, model, input_shape):
img_batch = preprocess_image(image, input_shape)
prediction = model.predict(img_batch)
glaucoma_probability = prediction[0][0]
result_text = f"Probability of glaucoma: {glaucoma_probability:.2%}"
img_display = np.array(image)
if img_display.shape[2] == 1: # Convert to RGB for display
img_display = cv2.cvtColor(img_display.squeeze(), cv2.COLOR_GRAY2RGB)
image_pil = Image.fromarray(img_display)
draw = ImageDraw.Draw(image_pil)
font = ImageFont.load_default()
text = f"{glaucoma_probability:.2%}"
text_bbox = draw.textbbox((0, 0), text, font=font)
text_size = (text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1])
rect_width = 200
rect_height = 100
rect_x = (image_pil.width - rect_width) // 2
rect_y = (image_pil.height - rect_height) // 2
draw.rectangle([rect_x, rect_y, rect_x + rect_width, rect_y + rect_height], outline="red", width=3)
text_x = rect_x + (rect_width - text_size[0]) // 2
text_y = rect_y + (rect_height - text_size[1]) // 2
draw.text((text_x, text_y), text, fill="red", font=font)
return image_pil, result_text
def main():
with gr.Blocks() as demo:
gr.Markdown("# Glaucoma Detection App")
gr.Markdown("Upload an fundus eye image to detect the probability of glaucoma.")
with gr.Row():
model_file = gr.File(label="Upload Model (.h5 or .keras)")
load_model_btn = gr.Button("Load Model")
model_info = gr.Markdown()
image = gr.Image(type="pil", label="Upload Image")
submit_btn = gr.Button("Diagnose")
result = gr.Textbox(label="Diagnosis Result")
def load_and_display_model_info(file):
model = load_model(file.name)
model_summary = get_model_summary(model)
input_shape = get_input_shape(model)
return model, model_summary, input_shape
model = gr.State(None)
input_shape = gr.State(None)
def diagnose_and_display(image, model, input_shape):
return diagnose_image(image, model, input_shape)
load_model_btn.click(fn=load_and_display_model_info, inputs=model_file, outputs=[model, model_info, input_shape])
submit_btn.click(fn=diagnose_and_display, inputs=[image, model, input_shape], outputs=[image, result])
gr.Markdown("### Glaucoma Analyzer V.1.0.0 by Thariq Arian")
demo.launch()
if __name__ == "__main__":
main()