ariankhalfani commited on
Commit
22a68d2
1 Parent(s): 822a6f7

Update app2.py

Browse files
Files changed (1) hide show
  1. app2.py +113 -48
app2.py CHANGED
@@ -3,10 +3,12 @@ from ultralytics import YOLO
3
  import cv2
4
  import numpy as np
5
  from PIL import Image, ImageDraw, ImageFont
 
 
 
6
  import os
7
  from pathlib import Path
8
  import shutil
9
- import tempfile
10
 
11
  # Load YOLOv8 model
12
  model = YOLO("best.pt")
@@ -17,8 +19,13 @@ predicted_folder = Path('Predicted_Picture')
17
  uploaded_folder.mkdir(parents=True, exist_ok=True)
18
  predicted_folder.mkdir(parents=True, exist_ok=True)
19
 
20
- # Global patient data list to accumulate HTML data
21
- patient_data = []
 
 
 
 
 
22
 
23
  def predict_image(input_image, name, age, medical_record, sex):
24
  if input_image is None:
@@ -36,41 +43,40 @@ def predict_image(input_image, name, age, medical_record, sex):
36
  # Perform prediction
37
  results = model(image_np)
38
 
39
- # Draw bounding boxes and white circle on the image
40
  image_with_boxes = image_np.copy()
41
  raw_predictions = []
42
 
43
  if results[0].boxes:
 
44
  highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item())
 
 
45
  class_index = highest_confidence_result.cls.item()
46
  if class_index == 0:
47
  label = "Immature"
48
- color = (0, 255, 255)
49
  elif class_index == 1:
50
  label = "Mature"
51
- color = (255, 0, 0)
52
  else:
53
  label = "Normal"
54
- color = (0, 255, 0)
55
 
56
  confidence = highest_confidence_result.conf.item()
57
  xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0])
58
 
59
  # Draw the bounding box
60
  cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2)
61
-
62
- # Draw the white circle in the center of the bounding box
63
- box_width = xmax - xmin
64
- box_height = ymax - ymin
65
- center_x = xmin + box_width // 2
66
- center_y = ymin + box_height // 2
67
- radius = int((box_width + box_height) / 2 / 12)
68
- cv2.circle(image_with_boxes, (center_x, center_y), radius, (255, 255, 255), 2)
69
 
70
  # Enlarge font scale and thickness
71
  font_scale = 1.0
72
  thickness = 2
73
 
 
 
 
 
74
  # Put the label text with black background
75
  cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
76
 
@@ -89,13 +95,47 @@ def predict_image(input_image, name, age, medical_record, sex):
89
  input_image.save(uploaded_folder / image_name)
90
  pil_image_with_boxes.save(predicted_folder / image_name)
91
 
 
 
 
 
 
 
 
 
92
  return pil_image_with_boxes, raw_predictions_str
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # Function to add text and watermark
95
  def add_text_and_watermark(image, name, age, medical_record, sex, label):
96
  draw = ImageDraw.Draw(image)
97
 
98
- font_size = 24
 
99
  try:
100
  font = ImageFont.truetype("font.ttf", size=font_size)
101
  except IOError:
@@ -103,48 +143,75 @@ def add_text_and_watermark(image, name, age, medical_record, sex, label):
103
  print("Error: cannot open resource, using default font.")
104
 
105
  text = f"Name: {name}, Age: {age}, Medical Record: {medical_record}, Sex: {sex}, Result: {label}"
106
- text_x, text_y = 20, 40
 
 
 
 
 
107
  padding = 10
108
 
109
  # Draw a filled rectangle for the background
110
  draw.rectangle(
111
- [text_x - padding, text_y - padding, text_x + 500, text_y + 30 + padding],
112
  fill="black"
113
  )
114
 
115
  # Draw text on top of the rectangle
116
  draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font)
117
 
118
- return image
119
-
120
- # Function to save patient info in HTML and accumulate data
121
- def save_patient_info_to_html(name, age, medical_record, sex, result):
122
- global patient_data
123
- new_data = f"<p><strong>Name:</strong> {name}, <strong>Age:</strong> {age}, <strong>Medical Record:</strong> {medical_record}, <strong>Sex:</strong> {sex}, <strong>Result:</strong> {result}</p>"
124
- patient_data.append(new_data)
125
-
126
- html_content = f"""
127
- <html>
128
- <body>
129
- <h1>Patient Information</h1>
130
- {''.join(patient_data)}
131
- </body>
132
- </html>
133
- """
134
 
135
- html_file_path = os.path.join(tempfile.gettempdir(), 'patient_info.html')
136
- with open(html_file_path, 'w') as f:
137
- f.write(html_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- return html_file_path
 
140
 
141
  # Function to download the folders
142
- def download_folder(folder_path):
143
- zip_path = os.path.join(tempfile.gettempdir(), f"{Path(folder_path).name}.zip")
144
- shutil.make_archive(zip_path.replace('.zip', ''), 'zip', folder_path)
 
 
 
145
  return zip_path
146
 
147
  # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  with gr.Blocks() as demo:
149
  with gr.Column():
150
  gr.Markdown("# Cataract Detection System")
@@ -170,18 +237,16 @@ with gr.Blocks() as demo:
170
  download_uploaded_btn = gr.Button("Download Uploaded Images")
171
  download_predicted_btn = gr.Button("Download Predicted Images")
172
 
 
173
  patient_info_file = gr.File(label="Patient Information HTML File")
174
  uploaded_folder_file = gr.File(label="Uploaded Images Zip File")
175
  predicted_folder_file = gr.File(label="Predicted Images Zip File")
176
 
177
- # Use gr.State to hold folder paths
178
- uploaded_folder_state = gr.State(str(uploaded_folder))
179
- predicted_folder_state = gr.State(str(predicted_folder))
180
-
181
- submit_btn.click(fn=predict_image, inputs=[name, age, medical_record, sex, input_image], outputs=[output_image, raw_result])
182
  download_html_btn.click(fn=save_patient_info_to_html, inputs=[name, age, medical_record, sex, raw_result], outputs=patient_info_file)
183
- download_uploaded_btn.click(fn=download_folder, inputs=[uploaded_folder_state], outputs=uploaded_folder_file)
184
- download_predicted_btn.click(fn=download_folder, inputs=[predicted_folder_state], outputs=predicted_folder_file)
185
 
186
  # Launch Gradio app
187
  demo.launch()
 
3
  import cv2
4
  import numpy as np
5
  from PIL import Image, ImageDraw, ImageFont
6
+ import base64
7
+ from io import BytesIO
8
+ import tempfile
9
  import os
10
  from pathlib import Path
11
  import shutil
 
12
 
13
  # Load YOLOv8 model
14
  model = YOLO("best.pt")
 
19
  uploaded_folder.mkdir(parents=True, exist_ok=True)
20
  predicted_folder.mkdir(parents=True, exist_ok=True)
21
 
22
+ # Path for HTML database file
23
+ html_db_file = Path('patient_predictions.html')
24
+
25
+ # Initialize HTML database file if not present
26
+ if not html_db_file.exists():
27
+ with open(html_db_file, 'w') as f:
28
+ f.write("<html><body><h1>Patient Prediction Database</h1>")
29
 
30
  def predict_image(input_image, name, age, medical_record, sex):
31
  if input_image is None:
 
43
  # Perform prediction
44
  results = model(image_np)
45
 
46
+ # Draw bounding boxes on the image
47
  image_with_boxes = image_np.copy()
48
  raw_predictions = []
49
 
50
  if results[0].boxes:
51
+ # Sort the results by confidence and take the highest confidence one
52
  highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item())
53
+
54
+ # Determine the label based on the class index
55
  class_index = highest_confidence_result.cls.item()
56
  if class_index == 0:
57
  label = "Immature"
58
+ color = (0, 255, 255) # Yellow for Immature
59
  elif class_index == 1:
60
  label = "Mature"
61
+ color = (255, 0, 0) # Red for Mature
62
  else:
63
  label = "Normal"
64
+ color = (0, 255, 0) # Green for Normal
65
 
66
  confidence = highest_confidence_result.conf.item()
67
  xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0])
68
 
69
  # Draw the bounding box
70
  cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2)
 
 
 
 
 
 
 
 
71
 
72
  # Enlarge font scale and thickness
73
  font_scale = 1.0
74
  thickness = 2
75
 
76
+ # Calculate label background size
77
+ (text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
78
+ cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED)
79
+
80
  # Put the label text with black background
81
  cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
82
 
 
95
  input_image.save(uploaded_folder / image_name)
96
  pil_image_with_boxes.save(predicted_folder / image_name)
97
 
98
+ # Convert the predicted image to base64 for embedding in HTML
99
+ buffered = BytesIO()
100
+ pil_image_with_boxes.save(buffered, format="PNG")
101
+ predicted_image_base64 = base64.b64encode(buffered.getvalue()).decode()
102
+
103
+ # Append the prediction to the HTML database
104
+ append_patient_info_to_html(name, age, medical_record, sex, label, predicted_image_base64)
105
+
106
  return pil_image_with_boxes, raw_predictions_str
107
 
108
+ # Function to add watermark
109
+ def add_watermark(image):
110
+ try:
111
+ logo = Image.open('image-logo.png').convert("RGBA")
112
+ image = image.convert("RGBA")
113
+
114
+ # Resize logo
115
+ basewidth = 100
116
+ wpercent = (basewidth / float(logo.size[0]))
117
+ hsize = int((float(wpercent) * logo.size[1]))
118
+ logo = logo.resize((basewidth, hsize), Image.LANCZOS)
119
+
120
+ # Position logo
121
+ position = (image.width - logo.width - 10, image.height - logo.height - 10)
122
+
123
+ # Composite image
124
+ transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0))
125
+ transparent.paste(image, (0, 0))
126
+ transparent.paste(logo, position, mask=logo)
127
+
128
+ return transparent.convert("RGB")
129
+ except Exception as e:
130
+ print(f"Error adding watermark: {e}")
131
+ return image
132
+
133
  # Function to add text and watermark
134
  def add_text_and_watermark(image, name, age, medical_record, sex, label):
135
  draw = ImageDraw.Draw(image)
136
 
137
+ # Load a larger font (adjust the size as needed)
138
+ font_size = 24 # Example font size
139
  try:
140
  font = ImageFont.truetype("font.ttf", size=font_size)
141
  except IOError:
 
143
  print("Error: cannot open resource, using default font.")
144
 
145
  text = f"Name: {name}, Age: {age}, Medical Record: {medical_record}, Sex: {sex}, Result: {label}"
146
+
147
+ # Calculate text bounding box
148
+ text_bbox = draw.textbbox((0, 0), text, font=font)
149
+ text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
150
+ text_x = 20
151
+ text_y = 40
152
  padding = 10
153
 
154
  # Draw a filled rectangle for the background
155
  draw.rectangle(
156
+ [text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding],
157
  fill="black"
158
  )
159
 
160
  # Draw text on top of the rectangle
161
  draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font)
162
 
163
+ # Add watermark to the image
164
+ image_with_watermark = add_watermark(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ return image_with_watermark
167
+
168
+ # Function to append patient info and image to HTML database
169
+ def append_patient_info_to_html(name, age, medical_record, sex, result, predicted_image_base64):
170
+ html_entry = f"""
171
+ <div>
172
+ <h2>Patient Information</h2>
173
+ <p><strong>Name:</strong> {name}</p>
174
+ <p><strong>Age:</strong> {age}</p>
175
+ <p><strong>Medical Record:</strong> {medical_record}</p>
176
+ <p><strong>Sex:</strong> {sex}</p>
177
+ <p><strong>Result:</strong> {result}</p>
178
+ <p><strong>Predicted Image:</strong></p>
179
+ <img src="data:image/png;base64,{predicted_image_base64}" alt="Predicted Image" width="300">
180
+ </div>
181
+ <hr>
182
+ """
183
 
184
+ with open(html_db_file, 'a') as f:
185
+ f.write(html_entry)
186
 
187
  # Function to download the folders
188
+ def download_folder(folder):
189
+ zip_path = os.path.join(tempfile.gettempdir(), f"{folder}.zip")
190
+
191
+ # Zip the folder
192
+ shutil.make_archive(zip_path.replace('.zip', ''), 'zip', folder)
193
+
194
  return zip_path
195
 
196
  # Gradio Interface
197
+ def interface(name, age, medical_record, sex, input_image):
198
+ if input_image is None:
199
+ return None, "Please upload an image.", None
200
+
201
+ output_image, raw_result = predict_image(input_image, name, age, medical_record, sex)
202
+
203
+ # Return the current state of the HTML file with all predictions
204
+ return output_image, raw_result, str(html_db_file)
205
+
206
+ # Download Functions
207
+ def download_predicted_folder():
208
+ return download_folder(predicted_folder)
209
+
210
+ def download_uploaded_folder():
211
+ return download_folder(uploaded_folder)
212
+
213
+ # Launch Gradio Interface
214
+
215
  with gr.Blocks() as demo:
216
  with gr.Column():
217
  gr.Markdown("# Cataract Detection System")
 
237
  download_uploaded_btn = gr.Button("Download Uploaded Images")
238
  download_predicted_btn = gr.Button("Download Predicted Images")
239
 
240
+ # Add file download output components for the uploaded and predicted images
241
  patient_info_file = gr.File(label="Patient Information HTML File")
242
  uploaded_folder_file = gr.File(label="Uploaded Images Zip File")
243
  predicted_folder_file = gr.File(label="Predicted Images Zip File")
244
 
245
+ # Connect functions with components
246
+ submit_btn.click(fn=interface, inputs=[name, age, medical_record, sex, input_image], outputs=[output_image, raw_result])
 
 
 
247
  download_html_btn.click(fn=save_patient_info_to_html, inputs=[name, age, medical_record, sex, raw_result], outputs=patient_info_file)
248
+ download_uploaded_btn.click(fn=download_uploaded_folder, outputs=uploaded_folder_file)
249
+ download_predicted_btn.click(fn=download_predicted_folder, outputs=predicted_folder_file)
250
 
251
  # Launch Gradio app
252
  demo.launch()