|
"""
|
|
ONNX-based batch image processing for the Image Tagger application.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import time
|
|
import traceback
|
|
import numpy as np
|
|
import glob
|
|
import onnxruntime as ort
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
def preprocess_image(image_path, image_size=512):
|
|
"""Process an image for inference"""
|
|
if not os.path.exists(image_path):
|
|
raise ValueError(f"Image not found at path: {image_path}")
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
try:
|
|
with Image.open(image_path) as img:
|
|
|
|
if img.mode in ('RGBA', 'P'):
|
|
img = img.convert('RGB')
|
|
|
|
|
|
width, height = img.size
|
|
aspect_ratio = width / height
|
|
|
|
|
|
if aspect_ratio > 1:
|
|
new_width = image_size
|
|
new_height = int(new_width / aspect_ratio)
|
|
else:
|
|
new_height = image_size
|
|
new_width = int(new_height * aspect_ratio)
|
|
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
|
|
|
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
|
|
paste_x = (image_size - new_width) // 2
|
|
paste_y = (image_size - new_height) // 2
|
|
new_image.paste(img, (paste_x, paste_y))
|
|
|
|
|
|
img_tensor = transform(new_image)
|
|
return img_tensor.numpy()
|
|
except Exception as e:
|
|
raise Exception(f"Error processing {image_path}: {str(e)}")
|
|
|
|
def process_single_image_onnx(image_path, model_path, metadata, threshold_profile="Overall",
|
|
active_threshold=0.35, active_category_thresholds=None,
|
|
min_confidence=0.1):
|
|
"""
|
|
Process a single image using ONNX model
|
|
|
|
Args:
|
|
image_path: Path to the image file
|
|
model_path: Path to the ONNX model file
|
|
metadata: Model metadata dictionary
|
|
threshold_profile: The threshold profile being used
|
|
active_threshold: Overall threshold value
|
|
active_category_thresholds: Category-specific thresholds
|
|
min_confidence: Minimum confidence to include in results
|
|
|
|
Returns:
|
|
Dictionary with tags and probabilities
|
|
"""
|
|
import time
|
|
|
|
try:
|
|
|
|
if hasattr(process_single_image_onnx, 'tagger'):
|
|
tagger = process_single_image_onnx.tagger
|
|
else:
|
|
|
|
metadata_path = model_path.replace('.onnx', '_metadata.json')
|
|
if not os.path.exists(metadata_path):
|
|
metadata_path = model_path.replace('.onnx', '') + '_metadata.json'
|
|
|
|
|
|
tagger = ONNXImageTagger(model_path, metadata_path)
|
|
|
|
process_single_image_onnx.tagger = tagger
|
|
|
|
|
|
start_time = time.time()
|
|
img_array = preprocess_image(image_path)
|
|
|
|
|
|
results = tagger.predict_batch(
|
|
[img_array],
|
|
threshold=active_threshold,
|
|
category_thresholds=active_category_thresholds,
|
|
min_confidence=min_confidence
|
|
)
|
|
inference_time = time.time() - start_time
|
|
|
|
if results:
|
|
result = results[0]
|
|
result['inference_time'] = inference_time
|
|
return result
|
|
else:
|
|
return {
|
|
'success': False,
|
|
'error': 'Failed to process image',
|
|
'all_tags': [],
|
|
'all_probs': {},
|
|
'tags': {}
|
|
}
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
print(f"Error in process_single_image_onnx: {str(e)}")
|
|
traceback.print_exc()
|
|
return {
|
|
'success': False,
|
|
'error': str(e),
|
|
'all_tags': [],
|
|
'all_probs': {},
|
|
'tags': {}
|
|
}
|
|
|
|
def preprocess_images_parallel(image_paths, image_size=512, max_workers=8):
|
|
"""Process multiple images in parallel"""
|
|
processed_images = []
|
|
valid_paths = []
|
|
|
|
|
|
def process_single_image(path):
|
|
try:
|
|
return preprocess_image(path, image_size), path
|
|
except Exception as e:
|
|
print(f"Error processing {path}: {str(e)}")
|
|
return None, path
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
results = list(executor.map(process_single_image, image_paths))
|
|
|
|
|
|
for img_array, path in results:
|
|
if img_array is not None:
|
|
processed_images.append(img_array)
|
|
valid_paths.append(path)
|
|
|
|
return processed_images, valid_paths
|
|
|
|
def apply_category_limits(result, category_limits):
|
|
"""
|
|
Apply category limits to a result dictionary.
|
|
|
|
Args:
|
|
result: Result dictionary containing tags and all_tags
|
|
category_limits: Dictionary mapping categories to their tag limits
|
|
(0 = exclude category, -1 = no limit/include all)
|
|
|
|
Returns:
|
|
Updated result dictionary with limits applied
|
|
"""
|
|
if not category_limits or not result['success']:
|
|
return result
|
|
|
|
|
|
filtered_tags = result['tags']
|
|
|
|
|
|
for category, cat_tags in list(filtered_tags.items()):
|
|
|
|
limit = category_limits.get(category, -1)
|
|
|
|
if limit == 0:
|
|
|
|
del filtered_tags[category]
|
|
elif limit > 0 and len(cat_tags) > limit:
|
|
|
|
filtered_tags[category] = cat_tags[:limit]
|
|
|
|
|
|
all_tags = []
|
|
for category, cat_tags in filtered_tags.items():
|
|
for tag, _ in cat_tags:
|
|
all_tags.append(tag)
|
|
|
|
|
|
result['tags'] = filtered_tags
|
|
result['all_tags'] = all_tags
|
|
|
|
return result
|
|
|
|
class ONNXImageTagger:
|
|
"""ONNX-based image tagger for fast batch inference"""
|
|
|
|
def __init__(self, model_path, metadata_path):
|
|
|
|
self.model_path = model_path
|
|
try:
|
|
self.session = ort.InferenceSession(
|
|
model_path,
|
|
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
)
|
|
print(f"Using providers: {self.session.get_providers()}")
|
|
except Exception as e:
|
|
print(f"CUDA not available, using CPU: {e}")
|
|
self.session = ort.InferenceSession(
|
|
model_path,
|
|
providers=['CPUExecutionProvider']
|
|
)
|
|
print(f"Using providers: {self.session.get_providers()}")
|
|
|
|
|
|
with open(metadata_path, 'r') as f:
|
|
self.metadata = json.load(f)
|
|
|
|
|
|
self.input_name = self.session.get_inputs()[0].name
|
|
print(f"Model loaded successfully. Input name: {self.input_name}")
|
|
|
|
def predict_batch(self, image_arrays, threshold=0.325, category_thresholds=None, min_confidence=0.1):
|
|
"""Run batch inference on preprocessed image arrays"""
|
|
|
|
batch_input = np.stack(image_arrays)
|
|
|
|
|
|
start_time = time.time()
|
|
outputs = self.session.run(None, {self.input_name: batch_input})
|
|
inference_time = time.time() - start_time
|
|
print(f"Batch inference completed in {inference_time:.4f} seconds ({inference_time/len(image_arrays):.4f} s/image)")
|
|
|
|
|
|
initial_probs = 1.0 / (1.0 + np.exp(-outputs[0]))
|
|
refined_probs = 1.0 / (1.0 + np.exp(-outputs[1])) if len(outputs) > 1 else initial_probs
|
|
|
|
|
|
batch_results = []
|
|
|
|
for i in range(refined_probs.shape[0]):
|
|
probs = refined_probs[i]
|
|
|
|
|
|
all_probs = {}
|
|
for idx in range(probs.shape[0]):
|
|
prob_value = float(probs[idx])
|
|
if prob_value >= min_confidence:
|
|
idx_str = str(idx)
|
|
tag_name = self.metadata['idx_to_tag'].get(idx_str, f"unknown-{idx}")
|
|
category = self.metadata['tag_to_category'].get(tag_name, "general")
|
|
|
|
if category not in all_probs:
|
|
all_probs[category] = []
|
|
|
|
all_probs[category].append((tag_name, prob_value))
|
|
|
|
|
|
for category in all_probs:
|
|
all_probs[category] = sorted(
|
|
all_probs[category],
|
|
key=lambda x: x[1],
|
|
reverse=True
|
|
)
|
|
|
|
|
|
tags = {}
|
|
for category, cat_tags in all_probs.items():
|
|
|
|
if category_thresholds and category in category_thresholds:
|
|
cat_threshold = category_thresholds[category]
|
|
else:
|
|
cat_threshold = threshold
|
|
|
|
tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= cat_threshold]
|
|
|
|
|
|
all_tags = []
|
|
for category, cat_tags in tags.items():
|
|
for tag, _ in cat_tags:
|
|
all_tags.append(tag)
|
|
|
|
batch_results.append({
|
|
'tags': tags,
|
|
'all_probs': all_probs,
|
|
'all_tags': all_tags,
|
|
'success': True
|
|
})
|
|
|
|
return batch_results
|
|
|
|
def batch_process_images_onnx(folder_path, model_path, metadata_path, threshold_profile,
|
|
active_threshold, active_category_thresholds, save_dir=None,
|
|
progress_callback=None, min_confidence=0.1, batch_size=16,
|
|
category_limits=None):
|
|
"""
|
|
Process all images in a folder using the ONNX model.
|
|
|
|
Args:
|
|
folder_path: Path to folder containing images
|
|
model_path: Path to the ONNX model file
|
|
metadata_path: Path to the model metadata file
|
|
threshold_profile: Selected threshold profile
|
|
active_threshold: Overall threshold value
|
|
active_category_thresholds: Category-specific thresholds
|
|
save_dir: Directory to save tag files (if None uses default)
|
|
progress_callback: Optional callback for progress updates
|
|
min_confidence: Minimum confidence threshold
|
|
batch_size: Number of images to process at once
|
|
category_limits: Dictionary mapping categories to their tag limits (0 = unlimited)
|
|
|
|
Returns:
|
|
Dictionary with results for each image
|
|
"""
|
|
from utils.file_utils import save_tags_to_file
|
|
|
|
|
|
image_extensions = ['*.jpg', '*.jpeg', '*.png']
|
|
image_files = []
|
|
|
|
for ext in image_extensions:
|
|
image_files.extend(glob.glob(os.path.join(folder_path, ext)))
|
|
image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))
|
|
|
|
if os.name == 'nt':
|
|
|
|
unique_paths = set()
|
|
unique_files = []
|
|
for file_path in image_files:
|
|
normalized_path = os.path.normpath(file_path).lower()
|
|
if normalized_path not in unique_paths:
|
|
unique_paths.add(normalized_path)
|
|
unique_files.append(file_path)
|
|
image_files = unique_files
|
|
|
|
if not image_files:
|
|
return {
|
|
'success': False,
|
|
'error': f"No images found in {folder_path}",
|
|
'results': {}
|
|
}
|
|
|
|
|
|
if save_dir is None:
|
|
app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
save_dir = os.path.join(app_dir, "saved_tags")
|
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
|
|
tagger = ONNXImageTagger(model_path, metadata_path)
|
|
|
|
|
|
results = {}
|
|
total_images = len(image_files)
|
|
processed = 0
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
for i in range(0, total_images, batch_size):
|
|
batch_start = time.time()
|
|
|
|
|
|
batch_files = image_files[i:i+batch_size]
|
|
batch_size_actual = len(batch_files)
|
|
|
|
|
|
if progress_callback:
|
|
progress_callback(processed, total_images, batch_files[0] if batch_files else None)
|
|
|
|
print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images")
|
|
|
|
try:
|
|
|
|
processed_images, valid_paths = preprocess_images_parallel(batch_files)
|
|
|
|
if processed_images:
|
|
|
|
batch_results = tagger.predict_batch(
|
|
processed_images,
|
|
threshold=active_threshold,
|
|
category_thresholds=active_category_thresholds,
|
|
min_confidence=min_confidence
|
|
)
|
|
|
|
|
|
for j, (image_path, result) in enumerate(zip(valid_paths, batch_results)):
|
|
|
|
if progress_callback:
|
|
progress_callback(processed + j, total_images, image_path)
|
|
|
|
|
|
print(f"Before limiting - Tags for {os.path.basename(image_path)}: {len(result['all_tags'])} tags")
|
|
print(f"Category limits applied: {category_limits}")
|
|
|
|
|
|
if category_limits and result['success']:
|
|
|
|
before_counts = {cat: len(tags) for cat, tags in result['tags'].items()}
|
|
|
|
|
|
result = apply_category_limits(result, category_limits)
|
|
|
|
|
|
after_counts = {cat: len(tags) for cat, tags in result['tags'].items()}
|
|
|
|
|
|
print(f"Before limits: {before_counts}")
|
|
print(f"After limits: {after_counts}")
|
|
print(f"After limiting - Tags for {os.path.basename(image_path)}: {len(result['all_tags'])} tags")
|
|
|
|
|
|
if result['success']:
|
|
output_path = save_tags_to_file(
|
|
image_path=image_path,
|
|
all_tags=result['all_tags'],
|
|
custom_dir=save_dir,
|
|
overwrite=True
|
|
)
|
|
result['output_path'] = str(output_path)
|
|
|
|
|
|
results[image_path] = result
|
|
|
|
processed += batch_size_actual
|
|
|
|
|
|
batch_end = time.time()
|
|
batch_time = batch_end - batch_start
|
|
print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)")
|
|
|
|
except Exception as e:
|
|
print(f"Error processing batch: {str(e)}")
|
|
traceback.print_exc()
|
|
|
|
|
|
for image_path in batch_files:
|
|
try:
|
|
|
|
if progress_callback:
|
|
progress_callback(processed + j, total_images, image_path)
|
|
|
|
|
|
print(f"Before limiting - Tags for {os.path.basename(image_path)}: {len(result['all_tags'])} tags")
|
|
print(f"Category limits applied: {category_limits}")
|
|
|
|
|
|
if category_limits and result['success']:
|
|
|
|
before_counts = {cat: len(tags) for cat, tags in result['tags'].items()}
|
|
|
|
|
|
result = apply_category_limits(result, category_limits)
|
|
|
|
|
|
after_counts = {cat: len(tags) for cat, tags in result['tags'].items()}
|
|
|
|
|
|
print(f"Before limits: {before_counts}")
|
|
print(f"After limits: {after_counts}")
|
|
print(f"After limiting - Tags for {os.path.basename(image_path)}: {len(result['all_tags'])} tags")
|
|
|
|
|
|
img_array = preprocess_image(image_path)
|
|
|
|
|
|
single_results = tagger.predict_batch(
|
|
[img_array],
|
|
threshold=active_threshold,
|
|
category_thresholds=active_category_thresholds,
|
|
min_confidence=min_confidence
|
|
)
|
|
|
|
if single_results:
|
|
result = single_results[0]
|
|
|
|
|
|
if result['success']:
|
|
output_path = save_tags_to_file(
|
|
image_path=image_path,
|
|
all_tags=result['all_tags'],
|
|
custom_dir=save_dir,
|
|
overwrite=True
|
|
)
|
|
result['output_path'] = str(output_path)
|
|
|
|
|
|
results[image_path] = result
|
|
else:
|
|
results[image_path] = {
|
|
'success': False,
|
|
'error': 'Failed to process image',
|
|
'all_tags': []
|
|
}
|
|
|
|
except Exception as img_e:
|
|
print(f"Error processing single image {image_path}: {str(img_e)}")
|
|
results[image_path] = {
|
|
'success': False,
|
|
'error': str(img_e),
|
|
'all_tags': []
|
|
}
|
|
|
|
processed += 1
|
|
|
|
|
|
if progress_callback:
|
|
progress_callback(total_images, total_images, None)
|
|
|
|
end_time = time.time()
|
|
total_time = end_time - start_time
|
|
print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image")
|
|
|
|
return {
|
|
'success': True,
|
|
'total': total_images,
|
|
'processed': len(results),
|
|
'results': results,
|
|
'save_dir': save_dir,
|
|
'time_elapsed': end_time - start_time
|
|
} |