ariankhalfani commited on
Commit
cfd68c2
1 Parent(s): 010d4d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -0
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ultralytics import YOLO
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import base64
7
+ from io import BytesIO
8
+ import tempfile
9
+ import os
10
+ from pathlib import Path
11
+ import shutil
12
+ from openpyxl import Workbook, load_workbook
13
+
14
+ # Load YOLOv8 model
15
+ model = YOLO("best.pt")
16
+
17
+ # Create directories if not present
18
+ uploaded_folder = Path('Uploaded_Picture')
19
+ predicted_folder = Path('Predicted_Picture')
20
+ uploaded_folder.mkdir(parents=True, exist_ok=True)
21
+ predicted_folder.mkdir(parents=True, exist_ok=True)
22
+
23
+ # Path for Excel database file
24
+ xlsx_db_file = Path('patient_predictions.xlsx')
25
+
26
+ # Initialize Excel database file if not present
27
+ if not xlsx_db_file.exists():
28
+ workbook = Workbook()
29
+ sheet = workbook.active
30
+ sheet.title = "Predictions"
31
+ sheet.append(["Name", "Age", "Medical Record", "Sex", "Result", "Image Path"])
32
+ workbook.save(xlsx_db_file)
33
+
34
+ def predict_image(input_image, name, age, medical_record, sex):
35
+ if input_image is None:
36
+ return None, "Please Input The Image"
37
+
38
+ # Convert Gradio input image (PIL Image) to numpy array
39
+ image_np = np.array(input_image)
40
+
41
+ # Ensure the image is in the correct format
42
+ if len(image_np.shape) == 2: # grayscale to RGB
43
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
44
+ elif image_np.shape[2] == 4: # RGBA to RGB
45
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
46
+
47
+ # Perform prediction
48
+ results = model(image_np)
49
+
50
+ # Draw bounding boxes on the image
51
+ image_with_boxes = image_np.copy()
52
+ raw_predictions = []
53
+
54
+ if results[0].boxes:
55
+ # Sort the results by confidence and take the highest confidence one
56
+ highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item())
57
+
58
+ # Determine the label based on the class index
59
+ class_index = highest_confidence_result.cls.item()
60
+ if class_index == 1:
61
+ label = "Mature"
62
+ color = (255, 0, 0) # Red for Mature
63
+ else:
64
+ label = "Normal"
65
+ color = (0, 255, 0) # Green for Normal
66
+
67
+ confidence = highest_confidence_result.conf.item()
68
+ xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0])
69
+
70
+ # Calculate the average of box width and height
71
+ box_width = xmax - xmin
72
+ box_height = ymax - ymin
73
+ avg_dimension = (box_width + box_height) / 2
74
+
75
+ # Calculate the circle radius as 1/12 of the average dimension
76
+ radius = int(avg_dimension / 12)
77
+
78
+ # Calculate the center of the bounding box
79
+ center_x = int((xmin + xmax) / 2)
80
+ center_y = int((ymin + ymax) / 2)
81
+
82
+ # Draw the circle at the center of the bounding box with the color corresponding to the label
83
+ cv2.circle(image_with_boxes, (center_x, center_y), radius, color, 2)
84
+
85
+ # Enlarge font scale and thickness
86
+ font_scale = 1.0
87
+ thickness = 2
88
+
89
+ # Calculate label background size
90
+ (text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
91
+ cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED)
92
+
93
+ # Put the label text with black background
94
+ cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
95
+
96
+ raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Circle Center: [{center_x}, {center_y}], Radius: {radius}")
97
+
98
+ raw_predictions_str = "\n".join(raw_predictions)
99
+
100
+ # Convert to PIL image for further processing
101
+ pil_image_with_boxes = Image.fromarray(image_with_boxes)
102
+
103
+ # Add text and watermark
104
+ pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, age, medical_record, sex, label)
105
+
106
+ # Save images to directories
107
+ image_name = f"{name}-{age}-{sex}-{medical_record}.png"
108
+ input_image.save(uploaded_folder / image_name)
109
+ pil_image_with_boxes.save(predicted_folder / image_name)
110
+
111
+ # Convert the predicted image to base64 for embedding in the XLSX file
112
+ buffered = BytesIO()
113
+ pil_image_with_boxes.save(buffered, format="PNG")
114
+ predicted_image_base64 = base64.b64encode(buffered.getvalue()).decode()
115
+
116
+ # Append the prediction to the XLSX database
117
+ append_patient_info_to_xlsx(name, age, medical_record, sex, label, image_name)
118
+
119
+ return pil_image_with_boxes, raw_predictions_str
120
+
121
+ def add_watermark(image):
122
+ try:
123
+ logo = Image.open('image-logo.png').convert("RGBA")
124
+ image = image.convert("RGBA")
125
+ basewidth = 100
126
+ wpercent = (basewidth / float(logo.size[0]))
127
+ hsize = int((float(wpercent) * logo.size[1]))
128
+ logo = logo.resize((basewidth, hsize), Image.LANCZOS)
129
+ position = (image.width - logo.width - 10, image.height - logo.height - 10)
130
+ transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0))
131
+ transparent.paste(image, (0, 0))
132
+ transparent.paste(logo, position, mask=logo)
133
+ return transparent.convert("RGB")
134
+ except Exception as e:
135
+ print(f"Error adding watermark: {e}")
136
+ return image
137
+
138
+ def add_text_and_watermark(image, name, age, medical_record, sex, label):
139
+ draw = ImageDraw.Draw(image)
140
+ font_size = 24
141
+ try:
142
+ font = ImageFont.truetype("font.ttf", size=font_size)
143
+ except IOError:
144
+ font = ImageFont.load_default()
145
+ print("Error: cannot open resource, using default font.")
146
+
147
+ text = f"Name: {name}, Age: {age}, Medical Record: {medical_record}, Sex: {sex}, Result: {label}"
148
+ text_bbox = draw.textbbox((0, 0), text, font=font)
149
+ text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
150
+ text_x = 20
151
+ text_y = 40
152
+ padding = 10
153
+ draw.rectangle([text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding], fill="black")
154
+ draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font)
155
+
156
+ image_with_watermark = add_watermark(image)
157
+ return image_with_watermark
158
+
159
+ def append_patient_info_to_xlsx(name, age, medical_record, sex, result, image_path):
160
+ if not xlsx_db_file.exists():
161
+ workbook = Workbook()
162
+ sheet = workbook.active
163
+ sheet.title = "Predictions"
164
+ sheet.append(["Name", "Age", "Medical Record", "Sex", "Result", "Image Path"])
165
+ workbook.save(xlsx_db_file)
166
+
167
+ workbook = load_workbook(xlsx_db_file)
168
+ sheet = workbook.active
169
+ sheet.append([name, age, medical_record, sex, result, str(image_path)])
170
+ workbook.save(xlsx_db_file)
171
+
172
+ return str(xlsx_db_file)
173
+
174
+ def download_folder(folder):
175
+ zip_path = os.path.join(tempfile.gettempdir(), f"{folder}.zip")
176
+ shutil.make_archive(zip_path.replace('.zip', ''), 'zip', folder)
177
+ return zip_path
178
+
179
+ def interface(name, age, medical_record, sex, input_image):
180
+ if input_image is None:
181
+ return None, "Please upload an image.", None
182
+
183
+ output_image, raw_result = predict_image(input_image, name, age, medical_record, sex)
184
+
185
+ return output_image, raw_result, str(xlsx_db_file)
186
+
187
+ def download_predicted_folder():
188
+ return download_folder(predicted_folder)
189
+
190
+ def download_uploaded_folder():
191
+ return download_folder(uploaded_folder)
192
+
193
+ with gr.Blocks() as demo:
194
+ with gr.Column():
195
+ gr.Markdown("# Cataract Detection System")
196
+ gr.Markdown("Upload an image to detect cataract and add patient details.")
197
+ gr.Markdown("This application uses YOLOv8 with mAP=0.981")
198
+
199
+ with gr.Column():
200
+ name = gr.Textbox(label="Name")
201
+ age = gr.Number(label="Age")
202
+ medical_record = gr.Number(label="Medical Record")
203
+ sex = gr.Radio(["Male", "Female"], label="Sex")
204
+ input_image = gr.Image(type="pil", label="Upload an Image", image_mode="RGB")
205
+
206
+ with gr.Column():
207
+ submit_btn = gr.Button("Submit")
208
+ output_image = gr.Image(type="pil", label="Predicted Image")
209
+ raw_result = gr.Textbox(label="Raw Result", interactive=False)
210
+
211
+ submit_btn.click(fn=interface, inputs=[name, age, medical_record, sex, input_image], outputs=[output_image, raw_result])
212
+
213
+ with gr.Row():
214
+ download_uploaded_btn = gr.Button("Download Uploaded Folder")
215
+ download_predicted_btn = gr.Button("Download Predicted Folder")
216
+
217
+ download_uploaded_btn.click(fn=download_uploaded_folder, inputs=[], outputs=gr.File())
218
+ download_predicted_btn.click(fn=download_predicted_folder, inputs=[], outputs=gr.File())
219
+
220
+ demo.launch()