|
"""
|
|
Image processing functions for the Image Tagger application.
|
|
"""
|
|
|
|
import os
|
|
import traceback
|
|
import glob
|
|
|
|
|
|
def process_image(image_path, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1):
|
|
"""
|
|
Process a single image and return the tags.
|
|
|
|
Args:
|
|
image_path: Path to the image
|
|
model: The image tagger model
|
|
thresholds: Thresholds dictionary
|
|
metadata: Metadata dictionary
|
|
threshold_profile: Selected threshold profile
|
|
active_threshold: Overall threshold value
|
|
active_category_thresholds: Category-specific thresholds
|
|
min_confidence: Minimum confidence to include in results
|
|
|
|
Returns:
|
|
Dictionary with tags, all probabilities, and other info
|
|
"""
|
|
try:
|
|
|
|
if threshold_profile in ["Category-specific", "High Precision", "High Recall"]:
|
|
results = model.predict(
|
|
image_path=image_path,
|
|
category_thresholds=active_category_thresholds
|
|
)
|
|
else:
|
|
results = model.predict(
|
|
image_path=image_path,
|
|
threshold=active_threshold
|
|
)
|
|
|
|
|
|
all_probs = {}
|
|
probs = results['refined_probabilities'][0]
|
|
|
|
for idx in range(len(probs)):
|
|
prob_value = probs[idx].item()
|
|
if prob_value >= min_confidence:
|
|
tag, category = model.dataset.get_tag_info(idx)
|
|
|
|
if category not in all_probs:
|
|
all_probs[category] = []
|
|
|
|
all_probs[category].append((tag, prob_value))
|
|
|
|
|
|
for category in all_probs:
|
|
all_probs[category] = sorted(
|
|
all_probs[category],
|
|
key=lambda x: x[1],
|
|
reverse=True
|
|
)
|
|
|
|
|
|
tags = {}
|
|
for category, cat_tags in all_probs.items():
|
|
threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold
|
|
tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold]
|
|
|
|
|
|
all_tags = []
|
|
for category, cat_tags in tags.items():
|
|
for tag, _ in cat_tags:
|
|
all_tags.append(tag)
|
|
|
|
return {
|
|
'tags': tags,
|
|
'all_probs': all_probs,
|
|
'all_tags': all_tags,
|
|
'success': True
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Error processing {image_path}: {str(e)}")
|
|
traceback.print_exc()
|
|
return {
|
|
'tags': {},
|
|
'all_probs': {},
|
|
'all_tags': [],
|
|
'success': False,
|
|
'error': str(e)
|
|
}
|
|
|
|
def apply_category_limits(result, category_limits):
|
|
"""
|
|
Apply category limits to a result dictionary.
|
|
|
|
Args:
|
|
result: Result dictionary containing tags and all_tags
|
|
category_limits: Dictionary mapping categories to their tag limits
|
|
(0 = exclude category, -1 = no limit/include all)
|
|
|
|
Returns:
|
|
Updated result dictionary with limits applied
|
|
"""
|
|
if not category_limits or not result['success']:
|
|
return result
|
|
|
|
|
|
filtered_tags = result['tags']
|
|
|
|
|
|
for category, cat_tags in list(filtered_tags.items()):
|
|
|
|
limit = category_limits.get(category, -1)
|
|
|
|
if limit == 0:
|
|
|
|
del filtered_tags[category]
|
|
elif limit > 0 and len(cat_tags) > limit:
|
|
|
|
filtered_tags[category] = cat_tags[:limit]
|
|
|
|
|
|
all_tags = []
|
|
for category, cat_tags in filtered_tags.items():
|
|
for tag, _ in cat_tags:
|
|
all_tags.append(tag)
|
|
|
|
|
|
result['tags'] = filtered_tags
|
|
result['all_tags'] = all_tags
|
|
|
|
return result
|
|
|
|
def batch_process_images(folder_path, model, thresholds, metadata, threshold_profile, active_threshold,
|
|
active_category_thresholds, save_dir=None, progress_callback=None,
|
|
min_confidence=0.1, batch_size=1, category_limits=None):
|
|
"""
|
|
Process all images in a folder with optional batching for improved performance.
|
|
|
|
Args:
|
|
folder_path: Path to folder containing images
|
|
model: The image tagger model
|
|
thresholds: Thresholds dictionary
|
|
metadata: Metadata dictionary
|
|
threshold_profile: Selected threshold profile
|
|
active_threshold: Overall threshold value
|
|
active_category_thresholds: Category-specific thresholds
|
|
save_dir: Directory to save tag files (if None uses default)
|
|
progress_callback: Optional callback for progress updates
|
|
min_confidence: Minimum confidence threshold
|
|
batch_size: Number of images to process at once (default: 1)
|
|
category_limits: Dictionary mapping categories to their tag limits (0 = unlimited)
|
|
|
|
Returns:
|
|
Dictionary with results for each image
|
|
"""
|
|
from .file_utils import save_tags_to_file
|
|
import torch
|
|
from PIL import Image
|
|
import time
|
|
|
|
print(f"Starting batch processing on {folder_path} with batch size {batch_size}")
|
|
start_time = time.time()
|
|
|
|
|
|
image_extensions = ['*.jpg', '*.jpeg', '*.png']
|
|
image_files = []
|
|
|
|
for ext in image_extensions:
|
|
image_files.extend(glob.glob(os.path.join(folder_path, ext)))
|
|
image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))
|
|
|
|
|
|
if os.name == 'nt':
|
|
|
|
unique_paths = set()
|
|
unique_files = []
|
|
for file_path in image_files:
|
|
normalized_path = os.path.normpath(file_path).lower()
|
|
if normalized_path not in unique_paths:
|
|
unique_paths.add(normalized_path)
|
|
unique_files.append(file_path)
|
|
image_files = unique_files
|
|
|
|
|
|
image_files.sort()
|
|
|
|
if not image_files:
|
|
return {
|
|
'success': False,
|
|
'error': f"No images found in {folder_path}",
|
|
'results': {}
|
|
}
|
|
|
|
print(f"Found {len(image_files)} images to process")
|
|
|
|
|
|
if save_dir is None:
|
|
app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
save_dir = os.path.join(app_dir, "saved_tags")
|
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
|
|
results = {}
|
|
total_images = len(image_files)
|
|
processed = 0
|
|
|
|
|
|
for i in range(0, total_images, batch_size):
|
|
batch_start = time.time()
|
|
|
|
batch_files = image_files[i:i+batch_size]
|
|
batch_size_actual = len(batch_files)
|
|
|
|
print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images")
|
|
|
|
if batch_size > 1:
|
|
|
|
try:
|
|
|
|
batch_results = process_image_batch(
|
|
image_paths=batch_files,
|
|
model=model,
|
|
thresholds=thresholds,
|
|
metadata=metadata,
|
|
threshold_profile=threshold_profile,
|
|
active_threshold=active_threshold,
|
|
active_category_thresholds=active_category_thresholds,
|
|
min_confidence=min_confidence
|
|
)
|
|
|
|
|
|
for j, image_path in enumerate(batch_files):
|
|
|
|
if progress_callback:
|
|
progress_callback(processed + j, total_images, image_path)
|
|
|
|
if j < len(batch_results):
|
|
result = batch_results[j]
|
|
|
|
|
|
if category_limits and result['success']:
|
|
|
|
result = apply_category_limits(result, category_limits)
|
|
|
|
|
|
print(f"Applied limits for {os.path.basename(image_path)}, remaining tags: {len(result['all_tags'])}")
|
|
|
|
|
|
if result['success']:
|
|
output_path = save_tags_to_file(
|
|
image_path=image_path,
|
|
all_tags=result['all_tags'],
|
|
custom_dir=save_dir,
|
|
overwrite=True
|
|
)
|
|
result['output_path'] = str(output_path)
|
|
|
|
|
|
results[image_path] = result
|
|
else:
|
|
|
|
results[image_path] = {
|
|
'success': False,
|
|
'error': 'Batch processing error: missing result',
|
|
'all_tags': []
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Batch processing error: {str(e)}")
|
|
traceback.print_exc()
|
|
|
|
|
|
for j, image_path in enumerate(batch_files):
|
|
if progress_callback:
|
|
progress_callback(processed + j, total_images, image_path)
|
|
|
|
result = process_image(
|
|
image_path=image_path,
|
|
model=model,
|
|
thresholds=thresholds,
|
|
metadata=metadata,
|
|
threshold_profile=threshold_profile,
|
|
active_threshold=active_threshold,
|
|
active_category_thresholds=active_category_thresholds,
|
|
min_confidence=min_confidence
|
|
)
|
|
|
|
|
|
if category_limits and result['success']:
|
|
|
|
result = apply_category_limits(result, category_limits)
|
|
|
|
if result['success']:
|
|
output_path = save_tags_to_file(
|
|
image_path=image_path,
|
|
all_tags=result['all_tags'],
|
|
custom_dir=save_dir,
|
|
overwrite=True
|
|
)
|
|
result['output_path'] = str(output_path)
|
|
|
|
results[image_path] = result
|
|
else:
|
|
|
|
for j, image_path in enumerate(batch_files):
|
|
if progress_callback:
|
|
progress_callback(processed + j, total_images, image_path)
|
|
|
|
result = process_image(
|
|
image_path=image_path,
|
|
model=model,
|
|
thresholds=thresholds,
|
|
metadata=metadata,
|
|
threshold_profile=threshold_profile,
|
|
active_threshold=active_threshold,
|
|
active_category_thresholds=active_category_thresholds,
|
|
min_confidence=min_confidence
|
|
)
|
|
|
|
|
|
if category_limits and result['success']:
|
|
|
|
result = apply_category_limits(result, category_limits)
|
|
|
|
if result['success']:
|
|
output_path = save_tags_to_file(
|
|
image_path=image_path,
|
|
all_tags=result['all_tags'],
|
|
custom_dir=save_dir,
|
|
overwrite=True
|
|
)
|
|
result['output_path'] = str(output_path)
|
|
|
|
results[image_path] = result
|
|
|
|
|
|
processed += batch_size_actual
|
|
|
|
|
|
batch_end = time.time()
|
|
batch_time = batch_end - batch_start
|
|
print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)")
|
|
|
|
|
|
if progress_callback:
|
|
progress_callback(total_images, total_images, None)
|
|
|
|
end_time = time.time()
|
|
total_time = end_time - start_time
|
|
print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image")
|
|
|
|
return {
|
|
'success': True,
|
|
'total': total_images,
|
|
'processed': len(results),
|
|
'results': results,
|
|
'save_dir': save_dir,
|
|
'time_elapsed': end_time - start_time
|
|
}
|
|
|
|
def process_image_batch(image_paths, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1):
|
|
"""
|
|
Process a batch of images at once.
|
|
|
|
Args:
|
|
image_paths: List of paths to the images
|
|
model: The image tagger model
|
|
thresholds: Thresholds dictionary
|
|
metadata: Metadata dictionary
|
|
threshold_profile: Selected threshold profile
|
|
active_threshold: Overall threshold value
|
|
active_category_thresholds: Category-specific thresholds
|
|
min_confidence: Minimum confidence to include in results
|
|
|
|
Returns:
|
|
List of dictionaries with tags, all probabilities, and other info for each image
|
|
"""
|
|
try:
|
|
import torch
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
|
|
|
|
model_type = model.__class__.__name__
|
|
print(f"Running batch processing with model type: {model_type}")
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((512, 512)),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
|
|
device = next(model.parameters()).device
|
|
dtype = next(model.parameters()).dtype
|
|
print(f"Model is using device: {device}, dtype: {dtype}")
|
|
|
|
|
|
batch_tensor = []
|
|
valid_images = []
|
|
|
|
for img_path in image_paths:
|
|
try:
|
|
img = Image.open(img_path).convert('RGB')
|
|
img_tensor = transform(img)
|
|
img_tensor = img_tensor.to(device=device, dtype=dtype)
|
|
batch_tensor.append(img_tensor)
|
|
valid_images.append(img_path)
|
|
except Exception as e:
|
|
print(f"Error loading image {img_path}: {str(e)}")
|
|
|
|
if not batch_tensor:
|
|
return []
|
|
|
|
|
|
batch_input = torch.stack(batch_tensor)
|
|
|
|
|
|
with torch.no_grad():
|
|
try:
|
|
|
|
output = model(batch_input)
|
|
|
|
|
|
if isinstance(output, tuple):
|
|
probs_batch = torch.sigmoid(output[1])
|
|
else:
|
|
probs_batch = torch.sigmoid(output)
|
|
|
|
|
|
results = []
|
|
for i, img_path in enumerate(valid_images):
|
|
probs = probs_batch[i].unsqueeze(0)
|
|
|
|
|
|
all_probs = {}
|
|
for idx in range(probs.size(1)):
|
|
prob_value = probs[0, idx].item()
|
|
if prob_value >= min_confidence:
|
|
tag, category = model.dataset.get_tag_info(idx)
|
|
|
|
if category not in all_probs:
|
|
all_probs[category] = []
|
|
|
|
all_probs[category].append((tag, prob_value))
|
|
|
|
|
|
for category in all_probs:
|
|
all_probs[category] = sorted(all_probs[category], key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
tags = {}
|
|
for category, cat_tags in all_probs.items():
|
|
threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold
|
|
tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold]
|
|
|
|
|
|
all_tags = []
|
|
for category, cat_tags in tags.items():
|
|
for tag, _ in cat_tags:
|
|
all_tags.append(tag)
|
|
|
|
results.append({
|
|
'tags': tags,
|
|
'all_probs': all_probs,
|
|
'all_tags': all_tags,
|
|
'success': True
|
|
})
|
|
|
|
return results
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
|
print(f"Error in batch processing: {str(e)}")
|
|
print("Falling back to one-by-one processing...")
|
|
|
|
|
|
results = []
|
|
for i, (img_tensor, img_path) in enumerate(zip(batch_tensor, valid_images)):
|
|
try:
|
|
input_tensor = img_tensor.unsqueeze(0)
|
|
output = model(input_tensor)
|
|
|
|
if isinstance(output, tuple):
|
|
probs = torch.sigmoid(output[1])
|
|
else:
|
|
probs = torch.sigmoid(output)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
print(f"Error processing image {img_path}: {str(e)}")
|
|
results.append({
|
|
'tags': {},
|
|
'all_probs': {},
|
|
'all_tags': [],
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
print(f"Error in batch processing: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc() |