ariankhalfani commited on
Commit
e8fecc0
1 Parent(s): 17e2614

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -116
app.py CHANGED
@@ -1,118 +1,111 @@
1
- import gradio as gr
2
  import numpy as np
3
-
4
- # Define functions to wrap the private space's interfaces
5
- def cataract_analyzer(image):
6
- private_space_id = "EyeMedicalDiagnosis/cataract_analyzer"
7
- private_space = gr.Interface.load(private_space_id)
8
- result = private_space(image)
9
- return parse_cataract_result(result, image)
10
-
11
- def glaucoma_analyzer(image):
12
- private_space_id = "EyeMedicalDiagnosis/glaucoma_analyzer"
13
- private_space = gr.Interface.load(private_space_id)
14
- result = private_space(image)
15
- return parse_glaucoma_result(result, image)
16
-
17
- def rlfd_analyzer(image):
18
- private_space_id = "EyeMedicalDiagnosis/rlfd_analyzer"
19
- private_space = gr.Interface.load(private_space_id)
20
- result = private_space(image)
21
- return parse_rlfd_result(result, image)
22
-
23
- def parse_cataract_result(result, image):
24
- if isinstance(result, list) and len(result) == 7:
25
- return result
26
- else:
27
- return [image, "Error: Invalid result", image, 0, 0, 0, ""]
28
-
29
- def parse_glaucoma_result(result, image):
30
- if isinstance(result, list) and len(result) == 5:
31
- return result
32
- else:
33
- return [image, "Error: Invalid result", "0", "0", "0"]
34
-
35
- def parse_rlfd_result(result, image):
36
- if isinstance(result, list) and len(result) == 8:
37
- return result
38
- else:
39
- return [image, image, "Error: Invalid result", image, "0", image, "0", image, "0"]
40
-
41
- # Create the public interface with custom layout
42
- with gr.Blocks() as demo:
43
- gr.Markdown("## Medical Image Analyzer")
44
- gr.Markdown("Choose which analyzer you want to use:")
45
-
46
- with gr.Tab("Cataract Analyzer"):
47
- with gr.Row():
48
- image_input_cataract = gr.Image(type="numpy", label="Upload an Image")
49
- submit_btn_cataract = gr.Button("Submit")
50
-
51
- with gr.Row():
52
- result_image_cataract = gr.Image(type="numpy", label="Image with Prediction and Bounding Box", scale=2)
53
- cataract_label = gr.Textbox(label="Cataract Prediction", scale=1)
54
-
55
- with gr.Row():
56
- blended_image_cataract = gr.Image(type="numpy", label="Image with Masked Area", scale=2)
57
-
58
- with gr.Column():
59
- red_quantity_cataract = gr.Slider(label="Red Quantity", minimum=0, maximum=255, interactive=False)
60
- green_quantity_cataract = gr.Slider(label="Green Quantity", minimum=0, maximum=255, interactive=False)
61
- blue_quantity_cataract = gr.Slider(label="Blue Quantity", minimum=0, maximum=255, interactive=False)
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  with gr.Row():
64
- raw_response_cataract = gr.Textbox(label="Raw Response", scale=2)
65
-
66
- submit_btn_cataract.click(
67
- cataract_analyzer,
68
- inputs=image_input_cataract,
69
- outputs=[
70
- result_image_cataract, cataract_label, blended_image_cataract,
71
- red_quantity_cataract, green_quantity_cataract, blue_quantity_cataract, raw_response_cataract
72
- ]
73
- )
74
-
75
- with gr.Tab("Glaucoma Analyzer"):
76
- image_input_glaucoma = gr.Image(type="numpy", label="Input Image")
77
- submit_btn_glaucoma = gr.Button("Submit")
78
-
79
- segmented_image_glaucoma = gr.Image(type="numpy", label="Segmented Image")
80
- cup_area = gr.Textbox(label="Cup Area")
81
- disk_area = gr.Textbox(label="Disk Area")
82
- rim_area = gr.Textbox(label="Rim Area")
83
- rim_to_disk_ratio = gr.Textbox(label="Rim/Disk Ratio")
84
-
85
- submit_btn_glaucoma.click(
86
- glaucoma_analyzer,
87
- inputs=image_input_glaucoma,
88
- outputs=[segmented_image_glaucoma, cup_area, disk_area, rim_area, rim_to_disk_ratio]
89
- )
90
-
91
- with gr.Tab("RLFD Analyzer"):
92
- image_input_rlfd = gr.Image(type="pil", label="Upload an Image")
93
- submit_btn_rlfd = gr.Button("Submit")
94
-
95
- enhanced_image_with_line = gr.Image(type="numpy", label="Enhanced Image with Diagonal Lines")
96
- top_section_image = gr.Image(type="numpy", label="Top Section")
97
- top_section_score = gr.Textbox(label="Top Section Score")
98
- bottom_section_image = gr.Image(type="numpy", label="Bottom Section")
99
- bottom_section_score = gr.Textbox(label="Bottom Section Score")
100
- left_section_image = gr.Image(type="numpy", label="Left Section")
101
- left_section_score = gr.Textbox(label="Left Section Score")
102
- right_section_image = gr.Image(type="numpy", label="Right Section")
103
- right_section_score = gr.Textbox(label="Right Section Score")
104
-
105
- submit_btn_rlfd.click(
106
- rlfd_analyzer,
107
- inputs=image_input_rlfd,
108
- outputs=[
109
- enhanced_image_with_line,
110
- top_section_image, top_section_score,
111
- bottom_section_image, bottom_section_score,
112
- left_section_image, left_section_score,
113
- right_section_image, right_section_score
114
- ]
115
- )
116
-
117
- # Launch the public interface
118
- demo.launch()
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
  import numpy as np
3
+ import cv2
4
+ import tensorflow as tf
5
+ import gradio as gr
6
+ import io
7
+
8
+ def load_model(model_path):
9
+ model = tf.keras.models.load_model(model_path)
10
+ model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.BinaryCrossentropy(), metrics=['accuracy'])
11
+ return model
12
+
13
+ def get_model_summary(model):
14
+ stream = io.StringIO()
15
+ model.summary(print_fn=lambda x: stream.write(x + "\n"))
16
+ summary_str = stream.getvalue()
17
+ stream.close()
18
+ return summary_str
19
+
20
+ def get_input_shape(model):
21
+ input_shape = model.input_shape[1:] # Skip the batch dimension
22
+ return input_shape
23
+
24
+ def preprocess_image(image, input_shape):
25
+ img = np.array(image)
26
+ num_channels = input_shape[-1]
27
+
28
+ if num_channels == 1: # Model expects grayscale
29
+ if len(img.shape) == 2: # Image is already grayscale
30
+ img = np.expand_dims(img, axis=-1)
31
+ elif img.shape[2] == 3: # Convert RGB to grayscale
32
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
33
+ img = np.expand_dims(img, axis=-1)
34
+ elif num_channels == 3: # Model expects RGB
35
+ if len(img.shape) == 2: # Convert grayscale to RGB
36
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
37
+ elif img.shape[2] == 1: # Convert single channel to RGB
38
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
39
+
40
+ img_resized = cv2.resize(img, (input_shape[0], input_shape[1]))
41
+ img_normalized = img_resized / 255.0
42
+ img_batch = np.expand_dims(img_normalized, axis=0)
43
+
44
+ return img_batch
45
+
46
+ def diagnose_image(image, model, input_shape):
47
+ img_batch = preprocess_image(image, input_shape)
48
+ prediction = model.predict(img_batch)
49
+ glaucoma_probability = prediction[0][0]
50
+ result_text = f"Probability of glaucoma: {glaucoma_probability:.2%}"
51
+
52
+ img_display = np.array(image)
53
+ if img_display.shape[2] == 1: # Convert to RGB for display
54
+ img_display = cv2.cvtColor(img_display.squeeze(), cv2.COLOR_GRAY2RGB)
55
+ image_pil = Image.fromarray(img_display)
56
+ draw = ImageDraw.Draw(image_pil)
57
+ font = ImageFont.load_default()
58
+
59
+ text = f"{glaucoma_probability:.2%}"
60
+ text_bbox = draw.textbbox((0, 0), text, font=font)
61
+ text_size = (text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1])
62
+
63
+ rect_width = 200
64
+ rect_height = 100
65
+ rect_x = (image_pil.width - rect_width) // 2
66
+ rect_y = (image_pil.height - rect_height) // 2
67
+
68
+ draw.rectangle([rect_x, rect_y, rect_x + rect_width, rect_y + rect_height], outline="red", width=3)
69
+
70
+ text_x = rect_x + (rect_width - text_size[0]) // 2
71
+ text_y = rect_y + (rect_height - text_size[1]) // 2
72
+
73
+ draw.text((text_x, text_y), text, fill="red", font=font)
74
+
75
+ return image_pil, result_text
76
+
77
+ def main():
78
+ with gr.Blocks() as demo:
79
+ gr.Markdown("# Glaucoma Detection App")
80
+ gr.Markdown("Upload an eye image to detect the probability of glaucoma.")
81
+
82
  with gr.Row():
83
+ model_file = gr.File(label="Upload Model (.h5 or .keras)")
84
+ load_model_btn = gr.Button("Load Model")
85
+ model_info = gr.Markdown()
86
+
87
+ image = gr.Image(type="pil", label="Upload Image")
88
+ submit_btn = gr.Button("Diagnose")
89
+ result = gr.Textbox(label="Diagnosis Result")
90
+
91
+ def load_and_display_model_info(file):
92
+ model = load_model(file.name)
93
+ model_summary = get_model_summary(model)
94
+ input_shape = get_input_shape(model)
95
+ return model, model_summary, input_shape
96
+
97
+ model = gr.State(None)
98
+ input_shape = gr.State(None)
99
+
100
+ def diagnose_and_display(image, model, input_shape):
101
+ return diagnose_image(image, model, input_shape)
102
+
103
+ load_model_btn.click(fn=load_and_display_model_info, inputs=model_file, outputs=[model, model_info, input_shape])
104
+ submit_btn.click(fn=diagnose_and_display, inputs=[image, model, input_shape], outputs=[image, result])
105
+
106
+ gr.Markdown("### Glaucoma Analyzer V.1.0.0 by Thariq Arian")
107
+
108
+ demo.launch()
109
+
110
+ if __name__ == "__main__":
111
+ main()