juba7 commited on
Commit
31220fd
Β·
1 Parent(s): b487f2a

Add application file

Browse files
Files changed (1) hide show
  1. app.py +391 -0
app.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import huggingface_hub
3
+ import onnxruntime as rt
4
+ import numpy as np
5
+ import cv2
6
+ import os
7
+ import csv
8
+ import datetime
9
+ import time
10
+
11
+ # --- Constants ---
12
+ LOG_FILE = "processing_log.csv"
13
+ LOG_HEADER = [
14
+ "Timestamp", "Repository", "Model Filename", "Model Size (MB)",
15
+ "Image Resolution (WxH)", "Execution Provider", "Processing Time (s)"
16
+ ]
17
+
18
+ # Global variables for model and providers
19
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
20
+ model_repo_default = "skytnt/anime-seg"
21
+
22
+ # --- Logging Functions ---
23
+
24
+ def initialize_log_file():
25
+ """Creates the log file and writes the header if it doesn't exist."""
26
+ if not os.path.exists(LOG_FILE):
27
+ try:
28
+ with open(LOG_FILE, 'w', newline='', encoding='utf-8') as f:
29
+ writer = csv.writer(f)
30
+ writer.writerow(LOG_HEADER)
31
+ print(f"Log file initialized: {LOG_FILE}")
32
+ except IOError as e:
33
+ print(f"Error initializing log file {LOG_FILE}: {e}")
34
+
35
+ def log_processing_event(timestamp, repo, model_filename, model_size_mb,
36
+ resolution, provider, processing_time):
37
+ """Appends a processing event to the CSV log file."""
38
+ try:
39
+ with open(LOG_FILE, 'a', newline='', encoding='utf-8') as f:
40
+ writer = csv.writer(f)
41
+ writer.writerow([
42
+ timestamp, repo, model_filename, f"{model_size_mb:.2f}",
43
+ resolution, provider, f"{processing_time:.4f}"
44
+ ])
45
+ except IOError as e:
46
+ print(f"Error writing to log file {LOG_FILE}: {e}")
47
+ except Exception as e:
48
+ print(f"An unexpected error occurred during logging: {e}")
49
+
50
+ def read_log_file():
51
+ """Reads the entire log file content."""
52
+ try:
53
+ if not os.path.exists(LOG_FILE):
54
+ return "Log file not found."
55
+ with open(LOG_FILE, 'r', encoding='utf-8') as f:
56
+ # Read all lines and join them for display
57
+ return "".join(f.readlines())
58
+ # Alternatively, for cleaner display of CSV in a textbox:
59
+ # reader = csv.reader(f)
60
+ # rows = list(reader)
61
+ # # Format header and rows nicely
62
+ # header = rows[0]
63
+ # data_rows = rows[1:]
64
+ # formatted_rows = [", ".join(header)] # Join header elements
65
+ # for row in data_rows:
66
+ # formatted_rows.append(", ".join(row)) # Join data elements
67
+ # return "\n".join(formatted_rows)
68
+ except IOError as e:
69
+ print(f"Error reading log file {LOG_FILE}: {e}")
70
+ return f"Error reading log file: {e}"
71
+ except Exception as e:
72
+ print(f"An unexpected error occurred reading log file: {e}")
73
+ return f"Error reading log file: {e}"
74
+
75
+ # --- Helper Functions ---
76
+
77
+ def get_model_details_from_choice(choice_string: str) -> tuple[str, float | None]:
78
+ """
79
+ Extracts filename and size (MB) from the dropdown choice string.
80
+ Returns (filename, size_mb) or (filename, None) if size is not parseable.
81
+ """
82
+ if not choice_string:
83
+ return "", None
84
+ parts = choice_string.split(" (")
85
+ filename = parts[0]
86
+ size_mb = None
87
+ if len(parts) > 1 and parts[1].endswith(" MB)"):
88
+ try:
89
+ size_str = parts[1].replace(" MB)", "")
90
+ size_mb = float(size_str)
91
+ except ValueError:
92
+ pass # Size couldn't be parsed
93
+ return filename, size_mb
94
+
95
+ # --- Model Loading and UI Functions (Mostly unchanged, modifications marked) ---
96
+
97
+ def update_onnx_files(repo: str):
98
+ """
99
+ Lists .onnx files in the Hugging Face repository and updates the Dropdown with file sizes.
100
+ """
101
+ onnx_files_with_size = []
102
+ try:
103
+ api = huggingface_hub.HfApi()
104
+ repo_info = api.model_info(repo_id=repo, files_metadata=True)
105
+
106
+ for file_info in repo_info.siblings:
107
+ if file_info.rfilename.endswith('.onnx'):
108
+ try:
109
+ # Use file_info.size which is in bytes
110
+ size_mb = file_info.size / (1024 * 1024) if file_info.size else 0
111
+ onnx_files_with_size.append(f"{file_info.rfilename} ({size_mb:.2f} MB)")
112
+ except Exception:
113
+ onnx_files_with_size.append(f"{file_info.rfilename} (Size N/A)")
114
+
115
+ if onnx_files_with_size:
116
+ onnx_files_with_size.sort()
117
+ return gr.update(choices=onnx_files_with_size, value=onnx_files_with_size[0])
118
+ else:
119
+ return gr.update(choices=[], value="", warning=f"No .onnx files found in repository '{repo}'")
120
+
121
+ except huggingface_hub.utils.RepositoryNotFoundError:
122
+ return gr.update(choices=[], value="", error=f"Repository '{repo}' not found or access denied.")
123
+ except Exception as e:
124
+ print(f"Error fetching repo files for {repo}: {e}")
125
+ return gr.update(choices=[], value="", error=f"Error fetching files: {str(e)}")
126
+
127
+ # Get default choices and filename
128
+ default_onnx_files_with_size = []
129
+ default_model_filename = ""
130
+ try:
131
+ initial_update = update_onnx_files(model_repo_default)
132
+ if isinstance(initial_update, gr.update) and initial_update.choices:
133
+ default_onnx_files_with_size = initial_update.choices
134
+ default_model_filename, _ = get_model_details_from_choice(default_onnx_files_with_size[0]) # Use helper
135
+ else:
136
+ default_onnx_files_with_size = ["isnetis.onnx (Size N/A)"]
137
+ default_model_filename = "isnetis.onnx"
138
+ print(f"Warning: Could not fetch initial ONNX files from {model_repo_default}. Using fallback '{default_model_filename}'.")
139
+ except Exception as e:
140
+ default_onnx_files_with_size = ["isnetis.onnx (Size N/A)"]
141
+ default_model_filename = "isnetis.onnx"
142
+ print(f"Error during initial model fetch: {e}. Using fallback '{default_model_filename}'.")
143
+
144
+ # Global variables for current model state
145
+ current_model_repo = model_repo_default
146
+ current_model_filename = default_model_filename
147
+
148
+ # Initial download and model load
149
+ model_path = None
150
+ rmbg_model = None
151
+ try:
152
+ print(f"Attempting initial download: {current_model_repo}/{current_model_filename}")
153
+ if current_model_filename: # Only download if we have a filename
154
+ model_path = huggingface_hub.hf_hub_download(current_model_repo, current_model_filename)
155
+ rmbg_model = rt.InferenceSession(model_path, providers=providers)
156
+ print(f"Initial model loaded successfully: {model_path}")
157
+ print(f"Available Execution Providers: {rt.get_available_providers()}")
158
+ print(f"Using Provider(s): {rmbg_model.get_providers()}")
159
+ else:
160
+ print("FATAL: No default model filename determined. Cannot load initial model.")
161
+ except Exception as e:
162
+ print(f"FATAL: Could not download or load initial model '{current_model_repo}/{current_model_filename}'. Error: {e}")
163
+
164
+ # --- Inference Functions (Unchanged get_mask, rmbg_fn) ---
165
+ def get_mask(img, s=1024):
166
+ if rmbg_model is None:
167
+ raise gr.Error("Model is not loaded. Please check model selection and update status.")
168
+ img_normalized = (img / 255.0).astype(np.float32)
169
+ h0, w0 = img.shape[:2]
170
+ if h0 >= w0: h, w = (s, int(s * w0 / h0))
171
+ else: h, w = (int(s * h0 / w0), s)
172
+ ph, pw = s - h, s - w
173
+ img_input = np.zeros([s, s, 3], dtype=np.float32)
174
+ resized_img = cv2.resize(img_normalized, (w, h), interpolation=cv2.INTER_AREA)
175
+ img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = resized_img
176
+ img_input = np.transpose(img_input, (2, 0, 1))[np.newaxis, :]
177
+ input_name = rmbg_model.get_inputs()[0].name
178
+ mask_output = rmbg_model.run(None, {input_name: img_input})[0][0]
179
+ mask_processed = np.transpose(mask_output, (1, 2, 0))
180
+ mask_processed = mask_processed[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
181
+ mask_resized = cv2.resize(mask_processed, (w0, h0), interpolation=cv2.INTER_LINEAR)
182
+ if mask_resized.ndim == 2: mask_resized = mask_resized[:, :, np.newaxis]
183
+ mask_final = np.clip(mask_resized, 0, 1)
184
+ return mask_final
185
+
186
+ def rmbg_fn(img):
187
+ if img is None: raise gr.Error("Please provide an input image.")
188
+ mask = get_mask(img)
189
+ if img.dtype != np.uint8: img = (img * 255).clip(0, 255).astype(np.uint8) if img.max() <= 1.0 else img.clip(0, 255).astype(np.uint8)
190
+ alpha_channel = (mask * 255).astype(np.uint8)
191
+ if img.shape[2] == 3: img_out_rgba = np.concatenate([img, alpha_channel], axis=2)
192
+ else: img_out_rgba = img.copy(); img_out_rgba[:, :, 3] = alpha_channel[:,:,0]
193
+ mask_img_display = (mask * 255).astype(np.uint8).repeat(3, axis=2)
194
+ return mask_img_display, img_out_rgba
195
+
196
+ # --- Model Update Function ---
197
+ def update_model(model_repo, model_filename_with_size):
198
+ global rmbg_model, current_model_repo, current_model_filename
199
+ model_filename, _ = get_model_details_from_choice(model_filename_with_size) # Use helper
200
+ if not model_filename: return "Error: No model filename selected or extracted."
201
+ if model_repo == current_model_repo and model_filename == current_model_filename:
202
+ # Even if it's the same, report the provider being used
203
+ current_provider = rmbg_model.get_providers()[0] if rmbg_model else "N/A"
204
+ return f"Model already loaded: {current_model_repo}/{current_model_filename}\nUsing Provider: {current_provider}"
205
+
206
+ try:
207
+ print(f"Updating model to: {model_repo}/{model_filename}")
208
+ model_path = huggingface_hub.hf_hub_download(model_repo, model_filename)
209
+ new_rmbg_model = rt.InferenceSession(model_path, providers=providers)
210
+ rmbg_model = new_rmbg_model
211
+ current_model_repo = model_repo
212
+ current_model_filename = model_filename
213
+ active_provider = rmbg_model.get_providers()[0] # Get the provider actually used
214
+ print(f"Model updated successfully: {model_path}")
215
+ print(f"Using Provider: {active_provider}")
216
+ return f"Model updated: {current_model_repo}/{current_model_filename}\nUsing Provider: {active_provider}"
217
+ except huggingface_hub.utils.HfHubHTTPError as e:
218
+ print(f"Error downloading model: {e}")
219
+ return f"Error downloading model: {model_repo}/{model_filename}. ({e.response.status_code})"
220
+ except rt.ONNXRuntimeException as e:
221
+ print(f"Error loading ONNX model: {e}")
222
+ # Attempt to provide more specific feedback if it's a provider issue
223
+ if "CUDAExecutionProvider" in str(e):
224
+ return f"Error loading ONNX model '{model_filename}'. CUDA unavailable or setup issue? Falling back might require restart or different build. Error: {e}"
225
+ return f"Error loading ONNX model '{model_filename}'. Incompatible or corrupted? Error: {e}"
226
+ except Exception as e:
227
+ print(f"Error updating model: {e}")
228
+ return f"Error updating model: {str(e)}"
229
+
230
+ # --- Main Processing Function (MODIFIED FOR LOGGING) ---
231
+ def process_and_update(img, model_repo, model_filename_with_size, history):
232
+ global current_model_repo, current_model_filename, rmbg_model
233
+
234
+ # --- Pre-checks ---
235
+ if img is None:
236
+ return None, [], history, "generated", "Please upload an image first.", read_log_file() # Return current log
237
+ if rmbg_model is None:
238
+ return None, [], history, "generated", "ERROR: Model not loaded. Update model first.", read_log_file() # Return current log
239
+
240
+ selected_model_filename, selected_model_size_mb = get_model_details_from_choice(model_filename_with_size) # Use helper
241
+ status_message = ""
242
+
243
+ # --- Model Update Check ---
244
+ if model_repo != current_model_repo or selected_model_filename != current_model_filename:
245
+ status_message = update_model(model_repo, model_filename_with_size)
246
+ if "Error" in status_message:
247
+ return None, [], history, "generated", f"Model Update Failed:\n{status_message}", read_log_file() # Return current log
248
+ if rmbg_model is None:
249
+ return None, [], history, "generated", "ERROR: Model failed to load after update.", read_log_file() # Return current log
250
+
251
+ # --- Processing & Logging ---
252
+ try:
253
+ start_time = time.time() # Start timer
254
+ mask_img, generated_img_rgba = rmbg_fn(img) # Run inference
255
+ end_time = time.time() # End timer
256
+
257
+ processing_time = end_time - start_time # Calculate duration
258
+
259
+ # --- Gather Log Information ---
260
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
261
+ h, w = img.shape[:2]
262
+ resolution = f"{w}x{h}"
263
+ # Get the *actually used* provider from the loaded session
264
+ active_provider = rmbg_model.get_providers()[0]
265
+
266
+ # Log the event
267
+ log_processing_event(
268
+ timestamp=timestamp,
269
+ repo=current_model_repo, # Use the confirmed current repo
270
+ model_filename=current_model_filename, # Use the confirmed current filename
271
+ model_size_mb=selected_model_size_mb if selected_model_size_mb is not None else 0.0, # Use extracted size
272
+ resolution=resolution,
273
+ provider=active_provider,
274
+ processing_time=processing_time
275
+ )
276
+
277
+ # --- Prepare Outputs ---
278
+ new_history = history + [generated_img_rgba]
279
+ output_pair = [mask_img, generated_img_rgba]
280
+ current_log_content = read_log_file() # Read updated log
281
+
282
+ status_message = f"{status_message}\nProcessing complete ({processing_time:.2f}s)".strip()
283
+
284
+ return generated_img_rgba, output_pair, new_history, "generated", status_message, current_log_content
285
+
286
+ except Exception as e:
287
+ print(f"Error during processing: {e}")
288
+ import traceback
289
+ traceback.print_exc()
290
+ # Still return the log content even if processing fails
291
+ return None, [], history, "generated", f"Error during processing: {str(e)}", read_log_file()
292
+
293
+
294
+ # --- UI Interaction Functions (Unchanged toggle_view, clear_all needs slight modification) ---
295
+ def toggle_view(view_state, output_pair):
296
+ if not output_pair or len(output_pair) != 2:
297
+ return None, view_state, "View Mask" if view_state == "generated" else "View Generated"
298
+ if view_state == "generated":
299
+ return output_pair[0], "mask", "View Generated"
300
+ else:
301
+ return output_pair[1], "generated", "View Mask"
302
+
303
+ def clear_all():
304
+ """ Resets inputs, outputs, states, status, but keeps log view """
305
+ # Keeps the log viewer content, as history shouldn't be wiped by clearing inputs
306
+ initial_log_content = read_log_file() # Read log to display upon clearing
307
+ return None, None, [], [], "generated", "Interface cleared.", "View Mask", [], initial_log_content
308
+
309
+ # --- Gradio UI Definition ---
310
+ if __name__ == "__main__":
311
+ initialize_log_file() # Ensure log file exists before launching app
312
+
313
+ app = gr.Blocks(css=".gradio-container { max-width: 95% !important; }") # Wider
314
+ with app:
315
+ gr.Markdown("# Image Background Removal (Segmentation) with Logging")
316
+ gr.Markdown("Test ONNX models, view performance logs.")
317
+
318
+ with gr.Row():
319
+ # Left Column: Controls and Input
320
+ with gr.Column(scale=2):
321
+ with gr.Group():
322
+ gr.Markdown("### Model Selection")
323
+ model_repo_input = gr.Textbox(value=model_repo_default, label="Hugging Face Repository")
324
+ model_filename_dropdown = gr.Dropdown(
325
+ choices=default_onnx_files_with_size,
326
+ value=default_onnx_files_with_size[0] if default_onnx_files_with_size else "",
327
+ label="ONNX Model File (.onnx)"
328
+ )
329
+ update_btn = gr.Button("πŸ”„ Update/Load Model")
330
+ model_status_textbox = gr.Textbox(label="Status", value="Initial model loaded." if rmbg_model else "ERROR: Initial model failed to load.", interactive=False, lines=2)
331
+
332
+ gr.Markdown("#### Source Image")
333
+ input_img = gr.Image(label="Upload Image", type="numpy")
334
+
335
+ with gr.Row():
336
+ run_btn = gr.Button("▢️ Run Background Removal", variant="primary")
337
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Inputs/Outputs")
338
+
339
+ # Right Column: Output and Logs
340
+ with gr.Column(scale=3):
341
+ gr.Markdown("#### Output Image")
342
+ output_img = gr.Image(label="Output", image_mode="RGBA", format="png", type="numpy")
343
+ toggle_btn = gr.Button("View Mask")
344
+
345
+ gr.Markdown("---")
346
+ gr.Markdown("### Processing History")
347
+ history_gallery = gr.Gallery(label="Generated Image History", show_label=False, columns=8, object_fit="contain", height="auto")
348
+
349
+ gr.Markdown("---")
350
+ gr.Markdown("### Processing Log (`processing_log.csv`)")
351
+ # Use gr.Code for better viewing of CSV/text data
352
+ log_display = gr.Code(
353
+ value=read_log_file(), # Initial content
354
+ label="Log Viewer",
355
+ lines=10,
356
+ interactive=False
357
+ )
358
+ # Optional: Add a manual refresh button if auto-update isn't sufficient
359
+ # refresh_log_btn = gr.Button("πŸ”„ Refresh Log View")
360
+
361
+ # Hidden states
362
+ output_pair_state = gr.State([])
363
+ view_state = gr.State("generated")
364
+ history_state = gr.State([])
365
+
366
+ # --- Event Listeners ---
367
+ model_repo_input.submit(fn=update_onnx_files, inputs=model_repo_input, outputs=model_filename_dropdown)
368
+ model_repo_input.blur(fn=update_onnx_files, inputs=model_repo_input, outputs=model_filename_dropdown)
369
+ update_btn.click(fn=update_model, inputs=[model_repo_input, model_filename_dropdown], outputs=model_status_textbox)
370
+
371
+ # Run includes updating the log display
372
+ run_btn.click(
373
+ fn=process_and_update,
374
+ inputs=[input_img, model_repo_input, model_filename_dropdown, history_state],
375
+ outputs=[output_img, output_pair_state, history_state, view_state, model_status_textbox, log_display] # ADD log_display here
376
+ )
377
+
378
+ toggle_btn.click(fn=toggle_view, inputs=[view_state, output_pair_state], outputs=[output_img, view_state, toggle_btn])
379
+
380
+ # Clear resets inputs/outputs/status, but re-reads log for display
381
+ clear_btn.click(
382
+ fn=clear_all,
383
+ outputs=[input_img, output_img, output_pair_state, history_state, view_state, model_status_textbox, toggle_btn, history_gallery, log_display] # ADD log_display here
384
+ )
385
+
386
+ # Manual log refresh button (optional, as run/clear update it)
387
+ # refresh_log_btn.click(fn=read_log_file, inputs=None, outputs=log_display)
388
+
389
+ history_state.change(fn=lambda history: history, inputs=history_state, outputs=history_gallery)
390
+
391
+ app.launch(debug=True)