ariankhalfani commited on
Commit
dc36253
1 Parent(s): 5a80261

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -50
app.py CHANGED
@@ -4,15 +4,14 @@ import cv2
4
  import numpy as np
5
  from PIL import Image, ImageDraw, ImageFont
6
  import sqlite3
 
 
 
7
  import pandas as pd
8
 
9
- # Load YOLOv10n model
10
  model = YOLO("best.pt")
11
 
12
- # Define label mappings
13
- label_mapping = {0: 'immature', 1: 'mature', 2: 'normal'}
14
- inverse_label_mapping = {'immature': 0, 'mature': 1, 'normal': 2}
15
-
16
  # Function to perform prediction
17
  def predict_image(input_image, name, patient_id):
18
  if input_image is None:
@@ -27,73 +26,240 @@ def predict_image(input_image, name, patient_id):
27
  elif image_np.shape[2] == 4: # RGBA to RGB
28
  image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
29
 
30
- # Perform inference with YOLOv10n model
31
  results = model(image_np)
32
 
33
  # Draw bounding boxes on the image
34
  image_with_boxes = image_np.copy()
35
  raw_predictions = []
 
36
 
37
  if results[0].boxes:
38
- # Iterate through each detected object
39
- for i in range(len(results[0].boxes)):
40
- box = results[0].boxes[i]
41
- predicted_class = int(box.cls.item())
42
  confidence = box.conf.item()
43
 
44
- # Apply confidence threshold
45
- if confidence >= 0.5:
46
- # Map the predicted class to the label
47
- label = label_mapping[predicted_class]
48
-
49
- # Get the bounding box coordinates
50
- xmin, ymin, xmax, ymax = map(int, box.xyxy[0])
51
-
52
- # Assign color for the label
53
- color = (0, 255, 0) if label == 'normal' else (0, 255, 255) if label == 'immature' else (255, 0, 0)
54
-
55
- # Draw the bounding box
56
- cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2)
57
-
58
- # Draw the label with confidence
59
- cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
60
-
61
- raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]")
62
-
63
- # Convert to PIL image for Gradio output
 
 
 
 
 
 
 
 
 
 
 
 
64
  pil_image_with_boxes = Image.fromarray(image_with_boxes)
65
 
66
- return pil_image_with_boxes, "\n".join(raw_predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # Gradio Interface
69
  def interface(name, patient_id, input_image):
70
  if input_image is None:
71
  return "Please upload an image."
72
 
73
- # Run prediction
74
  output_image, raw_result = predict_image(input_image, name, patient_id)
75
-
76
- return output_image, raw_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Gradio Blocks
79
  with gr.Blocks() as demo:
80
- with gr.Column():
81
- gr.Markdown("# Cataract Detection System")
82
- gr.Markdown("Upload an image to detect cataract and add patient details.")
83
-
84
- with gr.Column():
85
- name = gr.Textbox(label="Name")
86
- patient_id = gr.Textbox(label="Patient ID")
87
- input_image = gr.Image(type="pil", label="Upload an Image", image_mode="RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- with gr.Column():
90
- submit_btn = gr.Button("Submit")
91
- output_image = gr.Image(type="pil", label="Predicted Image")
92
 
93
- with gr.Row():
94
- raw_result = gr.Textbox(label="Raw Result", lines=5)
95
 
96
- submit_btn.click(fn=interface, inputs=[name, patient_id, input_image], outputs=[output_image, raw_result])
97
 
98
- # Launch the Gradio app
99
  demo.launch()
 
4
  import numpy as np
5
  from PIL import Image, ImageDraw, ImageFont
6
  import sqlite3
7
+ import base64
8
+ from io import BytesIO
9
+ import tempfile
10
  import pandas as pd
11
 
12
+ # Load YOLOv8 model
13
  model = YOLO("best.pt")
14
 
 
 
 
 
15
  # Function to perform prediction
16
  def predict_image(input_image, name, patient_id):
17
  if input_image is None:
 
26
  elif image_np.shape[2] == 4: # RGBA to RGB
27
  image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
28
 
29
+ # Perform prediction
30
  results = model(image_np)
31
 
32
  # Draw bounding boxes on the image
33
  image_with_boxes = image_np.copy()
34
  raw_predictions = []
35
+ label = "Unknown" # Default label if no detection
36
 
37
  if results[0].boxes:
38
+ for box in results[0].boxes:
39
+ # Get class index and confidence for each detection
40
+ class_index = box.cls.item()
 
41
  confidence = box.conf.item()
42
 
43
+ # Determine the label based on the class index
44
+ if class_index == 0:
45
+ label = "Mature"
46
+ color = (255, 0, 0) # Red for Mature
47
+ elif class_index == 1:
48
+ label = "Immature"
49
+ color = (0, 255, 255) # Yellow for Immature
50
+ else:
51
+ label = "Normal"
52
+ color = (0, 255, 0) # Green for Normal
53
+
54
+ xmin, ymin, xmax, ymax = map(int, box.xyxy[0])
55
+
56
+ # Draw the bounding box
57
+ cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2)
58
+
59
+ # Enlarge font scale and thickness
60
+ font_scale = 1.0
61
+ thickness = 2
62
+
63
+ # Calculate label background size
64
+ (text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
65
+ cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED)
66
+
67
+ # Put the label text with black background
68
+ cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
69
+
70
+ raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]")
71
+
72
+ raw_predictions_str = "\n".join(raw_predictions)
73
+
74
+ # Convert to PIL image for further processing
75
  pil_image_with_boxes = Image.fromarray(image_with_boxes)
76
 
77
+ # Add text and watermark
78
+ pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, patient_id, label)
79
+
80
+ return pil_image_with_boxes, raw_predictions_str
81
+
82
+ # Function to add watermark
83
+ def add_watermark(image):
84
+ try:
85
+ logo = Image.open('image-logo.png').convert("RGBA")
86
+ image = image.convert("RGBA")
87
+
88
+ # Resize logo
89
+ basewidth = 100
90
+ wpercent = (basewidth / float(logo.size[0]))
91
+ hsize = int((float(wpercent) * logo.size[1]))
92
+ logo = logo.resize((basewidth, hsize), Image.LANCZOS)
93
+
94
+ # Position logo
95
+ position = (image.width - logo.width - 10, image.height - logo.height - 10)
96
+
97
+ # Composite image
98
+ transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0))
99
+ transparent.paste(image, (0, 0))
100
+ transparent.paste(logo, position, mask=logo)
101
+
102
+ return transparent.convert("RGB")
103
+ except Exception as e:
104
+ print(f"Error adding watermark: {e}")
105
+ return image
106
+
107
+ # Function to add text and watermark
108
+ def add_text_and_watermark(image, name, patient_id, label):
109
+ draw = ImageDraw.Draw(image)
110
+
111
+ # Load a larger font (adjust the size as needed)
112
+ font_size = 48 # Example font size
113
+ try:
114
+ font = ImageFont.truetype("font.ttf", size=font_size)
115
+ except IOError:
116
+ font = ImageFont.load_default()
117
+ print("Error: cannot open resource, using default font.")
118
+
119
+ text = f"Name: {name}, ID: {patient_id}, Result: {label}"
120
+
121
+ # Calculate text bounding box
122
+ text_bbox = draw.textbbox((0, 0), text, font=font)
123
+ text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
124
+ text_x = 20
125
+ text_y = 40
126
+ padding = 10
127
+
128
+ # Draw a filled rectangle for the background
129
+ draw.rectangle(
130
+ [text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding],
131
+ fill="black"
132
+ )
133
+
134
+ # Draw text on top of the rectangle
135
+ draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font)
136
+
137
+ # Add watermark to the image
138
+ image_with_watermark = add_watermark(image)
139
+
140
+ return image_with_watermark
141
+
142
+ # Function to initialize the database
143
+ def init_db():
144
+ conn = sqlite3.connect('results.db')
145
+ c = conn.cursor()
146
+ c.execute('''CREATE TABLE IF NOT EXISTS results
147
+ (id INTEGER PRIMARY KEY, name TEXT, patient_id TEXT, input_image BLOB, predicted_image BLOB, result TEXT)''')
148
+ conn.commit()
149
+ conn.close()
150
+
151
+ # Function to submit result to the database
152
+ def submit_result(name, patient_id, input_image, predicted_image, result):
153
+ conn = sqlite3.connect('results.db')
154
+ c = conn.cursor()
155
+
156
+ input_image_np = np.array(input_image)
157
+ _, input_buffer = cv2.imencode('.png', cv2.cvtColor(input_image_np, cv2.COLOR_RGB2BGR))
158
+ input_image_bytes = input_buffer.tobytes()
159
+
160
+ predicted_image_np = np.array(predicted_image)
161
+ predicted_image_rgb = cv2.cvtColor(predicted_image_np, cv2.COLOR_RGB2BGR) # Ensure correct color conversion
162
+ _, predicted_buffer = cv2.imencode('.png', predicted_image_rgb)
163
+ predicted_image_bytes = predicted_buffer.tobytes()
164
+
165
+ c.execute("INSERT INTO results (name, patient_id, input_image, predicted_image, result) VALUES (?, ?, ?, ?, ?)",
166
+ (name, patient_id, input_image_bytes, predicted_image_bytes, result))
167
+ conn.commit()
168
+ conn.close()
169
+ return "Result submitted to database."
170
+
171
+ # Function to load and view database
172
+ def view_database():
173
+ conn = sqlite3.connect('results.db')
174
+ c = conn.cursor()
175
+ c.execute("SELECT * FROM results")
176
+ rows = c.fetchall()
177
+ conn.close()
178
+
179
+ # Convert to pandas DataFrame
180
+ df = pd.DataFrame(rows, columns=["ID", "Name", "Patient ID", "Input Image", "Predicted Image", "Result"])
181
+
182
+ return df
183
+
184
+ # Function to download database or image
185
+ def download_file(choice):
186
+ conn = sqlite3.connect('results.db')
187
+ c = conn.cursor()
188
+
189
+ if choice == "Database (.db)":
190
+ conn.close()
191
+ return 'results.db'
192
+ else:
193
+ c.execute("SELECT predicted_image FROM results ORDER BY id DESC LIMIT 1")
194
+ row = c.fetchone()
195
+ conn.close()
196
+ if row:
197
+ image_bytes = row[0]
198
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
199
+ temp_file.write(image_bytes)
200
+ temp_file.flush() # Ensure all data is written before closing
201
+ return temp_file.name
202
+ else:
203
+ conn.close()
204
+ raise FileNotFoundError("No images found in the database.")
205
+
206
+ # Initialize the database
207
+ init_db()
208
 
209
  # Gradio Interface
210
  def interface(name, patient_id, input_image):
211
  if input_image is None:
212
  return "Please upload an image."
213
 
 
214
  output_image, raw_result = predict_image(input_image, name, patient_id)
215
+ submit_status = submit_result(name, patient_id, input_image, output_image, raw_result)
216
+
217
+ return output_image, raw_result, submit_status
218
+
219
+ # View Database Function
220
+ def view_db_interface():
221
+ df = view_database()
222
+ return df
223
+
224
+ # Download Function
225
+ def download_interface(choice):
226
+ try:
227
+ file_path = download_file(choice)
228
+ with open(file_path, "rb") as file:
229
+ return file.read(), f"{choice}"
230
+ except Exception as e:
231
+ return str(e)
232
 
 
233
  with gr.Blocks() as demo:
234
+ with gr.Tabs():
235
+ with gr.Tab("Image Analyzer and Screener"):
236
+ gr.Markdown("## Cataract Detection System")
237
+ with gr.Row():
238
+ with gr.Column():
239
+ input_image = gr.Image(label="Upload Image")
240
+ name = gr.Textbox(label="Patient Name")
241
+ patient_id = gr.Textbox(label="Patient ID")
242
+ submit_btn = gr.Button("Submit")
243
+
244
+ with gr.Column():
245
+ output_image = gr.Image(label="Predicted Image")
246
+ raw_result = gr.Textbox(label="Raw Result")
247
+ submit_status = gr.Textbox(label="Submission Status")
248
+
249
+ submit_btn.click(fn=interface, inputs=[name, patient_id, input_image], outputs=[output_image, raw_result, submit_status])
250
+
251
+ with gr.Tab("Database Viewer"):
252
+ view_db_btn = gr.Button("View Database")
253
+ database_display = gr.Dataframe()
254
+
255
+ view_db_btn.click(fn=view_db_interface, outputs=database_display)
256
 
257
+ with gr.Tab("Download Results"):
258
+ download_choice = gr.Radio(["Database (.db)", "Predicted Image (.png)"], label="Download Option")
259
+ download_btn = gr.Button("Download")
260
 
261
+ download_output = gr.File()
 
262
 
263
+ download_btn.click(fn=download_interface, inputs=download_choice, outputs=download_output)
264
 
 
265
  demo.launch()