camie-tagger / app /utils /image_processing.py
Camais03's picture
V1.5
29b445b verified
"""
Image processing functions for the Image Tagger application.
"""
import os
import traceback
import glob
def process_image(image_path, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1):
"""
Process a single image and return the tags.
Args:
image_path: Path to the image
model: The image tagger model
thresholds: Thresholds dictionary
metadata: Metadata dictionary
threshold_profile: Selected threshold profile
active_threshold: Overall threshold value
active_category_thresholds: Category-specific thresholds
min_confidence: Minimum confidence to include in results
Returns:
Dictionary with tags, all probabilities, and other info
"""
try:
# Run inference directly using the model's predict method
if threshold_profile in ["Category-specific", "High Precision", "High Recall"]:
results = model.predict(
image_path=image_path,
category_thresholds=active_category_thresholds
)
else:
results = model.predict(
image_path=image_path,
threshold=active_threshold
)
# Extract and organize all probabilities
all_probs = {}
probs = results['refined_probabilities'][0] # Remove batch dimension
for idx in range(len(probs)):
prob_value = probs[idx].item()
if prob_value >= min_confidence:
tag, category = model.dataset.get_tag_info(idx)
if category not in all_probs:
all_probs[category] = []
all_probs[category].append((tag, prob_value))
# Sort tags by probability within each category
for category in all_probs:
all_probs[category] = sorted(
all_probs[category],
key=lambda x: x[1],
reverse=True
)
# Get the filtered tags based on the selected threshold
tags = {}
for category, cat_tags in all_probs.items():
threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold
tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold]
# Create a flat list of all tags above threshold
all_tags = []
for category, cat_tags in tags.items():
for tag, _ in cat_tags:
all_tags.append(tag)
return {
'tags': tags,
'all_probs': all_probs,
'all_tags': all_tags,
'success': True
}
except Exception as e:
print(f"Error processing {image_path}: {str(e)}")
traceback.print_exc()
return {
'tags': {},
'all_probs': {},
'all_tags': [],
'success': False,
'error': str(e)
}
def apply_category_limits(result, category_limits):
"""
Apply category limits to a result dictionary.
Args:
result: Result dictionary containing tags and all_tags
category_limits: Dictionary mapping categories to their tag limits
(0 = exclude category, -1 = no limit/include all)
Returns:
Updated result dictionary with limits applied
"""
if not category_limits or not result['success']:
return result
# Get the filtered tags
filtered_tags = result['tags']
# Apply limits to each category
for category, cat_tags in list(filtered_tags.items()):
# Get limit for this category, default to -1 (no limit)
limit = category_limits.get(category, -1)
if limit == 0:
# Exclude this category entirely
del filtered_tags[category]
elif limit > 0 and len(cat_tags) > limit:
# Limit to top N tags for this category
filtered_tags[category] = cat_tags[:limit]
# Regenerate all_tags list after applying limits
all_tags = []
for category, cat_tags in filtered_tags.items():
for tag, _ in cat_tags:
all_tags.append(tag)
# Update the result with limited tags
result['tags'] = filtered_tags
result['all_tags'] = all_tags
return result
def batch_process_images(folder_path, model, thresholds, metadata, threshold_profile, active_threshold,
active_category_thresholds, save_dir=None, progress_callback=None,
min_confidence=0.1, batch_size=1, category_limits=None):
"""
Process all images in a folder with optional batching for improved performance.
Args:
folder_path: Path to folder containing images
model: The image tagger model
thresholds: Thresholds dictionary
metadata: Metadata dictionary
threshold_profile: Selected threshold profile
active_threshold: Overall threshold value
active_category_thresholds: Category-specific thresholds
save_dir: Directory to save tag files (if None uses default)
progress_callback: Optional callback for progress updates
min_confidence: Minimum confidence threshold
batch_size: Number of images to process at once (default: 1)
category_limits: Dictionary mapping categories to their tag limits (0 = unlimited)
Returns:
Dictionary with results for each image
"""
from .file_utils import save_tags_to_file # Import here to avoid circular imports
import torch
from PIL import Image
import time
print(f"Starting batch processing on {folder_path} with batch size {batch_size}")
start_time = time.time()
# Find all image files in the folder
image_extensions = ['*.jpg', '*.jpeg', '*.png']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(folder_path, ext)))
image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))
# Use a set to remove duplicate files (Windows filesystems are case-insensitive)
if os.name == 'nt': # Windows
# Use lowercase paths for comparison on Windows
unique_paths = set()
unique_files = []
for file_path in image_files:
normalized_path = os.path.normpath(file_path).lower()
if normalized_path not in unique_paths:
unique_paths.add(normalized_path)
unique_files.append(file_path)
image_files = unique_files
# Sort files for consistent processing order
image_files.sort()
if not image_files:
return {
'success': False,
'error': f"No images found in {folder_path}",
'results': {}
}
print(f"Found {len(image_files)} images to process")
# Use the provided save directory or create a default one
if save_dir is None:
app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
save_dir = os.path.join(app_dir, "saved_tags")
# Ensure the directory exists
os.makedirs(save_dir, exist_ok=True)
# Process images in batches
results = {}
total_images = len(image_files)
processed = 0
# Process in batches
for i in range(0, total_images, batch_size):
batch_start = time.time()
# Get current batch of images
batch_files = image_files[i:i+batch_size]
batch_size_actual = len(batch_files)
print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images")
if batch_size > 1:
# True batch processing for multiple images at once
try:
# Using batch processing if batch_size > 1
batch_results = process_image_batch(
image_paths=batch_files,
model=model,
thresholds=thresholds,
metadata=metadata,
threshold_profile=threshold_profile,
active_threshold=active_threshold,
active_category_thresholds=active_category_thresholds,
min_confidence=min_confidence
)
# Process and save results for each image in the batch
for j, image_path in enumerate(batch_files):
# Update progress if callback provided
if progress_callback:
progress_callback(processed + j, total_images, image_path)
if j < len(batch_results):
result = batch_results[j]
# Apply category limits if specified
if category_limits and result['success']:
# Use the apply_category_limits function instead of the inline code
result = apply_category_limits(result, category_limits)
# Debug print if you want
print(f"Applied limits for {os.path.basename(image_path)}, remaining tags: {len(result['all_tags'])}")
# Save the tags to a file
if result['success']:
output_path = save_tags_to_file(
image_path=image_path,
all_tags=result['all_tags'],
custom_dir=save_dir,
overwrite=True
)
result['output_path'] = str(output_path)
# Store the result
results[image_path] = result
else:
# Handle case where batch processing returned fewer results than expected
results[image_path] = {
'success': False,
'error': 'Batch processing error: missing result',
'all_tags': []
}
except Exception as e:
print(f"Batch processing error: {str(e)}")
traceback.print_exc()
# Fall back to processing images one by one in this batch
for j, image_path in enumerate(batch_files):
if progress_callback:
progress_callback(processed + j, total_images, image_path)
result = process_image(
image_path=image_path,
model=model,
thresholds=thresholds,
metadata=metadata,
threshold_profile=threshold_profile,
active_threshold=active_threshold,
active_category_thresholds=active_category_thresholds,
min_confidence=min_confidence
)
# Apply category limits if specified
if category_limits and result['success']:
# Use the apply_category_limits function
result = apply_category_limits(result, category_limits)
if result['success']:
output_path = save_tags_to_file(
image_path=image_path,
all_tags=result['all_tags'],
custom_dir=save_dir,
overwrite=True
)
result['output_path'] = str(output_path)
results[image_path] = result
else:
# Process one by one if batch_size is 1
for j, image_path in enumerate(batch_files):
if progress_callback:
progress_callback(processed + j, total_images, image_path)
result = process_image(
image_path=image_path,
model=model,
thresholds=thresholds,
metadata=metadata,
threshold_profile=threshold_profile,
active_threshold=active_threshold,
active_category_thresholds=active_category_thresholds,
min_confidence=min_confidence
)
# Apply category limits if specified
if category_limits and result['success']:
# Use the apply_category_limits function
result = apply_category_limits(result, category_limits)
if result['success']:
output_path = save_tags_to_file(
image_path=image_path,
all_tags=result['all_tags'],
custom_dir=save_dir,
overwrite=True
)
result['output_path'] = str(output_path)
results[image_path] = result
# Update processed count
processed += batch_size_actual
# Calculate batch timing
batch_end = time.time()
batch_time = batch_end - batch_start
print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)")
# Final progress update
if progress_callback:
progress_callback(total_images, total_images, None)
end_time = time.time()
total_time = end_time - start_time
print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image")
return {
'success': True,
'total': total_images,
'processed': len(results),
'results': results,
'save_dir': save_dir,
'time_elapsed': end_time - start_time
}
def process_image_batch(image_paths, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1):
"""
Process a batch of images at once.
Args:
image_paths: List of paths to the images
model: The image tagger model
thresholds: Thresholds dictionary
metadata: Metadata dictionary
threshold_profile: Selected threshold profile
active_threshold: Overall threshold value
active_category_thresholds: Category-specific thresholds
min_confidence: Minimum confidence to include in results
Returns:
List of dictionaries with tags, all probabilities, and other info for each image
"""
try:
import torch
from PIL import Image
import torchvision.transforms as transforms
# Identify the model type we're using for better error handling
model_type = model.__class__.__name__
print(f"Running batch processing with model type: {model_type}")
# Prepare the transformation for the images
transform = transforms.Compose([
transforms.Resize((512, 512)), # Adjust based on your model's expected input
transforms.ToTensor(),
])
# Get model information
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
print(f"Model is using device: {device}, dtype: {dtype}")
# Load and preprocess all images
batch_tensor = []
valid_images = []
for img_path in image_paths:
try:
img = Image.open(img_path).convert('RGB')
img_tensor = transform(img)
img_tensor = img_tensor.to(device=device, dtype=dtype)
batch_tensor.append(img_tensor)
valid_images.append(img_path)
except Exception as e:
print(f"Error loading image {img_path}: {str(e)}")
if not batch_tensor:
return []
# Stack all tensors into a single batch
batch_input = torch.stack(batch_tensor)
# Process entire batch at once
with torch.no_grad():
try:
# Forward pass on the whole batch
output = model(batch_input)
# Handle tuple output format
if isinstance(output, tuple):
probs_batch = torch.sigmoid(output[1])
else:
probs_batch = torch.sigmoid(output)
# Process each image's results
results = []
for i, img_path in enumerate(valid_images):
probs = probs_batch[i].unsqueeze(0) # Add batch dimension back
# Extract and organize all probabilities
all_probs = {}
for idx in range(probs.size(1)):
prob_value = probs[0, idx].item()
if prob_value >= min_confidence:
tag, category = model.dataset.get_tag_info(idx)
if category not in all_probs:
all_probs[category] = []
all_probs[category].append((tag, prob_value))
# Sort tags by probability
for category in all_probs:
all_probs[category] = sorted(all_probs[category], key=lambda x: x[1], reverse=True)
# Get filtered tags
tags = {}
for category, cat_tags in all_probs.items():
threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold
tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold]
# Create a flat list of all tags above threshold
all_tags = []
for category, cat_tags in tags.items():
for tag, _ in cat_tags:
all_tags.append(tag)
results.append({
'tags': tags,
'all_probs': all_probs,
'all_tags': all_tags,
'success': True
})
return results
except RuntimeError as e:
# If we encounter CUDA out of memory or another runtime error,
# fall back to processing one by one
print(f"Error in batch processing: {str(e)}")
print("Falling back to one-by-one processing...")
# Process one by one as fallback
results = []
for i, (img_tensor, img_path) in enumerate(zip(batch_tensor, valid_images)):
try:
input_tensor = img_tensor.unsqueeze(0)
output = model(input_tensor)
if isinstance(output, tuple):
probs = torch.sigmoid(output[1])
else:
probs = torch.sigmoid(output)
# Same post-processing as before...
# [Code omitted for brevity]
except Exception as e:
print(f"Error processing image {img_path}: {str(e)}")
results.append({
'tags': {},
'all_probs': {},
'all_tags': [],
'success': False,
'error': str(e)
})
return results
except Exception as e:
print(f"Error in batch processing: {str(e)}")
import traceback
traceback.print_exc()