""" Image processing functions for the Image Tagger application. """ import os import traceback import glob def process_image(image_path, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1): """ Process a single image and return the tags. Args: image_path: Path to the image model: The image tagger model thresholds: Thresholds dictionary metadata: Metadata dictionary threshold_profile: Selected threshold profile active_threshold: Overall threshold value active_category_thresholds: Category-specific thresholds min_confidence: Minimum confidence to include in results Returns: Dictionary with tags, all probabilities, and other info """ try: # Run inference directly using the model's predict method if threshold_profile in ["Category-specific", "High Precision", "High Recall"]: results = model.predict( image_path=image_path, category_thresholds=active_category_thresholds ) else: results = model.predict( image_path=image_path, threshold=active_threshold ) # Extract and organize all probabilities all_probs = {} probs = results['refined_probabilities'][0] # Remove batch dimension for idx in range(len(probs)): prob_value = probs[idx].item() if prob_value >= min_confidence: tag, category = model.dataset.get_tag_info(idx) if category not in all_probs: all_probs[category] = [] all_probs[category].append((tag, prob_value)) # Sort tags by probability within each category for category in all_probs: all_probs[category] = sorted( all_probs[category], key=lambda x: x[1], reverse=True ) # Get the filtered tags based on the selected threshold tags = {} for category, cat_tags in all_probs.items(): threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold] # Create a flat list of all tags above threshold all_tags = [] for category, cat_tags in tags.items(): for tag, _ in cat_tags: all_tags.append(tag) return { 'tags': tags, 'all_probs': all_probs, 'all_tags': all_tags, 'success': True } except Exception as e: print(f"Error processing {image_path}: {str(e)}") traceback.print_exc() return { 'tags': {}, 'all_probs': {}, 'all_tags': [], 'success': False, 'error': str(e) } def apply_category_limits(result, category_limits): """ Apply category limits to a result dictionary. Args: result: Result dictionary containing tags and all_tags category_limits: Dictionary mapping categories to their tag limits (0 = exclude category, -1 = no limit/include all) Returns: Updated result dictionary with limits applied """ if not category_limits or not result['success']: return result # Get the filtered tags filtered_tags = result['tags'] # Apply limits to each category for category, cat_tags in list(filtered_tags.items()): # Get limit for this category, default to -1 (no limit) limit = category_limits.get(category, -1) if limit == 0: # Exclude this category entirely del filtered_tags[category] elif limit > 0 and len(cat_tags) > limit: # Limit to top N tags for this category filtered_tags[category] = cat_tags[:limit] # Regenerate all_tags list after applying limits all_tags = [] for category, cat_tags in filtered_tags.items(): for tag, _ in cat_tags: all_tags.append(tag) # Update the result with limited tags result['tags'] = filtered_tags result['all_tags'] = all_tags return result def batch_process_images(folder_path, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, save_dir=None, progress_callback=None, min_confidence=0.1, batch_size=1, category_limits=None): """ Process all images in a folder with optional batching for improved performance. Args: folder_path: Path to folder containing images model: The image tagger model thresholds: Thresholds dictionary metadata: Metadata dictionary threshold_profile: Selected threshold profile active_threshold: Overall threshold value active_category_thresholds: Category-specific thresholds save_dir: Directory to save tag files (if None uses default) progress_callback: Optional callback for progress updates min_confidence: Minimum confidence threshold batch_size: Number of images to process at once (default: 1) category_limits: Dictionary mapping categories to their tag limits (0 = unlimited) Returns: Dictionary with results for each image """ from .file_utils import save_tags_to_file # Import here to avoid circular imports import torch from PIL import Image import time print(f"Starting batch processing on {folder_path} with batch size {batch_size}") start_time = time.time() # Find all image files in the folder image_extensions = ['*.jpg', '*.jpeg', '*.png'] image_files = [] for ext in image_extensions: image_files.extend(glob.glob(os.path.join(folder_path, ext))) image_files.extend(glob.glob(os.path.join(folder_path, ext.upper()))) # Use a set to remove duplicate files (Windows filesystems are case-insensitive) if os.name == 'nt': # Windows # Use lowercase paths for comparison on Windows unique_paths = set() unique_files = [] for file_path in image_files: normalized_path = os.path.normpath(file_path).lower() if normalized_path not in unique_paths: unique_paths.add(normalized_path) unique_files.append(file_path) image_files = unique_files # Sort files for consistent processing order image_files.sort() if not image_files: return { 'success': False, 'error': f"No images found in {folder_path}", 'results': {} } print(f"Found {len(image_files)} images to process") # Use the provided save directory or create a default one if save_dir is None: app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) save_dir = os.path.join(app_dir, "saved_tags") # Ensure the directory exists os.makedirs(save_dir, exist_ok=True) # Process images in batches results = {} total_images = len(image_files) processed = 0 # Process in batches for i in range(0, total_images, batch_size): batch_start = time.time() # Get current batch of images batch_files = image_files[i:i+batch_size] batch_size_actual = len(batch_files) print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images") if batch_size > 1: # True batch processing for multiple images at once try: # Using batch processing if batch_size > 1 batch_results = process_image_batch( image_paths=batch_files, model=model, thresholds=thresholds, metadata=metadata, threshold_profile=threshold_profile, active_threshold=active_threshold, active_category_thresholds=active_category_thresholds, min_confidence=min_confidence ) # Process and save results for each image in the batch for j, image_path in enumerate(batch_files): # Update progress if callback provided if progress_callback: progress_callback(processed + j, total_images, image_path) if j < len(batch_results): result = batch_results[j] # Apply category limits if specified if category_limits and result['success']: # Use the apply_category_limits function instead of the inline code result = apply_category_limits(result, category_limits) # Debug print if you want print(f"Applied limits for {os.path.basename(image_path)}, remaining tags: {len(result['all_tags'])}") # Save the tags to a file if result['success']: output_path = save_tags_to_file( image_path=image_path, all_tags=result['all_tags'], custom_dir=save_dir, overwrite=True ) result['output_path'] = str(output_path) # Store the result results[image_path] = result else: # Handle case where batch processing returned fewer results than expected results[image_path] = { 'success': False, 'error': 'Batch processing error: missing result', 'all_tags': [] } except Exception as e: print(f"Batch processing error: {str(e)}") traceback.print_exc() # Fall back to processing images one by one in this batch for j, image_path in enumerate(batch_files): if progress_callback: progress_callback(processed + j, total_images, image_path) result = process_image( image_path=image_path, model=model, thresholds=thresholds, metadata=metadata, threshold_profile=threshold_profile, active_threshold=active_threshold, active_category_thresholds=active_category_thresholds, min_confidence=min_confidence ) # Apply category limits if specified if category_limits and result['success']: # Use the apply_category_limits function result = apply_category_limits(result, category_limits) if result['success']: output_path = save_tags_to_file( image_path=image_path, all_tags=result['all_tags'], custom_dir=save_dir, overwrite=True ) result['output_path'] = str(output_path) results[image_path] = result else: # Process one by one if batch_size is 1 for j, image_path in enumerate(batch_files): if progress_callback: progress_callback(processed + j, total_images, image_path) result = process_image( image_path=image_path, model=model, thresholds=thresholds, metadata=metadata, threshold_profile=threshold_profile, active_threshold=active_threshold, active_category_thresholds=active_category_thresholds, min_confidence=min_confidence ) # Apply category limits if specified if category_limits and result['success']: # Use the apply_category_limits function result = apply_category_limits(result, category_limits) if result['success']: output_path = save_tags_to_file( image_path=image_path, all_tags=result['all_tags'], custom_dir=save_dir, overwrite=True ) result['output_path'] = str(output_path) results[image_path] = result # Update processed count processed += batch_size_actual # Calculate batch timing batch_end = time.time() batch_time = batch_end - batch_start print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)") # Final progress update if progress_callback: progress_callback(total_images, total_images, None) end_time = time.time() total_time = end_time - start_time print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image") return { 'success': True, 'total': total_images, 'processed': len(results), 'results': results, 'save_dir': save_dir, 'time_elapsed': end_time - start_time } def process_image_batch(image_paths, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1): """ Process a batch of images at once. Args: image_paths: List of paths to the images model: The image tagger model thresholds: Thresholds dictionary metadata: Metadata dictionary threshold_profile: Selected threshold profile active_threshold: Overall threshold value active_category_thresholds: Category-specific thresholds min_confidence: Minimum confidence to include in results Returns: List of dictionaries with tags, all probabilities, and other info for each image """ try: import torch from PIL import Image import torchvision.transforms as transforms # Identify the model type we're using for better error handling model_type = model.__class__.__name__ print(f"Running batch processing with model type: {model_type}") # Prepare the transformation for the images transform = transforms.Compose([ transforms.Resize((512, 512)), # Adjust based on your model's expected input transforms.ToTensor(), ]) # Get model information device = next(model.parameters()).device dtype = next(model.parameters()).dtype print(f"Model is using device: {device}, dtype: {dtype}") # Load and preprocess all images batch_tensor = [] valid_images = [] for img_path in image_paths: try: img = Image.open(img_path).convert('RGB') img_tensor = transform(img) img_tensor = img_tensor.to(device=device, dtype=dtype) batch_tensor.append(img_tensor) valid_images.append(img_path) except Exception as e: print(f"Error loading image {img_path}: {str(e)}") if not batch_tensor: return [] # Stack all tensors into a single batch batch_input = torch.stack(batch_tensor) # Process entire batch at once with torch.no_grad(): try: # Forward pass on the whole batch output = model(batch_input) # Handle tuple output format if isinstance(output, tuple): probs_batch = torch.sigmoid(output[1]) else: probs_batch = torch.sigmoid(output) # Process each image's results results = [] for i, img_path in enumerate(valid_images): probs = probs_batch[i].unsqueeze(0) # Add batch dimension back # Extract and organize all probabilities all_probs = {} for idx in range(probs.size(1)): prob_value = probs[0, idx].item() if prob_value >= min_confidence: tag, category = model.dataset.get_tag_info(idx) if category not in all_probs: all_probs[category] = [] all_probs[category].append((tag, prob_value)) # Sort tags by probability for category in all_probs: all_probs[category] = sorted(all_probs[category], key=lambda x: x[1], reverse=True) # Get filtered tags tags = {} for category, cat_tags in all_probs.items(): threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold] # Create a flat list of all tags above threshold all_tags = [] for category, cat_tags in tags.items(): for tag, _ in cat_tags: all_tags.append(tag) results.append({ 'tags': tags, 'all_probs': all_probs, 'all_tags': all_tags, 'success': True }) return results except RuntimeError as e: # If we encounter CUDA out of memory or another runtime error, # fall back to processing one by one print(f"Error in batch processing: {str(e)}") print("Falling back to one-by-one processing...") # Process one by one as fallback results = [] for i, (img_tensor, img_path) in enumerate(zip(batch_tensor, valid_images)): try: input_tensor = img_tensor.unsqueeze(0) output = model(input_tensor) if isinstance(output, tuple): probs = torch.sigmoid(output[1]) else: probs = torch.sigmoid(output) # Same post-processing as before... # [Code omitted for brevity] except Exception as e: print(f"Error processing image {img_path}: {str(e)}") results.append({ 'tags': {}, 'all_probs': {}, 'all_tags': [], 'success': False, 'error': str(e) }) return results except Exception as e: print(f"Error in batch processing: {str(e)}") import traceback traceback.print_exc()