File size: 18,808 Bytes
31220fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
import gradio as gr
import huggingface_hub
import onnxruntime as rt
import numpy as np
import cv2
import os
import csv
import datetime
import time

# --- Constants ---
LOG_FILE = "processing_log.csv"
LOG_HEADER = [
    "Timestamp", "Repository", "Model Filename", "Model Size (MB)", 
    "Image Resolution (WxH)", "Execution Provider", "Processing Time (s)"
]

# Global variables for model and providers
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
model_repo_default = "skytnt/anime-seg"

# --- Logging Functions ---

def initialize_log_file():
    """Creates the log file and writes the header if it doesn't exist."""
    if not os.path.exists(LOG_FILE):
        try:
            with open(LOG_FILE, 'w', newline='', encoding='utf-8') as f:
                writer = csv.writer(f)
                writer.writerow(LOG_HEADER)
            print(f"Log file initialized: {LOG_FILE}")
        except IOError as e:
            print(f"Error initializing log file {LOG_FILE}: {e}")

def log_processing_event(timestamp, repo, model_filename, model_size_mb, 
                         resolution, provider, processing_time):
    """Appends a processing event to the CSV log file."""
    try:
        with open(LOG_FILE, 'a', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow([
                timestamp, repo, model_filename, f"{model_size_mb:.2f}", 
                resolution, provider, f"{processing_time:.4f}"
            ])
    except IOError as e:
        print(f"Error writing to log file {LOG_FILE}: {e}")
    except Exception as e:
        print(f"An unexpected error occurred during logging: {e}")

def read_log_file():
    """Reads the entire log file content."""
    try:
        if not os.path.exists(LOG_FILE):
            return "Log file not found."
        with open(LOG_FILE, 'r', encoding='utf-8') as f:
            # Read all lines and join them for display
            return "".join(f.readlines())
            # Alternatively, for cleaner display of CSV in a textbox:
            # reader = csv.reader(f)
            # rows = list(reader)
            # # Format header and rows nicely
            # header = rows[0]
            # data_rows = rows[1:]
            # formatted_rows = [", ".join(header)] # Join header elements
            # for row in data_rows:
            #     formatted_rows.append(", ".join(row)) # Join data elements
            # return "\n".join(formatted_rows)
    except IOError as e:
        print(f"Error reading log file {LOG_FILE}: {e}")
        return f"Error reading log file: {e}"
    except Exception as e:
        print(f"An unexpected error occurred reading log file: {e}")
        return f"Error reading log file: {e}"

# --- Helper Functions ---

def get_model_details_from_choice(choice_string: str) -> tuple[str, float | None]:
    """
    Extracts filename and size (MB) from the dropdown choice string.
    Returns (filename, size_mb) or (filename, None) if size is not parseable.
    """
    if not choice_string:
        return "", None
    parts = choice_string.split(" (")
    filename = parts[0]
    size_mb = None
    if len(parts) > 1 and parts[1].endswith(" MB)"):
        try:
            size_str = parts[1].replace(" MB)", "")
            size_mb = float(size_str)
        except ValueError:
            pass # Size couldn't be parsed
    return filename, size_mb
    
# --- Model Loading and UI Functions (Mostly unchanged, modifications marked) ---

def update_onnx_files(repo: str):
    """
    Lists .onnx files in the Hugging Face repository and updates the Dropdown with file sizes.
    """
    onnx_files_with_size = []
    try:
        api = huggingface_hub.HfApi()
        repo_info = api.model_info(repo_id=repo, files_metadata=True)
        
        for file_info in repo_info.siblings:
            if file_info.rfilename.endswith('.onnx'):
                try:
                    # Use file_info.size which is in bytes
                    size_mb = file_info.size / (1024 * 1024) if file_info.size else 0
                    onnx_files_with_size.append(f"{file_info.rfilename} ({size_mb:.2f} MB)")
                except Exception:
                    onnx_files_with_size.append(f"{file_info.rfilename} (Size N/A)")
        
        if onnx_files_with_size:
            onnx_files_with_size.sort()
            return gr.update(choices=onnx_files_with_size, value=onnx_files_with_size[0])
        else:
            return gr.update(choices=[], value="", warning=f"No .onnx files found in repository '{repo}'")
            
    except huggingface_hub.utils.RepositoryNotFoundError:
        return gr.update(choices=[], value="", error=f"Repository '{repo}' not found or access denied.")
    except Exception as e:
        print(f"Error fetching repo files for {repo}: {e}")
        return gr.update(choices=[], value="", error=f"Error fetching files: {str(e)}")

# Get default choices and filename
default_onnx_files_with_size = []
default_model_filename = ""
try:
    initial_update = update_onnx_files(model_repo_default)
    if isinstance(initial_update, gr.update) and initial_update.choices:
        default_onnx_files_with_size = initial_update.choices
        default_model_filename, _ = get_model_details_from_choice(default_onnx_files_with_size[0]) # Use helper
    else:
        default_onnx_files_with_size = ["isnetis.onnx (Size N/A)"]
        default_model_filename = "isnetis.onnx"
        print(f"Warning: Could not fetch initial ONNX files from {model_repo_default}. Using fallback '{default_model_filename}'.")
except Exception as e:
    default_onnx_files_with_size = ["isnetis.onnx (Size N/A)"]
    default_model_filename = "isnetis.onnx"
    print(f"Error during initial model fetch: {e}. Using fallback '{default_model_filename}'.")

# Global variables for current model state
current_model_repo = model_repo_default
current_model_filename = default_model_filename

# Initial download and model load
model_path = None
rmbg_model = None
try:
    print(f"Attempting initial download: {current_model_repo}/{current_model_filename}")
    if current_model_filename: # Only download if we have a filename
        model_path = huggingface_hub.hf_hub_download(current_model_repo, current_model_filename)
        rmbg_model = rt.InferenceSession(model_path, providers=providers)
        print(f"Initial model loaded successfully: {model_path}")
        print(f"Available Execution Providers: {rt.get_available_providers()}")
        print(f"Using Provider(s): {rmbg_model.get_providers()}")
    else:
         print("FATAL: No default model filename determined. Cannot load initial model.")
except Exception as e:
    print(f"FATAL: Could not download or load initial model '{current_model_repo}/{current_model_filename}'. Error: {e}")

# --- Inference Functions (Unchanged get_mask, rmbg_fn) ---
def get_mask(img, s=1024):
    if rmbg_model is None:
        raise gr.Error("Model is not loaded. Please check model selection and update status.")
    img_normalized = (img / 255.0).astype(np.float32)
    h0, w0 = img.shape[:2]
    if h0 >= w0: h, w = (s, int(s * w0 / h0))
    else: h, w = (int(s * h0 / w0), s)
    ph, pw = s - h, s - w
    img_input = np.zeros([s, s, 3], dtype=np.float32)
    resized_img = cv2.resize(img_normalized, (w, h), interpolation=cv2.INTER_AREA)
    img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = resized_img
    img_input = np.transpose(img_input, (2, 0, 1))[np.newaxis, :]
    input_name = rmbg_model.get_inputs()[0].name
    mask_output = rmbg_model.run(None, {input_name: img_input})[0][0]
    mask_processed = np.transpose(mask_output, (1, 2, 0))
    mask_processed = mask_processed[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
    mask_resized = cv2.resize(mask_processed, (w0, h0), interpolation=cv2.INTER_LINEAR)
    if mask_resized.ndim == 2: mask_resized = mask_resized[:, :, np.newaxis]
    mask_final = np.clip(mask_resized, 0, 1)
    return mask_final

def rmbg_fn(img):
    if img is None: raise gr.Error("Please provide an input image.")
    mask = get_mask(img)
    if img.dtype != np.uint8: img = (img * 255).clip(0, 255).astype(np.uint8) if img.max() <= 1.0 else img.clip(0, 255).astype(np.uint8)
    alpha_channel = (mask * 255).astype(np.uint8)
    if img.shape[2] == 3: img_out_rgba = np.concatenate([img, alpha_channel], axis=2)
    else: img_out_rgba = img.copy(); img_out_rgba[:, :, 3] = alpha_channel[:,:,0]
    mask_img_display = (mask * 255).astype(np.uint8).repeat(3, axis=2)
    return mask_img_display, img_out_rgba

# --- Model Update Function ---
def update_model(model_repo, model_filename_with_size):
    global rmbg_model, current_model_repo, current_model_filename
    model_filename, _ = get_model_details_from_choice(model_filename_with_size) # Use helper
    if not model_filename: return "Error: No model filename selected or extracted."
    if model_repo == current_model_repo and model_filename == current_model_filename:
        # Even if it's the same, report the provider being used
        current_provider = rmbg_model.get_providers()[0] if rmbg_model else "N/A"
        return f"Model already loaded: {current_model_repo}/{current_model_filename}\nUsing Provider: {current_provider}"
        
    try:
        print(f"Updating model to: {model_repo}/{model_filename}")
        model_path = huggingface_hub.hf_hub_download(model_repo, model_filename)
        new_rmbg_model = rt.InferenceSession(model_path, providers=providers)
        rmbg_model = new_rmbg_model
        current_model_repo = model_repo
        current_model_filename = model_filename
        active_provider = rmbg_model.get_providers()[0] # Get the provider actually used
        print(f"Model updated successfully: {model_path}")
        print(f"Using Provider: {active_provider}")
        return f"Model updated: {current_model_repo}/{current_model_filename}\nUsing Provider: {active_provider}"
    except huggingface_hub.utils.HfHubHTTPError as e:
         print(f"Error downloading model: {e}")
         return f"Error downloading model: {model_repo}/{model_filename}. ({e.response.status_code})"
    except rt.ONNXRuntimeException as e:
        print(f"Error loading ONNX model: {e}")
        # Attempt to provide more specific feedback if it's a provider issue
        if "CUDAExecutionProvider" in str(e):
             return f"Error loading ONNX model '{model_filename}'. CUDA unavailable or setup issue? Falling back might require restart or different build. Error: {e}"
        return f"Error loading ONNX model '{model_filename}'. Incompatible or corrupted? Error: {e}"
    except Exception as e:
        print(f"Error updating model: {e}")
        return f"Error updating model: {str(e)}"

# --- Main Processing Function (MODIFIED FOR LOGGING) ---
def process_and_update(img, model_repo, model_filename_with_size, history):
    global current_model_repo, current_model_filename, rmbg_model
    
    # --- Pre-checks ---
    if img is None:
        return None, [], history, "generated", "Please upload an image first.", read_log_file() # Return current log
    if rmbg_model is None:
         return None, [], history, "generated", "ERROR: Model not loaded. Update model first.", read_log_file() # Return current log
         
    selected_model_filename, selected_model_size_mb = get_model_details_from_choice(model_filename_with_size) # Use helper
    status_message = ""

    # --- Model Update Check ---
    if model_repo != current_model_repo or selected_model_filename != current_model_filename:
        status_message = update_model(model_repo, model_filename_with_size)
        if "Error" in status_message:
             return None, [], history, "generated", f"Model Update Failed:\n{status_message}", read_log_file() # Return current log
        if rmbg_model is None:
             return None, [], history, "generated", "ERROR: Model failed to load after update.", read_log_file() # Return current log

    # --- Processing & Logging ---
    try:
        start_time = time.time() # Start timer
        mask_img, generated_img_rgba = rmbg_fn(img) # Run inference
        end_time = time.time() # End timer
        
        processing_time = end_time - start_time # Calculate duration
        
        # --- Gather Log Information ---
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        h, w = img.shape[:2]
        resolution = f"{w}x{h}"
        # Get the *actually used* provider from the loaded session
        active_provider = rmbg_model.get_providers()[0] 
        
        # Log the event
        log_processing_event(
            timestamp=timestamp,
            repo=current_model_repo, # Use the confirmed current repo
            model_filename=current_model_filename, # Use the confirmed current filename
            model_size_mb=selected_model_size_mb if selected_model_size_mb is not None else 0.0, # Use extracted size
            resolution=resolution,
            provider=active_provider,
            processing_time=processing_time
        )
        
        # --- Prepare Outputs ---
        new_history = history + [generated_img_rgba]
        output_pair = [mask_img, generated_img_rgba]
        current_log_content = read_log_file() # Read updated log
        
        status_message = f"{status_message}\nProcessing complete ({processing_time:.2f}s)".strip()
        
        return generated_img_rgba, output_pair, new_history, "generated", status_message, current_log_content

    except Exception as e:
        print(f"Error during processing: {e}")
        import traceback
        traceback.print_exc()
        # Still return the log content even if processing fails
        return None, [], history, "generated", f"Error during processing: {str(e)}", read_log_file()


# --- UI Interaction Functions (Unchanged toggle_view, clear_all needs slight modification) ---
def toggle_view(view_state, output_pair):
    if not output_pair or len(output_pair) != 2:
        return None, view_state, "View Mask" if view_state == "generated" else "View Generated"
    if view_state == "generated":
        return output_pair[0], "mask", "View Generated"
    else:
        return output_pair[1], "generated", "View Mask"

def clear_all():
    """ Resets inputs, outputs, states, status, but keeps log view """
    # Keeps the log viewer content, as history shouldn't be wiped by clearing inputs
    initial_log_content = read_log_file() # Read log to display upon clearing
    return None, None, [], [], "generated", "Interface cleared.", "View Mask", [], initial_log_content

# --- Gradio UI Definition ---
if __name__ == "__main__":
    initialize_log_file() # Ensure log file exists before launching app

    app = gr.Blocks(css=".gradio-container { max-width: 95% !important; }") # Wider
    with app:
        gr.Markdown("# Image Background Removal (Segmentation) with Logging")
        gr.Markdown("Test ONNX models, view performance logs.")

        with gr.Row():
            # Left Column: Controls and Input
            with gr.Column(scale=2):
                 with gr.Group():
                    gr.Markdown("### Model Selection")
                    model_repo_input = gr.Textbox(value=model_repo_default, label="Hugging Face Repository")
                    model_filename_dropdown = gr.Dropdown(
                        choices=default_onnx_files_with_size,
                        value=default_onnx_files_with_size[0] if default_onnx_files_with_size else "",
                        label="ONNX Model File (.onnx)"
                    )
                    update_btn = gr.Button("πŸ”„ Update/Load Model")
                    model_status_textbox = gr.Textbox(label="Status", value="Initial model loaded." if rmbg_model else "ERROR: Initial model failed to load.", interactive=False, lines=2)

                 gr.Markdown("#### Source Image")
                 input_img = gr.Image(label="Upload Image", type="numpy")

                 with gr.Row():
                     run_btn = gr.Button("▢️ Run Background Removal", variant="primary")
                     clear_btn = gr.Button("πŸ—‘οΈ Clear Inputs/Outputs")

            # Right Column: Output and Logs
            with gr.Column(scale=3):
                 gr.Markdown("#### Output Image")
                 output_img = gr.Image(label="Output", image_mode="RGBA", format="png", type="numpy")
                 toggle_btn = gr.Button("View Mask")

                 gr.Markdown("---")
                 gr.Markdown("### Processing History")
                 history_gallery = gr.Gallery(label="Generated Image History", show_label=False, columns=8, object_fit="contain", height="auto")

                 gr.Markdown("---")
                 gr.Markdown("### Processing Log (`processing_log.csv`)")
                 # Use gr.Code for better viewing of CSV/text data
                 log_display = gr.Code(
                     value=read_log_file(), # Initial content
                     label="Log Viewer",
                     lines=10,
                     interactive=False
                 )
                 # Optional: Add a manual refresh button if auto-update isn't sufficient
                 # refresh_log_btn = gr.Button("πŸ”„ Refresh Log View")

        # Hidden states
        output_pair_state = gr.State([])
        view_state = gr.State("generated")
        history_state = gr.State([])

        # --- Event Listeners ---
        model_repo_input.submit(fn=update_onnx_files, inputs=model_repo_input, outputs=model_filename_dropdown)
        model_repo_input.blur(fn=update_onnx_files, inputs=model_repo_input, outputs=model_filename_dropdown)
        update_btn.click(fn=update_model, inputs=[model_repo_input, model_filename_dropdown], outputs=model_status_textbox)

        # Run includes updating the log display
        run_btn.click(
            fn=process_and_update,
            inputs=[input_img, model_repo_input, model_filename_dropdown, history_state],
            outputs=[output_img, output_pair_state, history_state, view_state, model_status_textbox, log_display] # ADD log_display here
        )

        toggle_btn.click(fn=toggle_view, inputs=[view_state, output_pair_state], outputs=[output_img, view_state, toggle_btn])

        # Clear resets inputs/outputs/status, but re-reads log for display
        clear_btn.click(
            fn=clear_all,
            outputs=[input_img, output_img, output_pair_state, history_state, view_state, model_status_textbox, toggle_btn, history_gallery, log_display] # ADD log_display here
        )

        # Manual log refresh button (optional, as run/clear update it)
        # refresh_log_btn.click(fn=read_log_file, inputs=None, outputs=log_display)

        history_state.change(fn=lambda history: history, inputs=history_state, outputs=history_gallery)

    app.launch(debug=True)