ariankhalfani commited on
Commit
925fada
1 Parent(s): c5f9028

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -108
app.py CHANGED
@@ -4,13 +4,14 @@ import cv2
4
  import numpy as np
5
  from PIL import Image, ImageDraw, ImageFont
6
  import sqlite3
 
 
7
  import tempfile
8
  import pandas as pd
9
 
10
  # Load YOLOv8 model
11
  model = YOLO("best.pt")
12
 
13
- # Function to perform prediction
14
  def predict_image(input_image, name, patient_id):
15
  if input_image is None:
16
  return None, "Please Input The Image"
@@ -30,51 +31,50 @@ def predict_image(input_image, name, patient_id):
30
  # Draw bounding boxes on the image
31
  image_with_boxes = image_np.copy()
32
  raw_predictions = []
33
- label = "Unknown" # Default label if no detection
34
 
35
  if results[0].boxes:
36
- for box in results[0].boxes:
37
- # Get class index and confidence for each detection
38
- class_index = box.cls.item()
39
- confidence = box.conf.item()
40
-
41
- # Determine the label based on the class index
42
- if class_index == 0:
43
- label = "Immature"
44
- color = (0, 255, 255) # Yellow for Immature
45
- elif class_index == 1:
46
- label = "Mature"
47
- color = (255, 0, 0) # Red for Mature
48
- else:
49
- label = "Normal"
50
- color = (0, 255, 0) # Green for Normal
51
-
52
- xmin, ymin, xmax, ymax = map(int, box.xyxy[0])
53
-
54
- # Draw the bounding box
55
- cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2)
56
-
57
- # Enlarge font scale and thickness
58
- font_scale = 1.0
59
- thickness = 2
60
-
61
- # Calculate label background size
62
- (text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
63
- cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED)
64
-
65
- # Put the label text with black background
66
- cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
67
-
68
- raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]")
69
 
70
  raw_predictions_str = "\n".join(raw_predictions)
71
-
72
  # Convert to PIL image for further processing
73
  pil_image_with_boxes = Image.fromarray(image_with_boxes)
74
 
75
  # Add text and watermark
76
  pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, patient_id, label)
77
-
78
  return pil_image_with_boxes, raw_predictions_str
79
 
80
  # Function to add watermark
@@ -82,21 +82,21 @@ def add_watermark(image):
82
  try:
83
  logo = Image.open('image-logo.png').convert("RGBA")
84
  image = image.convert("RGBA")
85
-
86
  # Resize logo
87
  basewidth = 100
88
  wpercent = (basewidth / float(logo.size[0]))
89
  hsize = int((float(wpercent) * logo.size[1]))
90
  logo = logo.resize((basewidth, hsize), Image.LANCZOS)
91
-
92
  # Position logo
93
  position = (image.width - logo.width - 10, image.height - logo.height - 10)
94
-
95
  # Composite image
96
  transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0))
97
  transparent.paste(image, (0, 0))
98
  transparent.paste(logo, position, mask=logo)
99
-
100
  return transparent.convert("RGB")
101
  except Exception as e:
102
  print(f"Error adding watermark: {e}")
@@ -105,7 +105,7 @@ def add_watermark(image):
105
  # Function to add text and watermark
106
  def add_text_and_watermark(image, name, patient_id, label):
107
  draw = ImageDraw.Draw(image)
108
-
109
  # Load a larger font (adjust the size as needed)
110
  font_size = 48 # Example font size
111
  try:
@@ -113,9 +113,9 @@ def add_text_and_watermark(image, name, patient_id, label):
113
  except IOError:
114
  font = ImageFont.load_default()
115
  print("Error: cannot open resource, using default font.")
116
-
117
  text = f"Name: {name}, ID: {patient_id}, Result: {label}"
118
-
119
  # Calculate text bounding box
120
  text_bbox = draw.textbbox((0, 0), text, font=font)
121
  text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
@@ -128,13 +128,13 @@ def add_text_and_watermark(image, name, patient_id, label):
128
  [text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding],
129
  fill="black"
130
  )
131
-
132
  # Draw text on top of the rectangle
133
  draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font)
134
 
135
  # Add watermark to the image
136
  image_with_watermark = add_watermark(image)
137
-
138
  return image_with_watermark
139
 
140
  # Function to initialize the database
@@ -150,16 +150,16 @@ def init_db():
150
  def submit_result(name, patient_id, input_image, predicted_image, result):
151
  conn = sqlite3.connect('results.db')
152
  c = conn.cursor()
153
-
154
  input_image_np = np.array(input_image)
155
  _, input_buffer = cv2.imencode('.png', cv2.cvtColor(input_image_np, cv2.COLOR_RGB2BGR))
156
  input_image_bytes = input_buffer.tobytes()
157
-
158
  predicted_image_np = np.array(predicted_image)
159
  predicted_image_rgb = cv2.cvtColor(predicted_image_np, cv2.COLOR_RGB2BGR) # Ensure correct color conversion
160
  _, predicted_buffer = cv2.imencode('.png', predicted_image_rgb)
161
  predicted_image_bytes = predicted_buffer.tobytes()
162
-
163
  c.execute("INSERT INTO results (name, patient_id, input_image, predicted_image, result) VALUES (?, ?, ?, ?, ?)",
164
  (name, patient_id, input_image_bytes, predicted_image_bytes, result))
165
  conn.commit()
@@ -170,24 +170,28 @@ def submit_result(name, patient_id, input_image, predicted_image, result):
170
  def view_database():
171
  conn = sqlite3.connect('results.db')
172
  c = conn.cursor()
173
- c.execute("SELECT * FROM results")
174
  rows = c.fetchall()
175
  conn.close()
176
-
177
- # Convert to pandas DataFrame
178
- df = pd.DataFrame(rows, columns=["ID", "Name", "Patient ID", "Input Image", "Predicted Image", "Result"])
179
-
180
  return df
181
 
182
  # Function to download database or image
183
  def download_file(choice):
184
- conn = sqlite3.connect('results.db')
185
- c = conn.cursor()
186
-
187
  if choice == "Database (.db)":
188
- conn.close()
189
  return 'results.db'
 
 
 
 
 
190
  else:
 
 
191
  c.execute("SELECT predicted_image FROM results ORDER BY id DESC LIMIT 1")
192
  row = c.fetchone()
193
  conn.close()
@@ -198,7 +202,6 @@ def download_file(choice):
198
  temp_file.flush() # Ensure all data is written before closing
199
  return temp_file.name
200
  else:
201
- conn.close()
202
  raise FileNotFoundError("No images found in the database.")
203
 
204
  # Initialize the database
@@ -211,52 +214,37 @@ def interface(name, patient_id, input_image):
211
 
212
  output_image, raw_result = predict_image(input_image, name, patient_id)
213
  submit_status = submit_result(name, patient_id, input_image, output_image, raw_result)
214
-
215
- return output_image, raw_result, submit_status
216
-
217
- # View Database Function
218
- def view_db_interface():
219
- df = view_database()
220
- return df
221
-
222
- # Download Function
223
- def download_interface(choice):
224
- try:
225
- file_path = download_file(choice)
226
- with open(file_path, "rb") as file:
227
- return file.read(), file_path
228
- except Exception as e:
229
- return f"Error: {str(e)}", None
230
-
231
- # Build Gradio Interface
232
- app = gr.Blocks()
233
-
234
- with app:
235
- gr.Markdown("# Eye Condition Detection System")
236
-
237
- with gr.Row():
238
- with gr.Column():
239
- name = gr.Textbox(label="Name")
240
- patient_id = gr.Textbox(label="Patient ID")
241
- input_image = gr.Image(label="Input Image", tool="editor", type="pil")
242
-
243
- with gr.Column():
244
- output_image = gr.Image(label="Predicted Image")
245
- raw_result = gr.Textbox(label="Raw Predictions", lines=5)
246
- submit_status = gr.Textbox(label="Submit Status")
247
-
248
- predict_button = gr.Button("Predict")
249
-
250
- predict_button.click(fn=interface, inputs=[name, patient_id, input_image], outputs=[output_image, raw_result, submit_status])
251
-
252
- with gr.Row():
253
- with gr.Column():
254
- view_button = gr.Button("View Database")
255
- download_choice = gr.Dropdown(label="Download Option", choices=["Database (.db)", "Predicted Image (.png)"])
256
- download_button = gr.Button("Download")
257
-
258
- view_button.click(fn=view_db_interface, inputs=[], outputs=[gr.Dataframe()])
259
- download_button.click(fn=download_interface, inputs=[download_choice], outputs=[gr.File(), gr.Textbox()])
260
-
261
- # Launch the Gradio app
262
- app.launch()
 
4
  import numpy as np
5
  from PIL import Image, ImageDraw, ImageFont
6
  import sqlite3
7
+ import base64
8
+ from io import BytesIO
9
  import tempfile
10
  import pandas as pd
11
 
12
  # Load YOLOv8 model
13
  model = YOLO("best.pt")
14
 
 
15
  def predict_image(input_image, name, patient_id):
16
  if input_image is None:
17
  return None, "Please Input The Image"
 
31
  # Draw bounding boxes on the image
32
  image_with_boxes = image_np.copy()
33
  raw_predictions = []
 
34
 
35
  if results[0].boxes:
36
+ # Sort the results by confidence and take the highest confidence one
37
+ highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item())
38
+
39
+ # Determine the label based on the class index
40
+ class_index = highest_confidence_result.cls.item()
41
+ if class_index == 0:
42
+ label = "Immature"
43
+ color = (0, 255, 255) # Yellow for Immature
44
+ elif class_index == 1:
45
+ label = "Mature"
46
+ color = (255, 0, 0) # Red for Mature
47
+ else:
48
+ label = "Normal"
49
+ color = (0, 255, 0) # Green for Normal
50
+
51
+ confidence = highest_confidence_result.conf.item()
52
+ xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0])
53
+
54
+ # Draw the bounding box
55
+ cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2)
56
+
57
+ # Enlarge font scale and thickness
58
+ font_scale = 1.0
59
+ thickness = 2
60
+
61
+ # Calculate label background size
62
+ (text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
63
+ cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED)
64
+
65
+ # Put the label text with black background
66
+ cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
67
+
68
+ raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]")
69
 
70
  raw_predictions_str = "\n".join(raw_predictions)
71
+
72
  # Convert to PIL image for further processing
73
  pil_image_with_boxes = Image.fromarray(image_with_boxes)
74
 
75
  # Add text and watermark
76
  pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, patient_id, label)
77
+
78
  return pil_image_with_boxes, raw_predictions_str
79
 
80
  # Function to add watermark
 
82
  try:
83
  logo = Image.open('image-logo.png').convert("RGBA")
84
  image = image.convert("RGBA")
85
+
86
  # Resize logo
87
  basewidth = 100
88
  wpercent = (basewidth / float(logo.size[0]))
89
  hsize = int((float(wpercent) * logo.size[1]))
90
  logo = logo.resize((basewidth, hsize), Image.LANCZOS)
91
+
92
  # Position logo
93
  position = (image.width - logo.width - 10, image.height - logo.height - 10)
94
+
95
  # Composite image
96
  transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0))
97
  transparent.paste(image, (0, 0))
98
  transparent.paste(logo, position, mask=logo)
99
+
100
  return transparent.convert("RGB")
101
  except Exception as e:
102
  print(f"Error adding watermark: {e}")
 
105
  # Function to add text and watermark
106
  def add_text_and_watermark(image, name, patient_id, label):
107
  draw = ImageDraw.Draw(image)
108
+
109
  # Load a larger font (adjust the size as needed)
110
  font_size = 48 # Example font size
111
  try:
 
113
  except IOError:
114
  font = ImageFont.load_default()
115
  print("Error: cannot open resource, using default font.")
116
+
117
  text = f"Name: {name}, ID: {patient_id}, Result: {label}"
118
+
119
  # Calculate text bounding box
120
  text_bbox = draw.textbbox((0, 0), text, font=font)
121
  text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
 
128
  [text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding],
129
  fill="black"
130
  )
131
+
132
  # Draw text on top of the rectangle
133
  draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font)
134
 
135
  # Add watermark to the image
136
  image_with_watermark = add_watermark(image)
137
+
138
  return image_with_watermark
139
 
140
  # Function to initialize the database
 
150
  def submit_result(name, patient_id, input_image, predicted_image, result):
151
  conn = sqlite3.connect('results.db')
152
  c = conn.cursor()
153
+
154
  input_image_np = np.array(input_image)
155
  _, input_buffer = cv2.imencode('.png', cv2.cvtColor(input_image_np, cv2.COLOR_RGB2BGR))
156
  input_image_bytes = input_buffer.tobytes()
157
+
158
  predicted_image_np = np.array(predicted_image)
159
  predicted_image_rgb = cv2.cvtColor(predicted_image_np, cv2.COLOR_RGB2BGR) # Ensure correct color conversion
160
  _, predicted_buffer = cv2.imencode('.png', predicted_image_rgb)
161
  predicted_image_bytes = predicted_buffer.tobytes()
162
+
163
  c.execute("INSERT INTO results (name, patient_id, input_image, predicted_image, result) VALUES (?, ?, ?, ?, ?)",
164
  (name, patient_id, input_image_bytes, predicted_image_bytes, result))
165
  conn.commit()
 
170
  def view_database():
171
  conn = sqlite3.connect('results.db')
172
  c = conn.cursor()
173
+ c.execute("SELECT name, patient_id FROM results")
174
  rows = c.fetchall()
175
  conn.close()
176
+
177
+ # Convert to pandas DataFrame for better display in Gradio
178
+ df = pd.DataFrame(rows, columns=["Name", "Patient ID"])
179
+
180
  return df
181
 
182
  # Function to download database or image
183
  def download_file(choice):
 
 
 
184
  if choice == "Database (.db)":
185
+ # Provide the path to the database file
186
  return 'results.db'
187
+ elif choice == "Database (.html)":
188
+ df = view_database()
189
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as temp_file:
190
+ df.to_html(temp_file.name)
191
+ return temp_file.name
192
  else:
193
+ conn = sqlite3.connect('results.db')
194
+ c = conn.cursor()
195
  c.execute("SELECT predicted_image FROM results ORDER BY id DESC LIMIT 1")
196
  row = c.fetchone()
197
  conn.close()
 
202
  temp_file.flush() # Ensure all data is written before closing
203
  return temp_file.name
204
  else:
 
205
  raise FileNotFoundError("No images found in the database.")
206
 
207
  # Initialize the database
 
214
 
215
  output_image, raw_result = predict_image(input_image, name, patient_id)
216
  submit_status = submit_result(name, patient_id, input_image, output_image, raw_result)
217
+ return output_image, submit_status
218
+
219
+ inputs = [
220
+ gr.Textbox(label="Name"),
221
+ gr.Textbox(label="Patient ID"),
222
+ gr.Image(type="pil", label="Input Image")
223
+ ]
224
+
225
+ outputs = [
226
+ gr.Image(label="Output Image"),
227
+ gr.Textbox(label="Status")
228
+ ]
229
+
230
+ # File download interface
231
+ download_inputs = gr.Radio(["Database (.db)", "Database (.html)", "Image (.png)"], label="Download Type")
232
+ download_output = gr.File(label="Download File")
233
+
234
+ app = gr.Interface(
235
+ fn=interface,
236
+ inputs=inputs,
237
+ outputs=outputs,
238
+ title="AI Cataract Detector",
239
+ description="Upload an image, enter the patient's name and ID, and receive a prediction."
240
+ )
241
+
242
+ download_app = gr.Interface(
243
+ fn=download_file,
244
+ inputs=download_inputs,
245
+ outputs=download_output,
246
+ title="Download Results"
247
+ )
248
+
249
+ # Combine both interfaces in one layout
250
+ gr.TabbedInterface([app, download_app], ["Prediction", "Download"]).launch()