|
|
|
""" |
|
Image Tagger Application |
|
A Streamlit web app for tagging images using an AI model. |
|
""" |
|
|
|
import streamlit as st |
|
import os |
|
import sys |
|
import traceback |
|
import tempfile |
|
import time |
|
import platform |
|
import subprocess |
|
import webbrowser |
|
import glob |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import io |
|
import base64 |
|
from matplotlib.colors import LinearSegmentedColormap |
|
from PIL import Image |
|
from pathlib import Path |
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from utils.model_loader import load_exported_model, is_windows, check_flash_attention |
|
from utils.image_processing import process_image, batch_process_images |
|
from utils.file_utils import save_tags_to_file, get_default_save_locations |
|
from utils.ui_components import display_progress_bar, show_example_images, display_batch_results |
|
from utils.onnx_processing import batch_process_images_onnx |
|
|
|
|
|
MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "model") |
|
print(f"Using model directory: {MODEL_DIR}") |
|
|
|
|
|
threshold_profile_descriptions = { |
|
"Micro Optimized": "Maximizes micro-averaged F1 score (best for dominant classes). Optimal for overall prediction quality.", |
|
"Macro Optimized": "Maximizes macro-averaged F1 score (equal weight to all classes). Better for balanced performance across all tags.", |
|
"Balanced": "Provides a trade-off between precision and recall with moderate thresholds. Good general-purpose setting.", |
|
"High Precision": "Uses higher thresholds to prioritize accuracy over recall. Fewer but more confident predictions.", |
|
"High Recall": "Uses lower thresholds to capture more potential tags at the expense of accuracy. More comprehensive tagging.", |
|
"Overall": "Uses a single threshold value across all categories. Simplest approach for consistent behavior.", |
|
"Weighted": "Uses thresholds weighted by category importance. Better balance for tags that matter most.", |
|
"Category-specific": "Uses different optimal thresholds for each category. Best for fine-tuning results." |
|
} |
|
|
|
threshold_profile_explanations = { |
|
"Micro Optimized": """ |
|
### Micro Optimized Profile |
|
|
|
**Technical definition**: Maximizes micro-averaged F1 score, which calculates metrics globally across all predictions. |
|
|
|
**When to use**: When you want the best overall accuracy, especially for common tags and dominant categories. |
|
|
|
**Effects**: |
|
- Optimizes performance for the most frequent tags |
|
- Gives more weight to categories with many examples (like 'character' and 'general') |
|
- Provides higher precision in most common use cases |
|
|
|
**Threshold value**: Approximately 0.33 (optimized on validation data) |
|
|
|
**Performance metrics**: |
|
- Micro F1: ~0.62 |
|
- Macro F1: ~0.35 |
|
- Precision: ~0.63 |
|
- Recall: ~0.60 |
|
""", |
|
|
|
"Macro Optimized": """ |
|
### Macro Optimized Profile |
|
|
|
**Technical definition**: Maximizes macro-averaged F1 score, which gives equal weight to all categories regardless of size. |
|
|
|
**When to use**: When balanced performance across all categories is important, including rare tags. |
|
|
|
**Effects**: |
|
- More balanced performance across all tag categories |
|
- Better at detecting rare or unusual tags |
|
- Generally has lower thresholds than micro-optimized |
|
|
|
**Threshold value**: Approximately 0.19-0.21 (optimized on validation data) |
|
|
|
**Performance metrics**: |
|
- Micro F1: ~0.49 |
|
- Macro F1: ~0.41 |
|
- Precision: ~0.37 |
|
- Recall: ~0.53 |
|
""", |
|
|
|
"Balanced": """ |
|
### Balanced Profile |
|
|
|
**Technical definition**: Provides a compromise between precision and recall with moderate thresholds. |
|
|
|
**When to use**: For general-purpose tagging when you don't have specific recall or precision requirements. |
|
|
|
**Effects**: |
|
- Good middle ground between precision and recall |
|
- Works well for most common use cases |
|
- Default choice for most users |
|
|
|
**Threshold value**: Approximately 0.26 (optimized on validation data) |
|
|
|
**Performance metrics**: |
|
- Micro F1: ~0.59 |
|
- Macro F1: ~0.39 |
|
- Precision: ~0.51 |
|
- Recall: ~0.70 |
|
""", |
|
|
|
"High Precision": """ |
|
### High Precision Profile |
|
|
|
**Technical definition**: Uses higher thresholds to prioritize precision (correctness) over recall (coverage). |
|
|
|
**When to use**: When you need high confidence in the tags that are returned and prefer to miss tags rather than include incorrect ones. |
|
|
|
**Effects**: |
|
- Much higher precision (84-97% of returned tags are correct) |
|
- Lower recall (only captures 35-60% of relevant tags) |
|
- Returns fewer tags overall, but with higher confidence |
|
|
|
**Threshold value**: 0.50 (optimized for precision on validation data) |
|
|
|
**Performance metrics**: |
|
- Micro F1: ~0.50 |
|
- Macro F1: ~0.22 |
|
- Precision: ~0.84 |
|
- Recall: ~0.35 |
|
""", |
|
|
|
"High Recall": """ |
|
### High Recall Profile |
|
|
|
**Technical definition**: Uses lower thresholds to prioritize recall (coverage) over precision (correctness). |
|
|
|
**When to use**: When you want to capture as many potential tags as possible and don't mind some incorrect suggestions. |
|
|
|
**Effects**: |
|
- Very high recall (captures 90%+ of relevant tags) |
|
- Much lower precision (only 18-49% of returned tags may be correct) |
|
- Returns many more tags, including less confident ones |
|
|
|
**Threshold value**: 0.10-0.12 (optimized for recall on validation data) |
|
|
|
**Performance metrics**: |
|
- Micro F1: ~0.30 |
|
- Macro F1: ~0.35 |
|
- Precision: ~0.18 |
|
- Recall: ~0.90 |
|
""", |
|
|
|
"Overall": """ |
|
### Overall Profile |
|
|
|
**Technical definition**: Uses a single threshold value across all categories. |
|
|
|
**When to use**: When you want consistent behavior across all categories and a simple approach. |
|
|
|
**Effects**: |
|
- Consistent tagging threshold for all categories |
|
- Simpler to understand than category-specific thresholds |
|
- User-adjustable with a single slider |
|
|
|
**Default threshold value**: 0.35 (uses "balanced" threshold by default) |
|
|
|
**Note**: The threshold value is user-adjustable with the slider below. |
|
""", |
|
|
|
"Weighted": """ |
|
### Weighted Profile |
|
|
|
**Technical definition**: Uses thresholds weighted by category importance. |
|
|
|
**When to use**: When you want different sensitivity for different categories based on their importance. |
|
|
|
**Effects**: |
|
- More important categories (like character and copyright) get optimized thresholds |
|
- Less important categories get adjusted thresholds based on their contribution |
|
- Better balance for the tags that matter most |
|
|
|
**Default threshold values**: Varies by category (based on importance weighting) |
|
|
|
**Note**: This uses pre-calculated optimal thresholds that can't be adjusted directly. |
|
""", |
|
|
|
"Category-specific": """ |
|
### Category-specific Profile |
|
|
|
**Technical definition**: Uses different optimal thresholds for each category, allowing fine-tuning. |
|
|
|
**When to use**: When you want to customize tagging sensitivity for different categories. |
|
|
|
**Effects**: |
|
- Each category has its own independent threshold |
|
- Full control over category sensitivity |
|
- Best for fine-tuning results when some categories need different treatment |
|
|
|
**Default threshold values**: Starts with balanced thresholds for each category |
|
|
|
**Note**: Use the category sliders below to adjust thresholds for individual categories. |
|
""" |
|
} |
|
|
|
def get_profile_metrics(thresholds, profile_name, model_type="refined"): |
|
""" |
|
Extract metrics for the given profile from the thresholds dictionary |
|
|
|
Args: |
|
thresholds: The thresholds dictionary |
|
profile_name: Name of the profile (micro_opt, macro_opt, balanced, etc.) |
|
model_type: 'initial' or 'refined' |
|
|
|
Returns: |
|
Dictionary with metrics or None if not found |
|
""" |
|
profile_key = None |
|
|
|
|
|
if profile_name == "Micro Optimized": |
|
profile_key = "micro_opt" |
|
elif profile_name == "Macro Optimized": |
|
profile_key = "macro_opt" |
|
elif profile_name == "Balanced": |
|
profile_key = "balanced" |
|
elif profile_name == "High Precision": |
|
profile_key = "high_precision" |
|
elif profile_name == "High Recall": |
|
profile_key = "high_recall" |
|
|
|
|
|
elif profile_name in ["Overall", "Weighted", "Category-specific"]: |
|
profile_key = "balanced" |
|
|
|
|
|
if "initial" in thresholds and "refined" in thresholds: |
|
|
|
model_type_key = model_type |
|
|
|
|
|
if model_type_key not in thresholds: |
|
model_type_key = "refined" if "refined" in thresholds else "initial" |
|
|
|
|
|
if "overall" in thresholds[model_type_key] and profile_key in thresholds[model_type_key]["overall"]: |
|
return thresholds[model_type_key]["overall"][profile_key] |
|
else: |
|
|
|
if "overall" in thresholds and profile_key in thresholds["overall"]: |
|
return thresholds["overall"][profile_key] |
|
|
|
return None |
|
|
|
def on_threshold_profile_change(): |
|
""" |
|
Handle threshold profile changes to ensure smooth transitions between modes |
|
and preserve user customizations |
|
""" |
|
|
|
new_profile = st.session_state.threshold_profile |
|
|
|
|
|
if hasattr(st.session_state, 'thresholds') and hasattr(st.session_state, 'settings'): |
|
|
|
if st.session_state.settings['active_category_thresholds'] is None: |
|
st.session_state.settings['active_category_thresholds'] = {} |
|
|
|
|
|
current_thresholds = st.session_state.settings['active_category_thresholds'] |
|
|
|
|
|
if "initial" in st.session_state.thresholds and "refined" in st.session_state.thresholds: |
|
model_type_key = "refined" if hasattr(st.session_state, 'model_type') and st.session_state.model_type == "full" else "initial" |
|
|
|
|
|
if model_type_key not in st.session_state.thresholds: |
|
model_type_key = "refined" if "refined" in st.session_state.thresholds else "initial" |
|
else: |
|
|
|
model_type_key = None |
|
|
|
|
|
profile_key = None |
|
if new_profile == "Micro Optimized": |
|
profile_key = "micro_opt" |
|
elif new_profile == "Macro Optimized": |
|
profile_key = "macro_opt" |
|
elif new_profile == "Balanced": |
|
profile_key = "balanced" |
|
elif new_profile == "High Precision": |
|
profile_key = "high_precision" |
|
elif new_profile == "High Recall": |
|
profile_key = "high_recall" |
|
|
|
|
|
if profile_key: |
|
|
|
if model_type_key is not None: |
|
if "overall" in st.session_state.thresholds[model_type_key] and profile_key in st.session_state.thresholds[model_type_key]["overall"]: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds[model_type_key]["overall"][profile_key]["threshold"] |
|
else: |
|
if "overall" in st.session_state.thresholds and profile_key in st.session_state.thresholds["overall"]: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds["overall"][profile_key]["threshold"] |
|
|
|
|
|
for category in st.session_state.categories: |
|
if model_type_key is not None: |
|
if "categories" in st.session_state.thresholds[model_type_key] and category in st.session_state.thresholds[model_type_key]["categories"]: |
|
if profile_key in st.session_state.thresholds[model_type_key]["categories"][category]: |
|
current_thresholds[category] = st.session_state.thresholds[model_type_key]["categories"][category][profile_key]["threshold"] |
|
else: |
|
|
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
else: |
|
if "categories" in st.session_state.thresholds and category in st.session_state.thresholds["categories"]: |
|
if profile_key in st.session_state.thresholds["categories"][category]: |
|
current_thresholds[category] = st.session_state.thresholds["categories"][category][profile_key]["threshold"] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
|
|
|
|
elif new_profile == "Overall": |
|
|
|
if model_type_key is not None: |
|
if "overall" in st.session_state.thresholds[model_type_key] and "balanced" in st.session_state.thresholds[model_type_key]["overall"]: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds[model_type_key]["overall"]["balanced"]["threshold"] |
|
else: |
|
if "overall" in st.session_state.thresholds and "balanced" in st.session_state.thresholds["overall"]: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds["overall"]["balanced"]["threshold"] |
|
|
|
|
|
st.session_state.settings['active_category_thresholds'] = {} |
|
|
|
|
|
elif new_profile == "Weighted": |
|
|
|
if model_type_key is not None: |
|
if "overall" in st.session_state.thresholds[model_type_key] and "balanced" in st.session_state.thresholds[model_type_key]["overall"]: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds[model_type_key]["overall"]["balanced"]["threshold"] |
|
else: |
|
if "overall" in st.session_state.thresholds and "balanced" in st.session_state.thresholds["overall"]: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds["overall"]["balanced"]["threshold"] |
|
|
|
|
|
if model_type_key is not None: |
|
if "weighted" in st.session_state.thresholds[model_type_key]: |
|
weighted_thresholds = st.session_state.thresholds[model_type_key]["weighted"] |
|
for category in st.session_state.categories: |
|
if category in weighted_thresholds: |
|
current_thresholds[category] = weighted_thresholds[category] |
|
else: |
|
|
|
if "categories" in st.session_state.thresholds[model_type_key] and category in st.session_state.thresholds[model_type_key]["categories"]: |
|
if "balanced" in st.session_state.thresholds[model_type_key]["categories"][category]: |
|
current_thresholds[category] = st.session_state.thresholds[model_type_key]["categories"][category]["balanced"]["threshold"] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
else: |
|
if "weighted" in st.session_state.thresholds: |
|
weighted_thresholds = st.session_state.thresholds["weighted"] |
|
for category in st.session_state.categories: |
|
if category in weighted_thresholds: |
|
current_thresholds[category] = weighted_thresholds[category] |
|
else: |
|
|
|
if "categories" in st.session_state.thresholds and category in st.session_state.thresholds["categories"]: |
|
if "balanced" in st.session_state.thresholds["categories"][category]: |
|
current_thresholds[category] = st.session_state.thresholds["categories"][category]["balanced"]["threshold"] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
|
|
|
|
elif new_profile == "Category-specific": |
|
|
|
if model_type_key is not None: |
|
if "overall" in st.session_state.thresholds[model_type_key] and "balanced" in st.session_state.thresholds[model_type_key]["overall"]: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds[model_type_key]["overall"]["balanced"]["threshold"] |
|
else: |
|
if "overall" in st.session_state.thresholds and "balanced" in st.session_state.thresholds["overall"]: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds["overall"]["balanced"]["threshold"] |
|
|
|
|
|
for category in st.session_state.categories: |
|
if model_type_key is not None: |
|
if "categories" in st.session_state.thresholds[model_type_key] and category in st.session_state.thresholds[model_type_key]["categories"]: |
|
if "balanced" in st.session_state.thresholds[model_type_key]["categories"][category]: |
|
current_thresholds[category] = st.session_state.thresholds[model_type_key]["categories"][category]["balanced"]["threshold"] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
else: |
|
if "categories" in st.session_state.thresholds and category in st.session_state.thresholds["categories"]: |
|
if "balanced" in st.session_state.thresholds["categories"][category]: |
|
current_thresholds[category] = st.session_state.thresholds["categories"][category]["balanced"]["threshold"] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
|
|
def create_micro_macro_comparison(): |
|
""" |
|
Creates a visual explanation of micro vs macro optimization |
|
|
|
Returns: |
|
HTML for the visualization |
|
""" |
|
html = """ |
|
<style> |
|
.optimization-container { |
|
font-family: sans-serif; |
|
margin: 20px 0; |
|
} |
|
.optimization-row { |
|
display: flex; |
|
margin-bottom: 15px; |
|
} |
|
.optimization-col { |
|
flex: 1; |
|
padding: 15px; |
|
border-radius: 8px; |
|
margin: 0 5px; |
|
} |
|
.optimization-col h3 { |
|
margin-top: 0; |
|
font-size: 18px; |
|
} |
|
.optimization-col p { |
|
font-size: 14px; |
|
line-height: 1.5; |
|
} |
|
.micro-col { |
|
background-color: #e6f3ff; |
|
border: 1px solid #99ccff; |
|
} |
|
.macro-col { |
|
background-color: #fff0e6; |
|
border: 1px solid #ffcc99; |
|
} |
|
.tag-example { |
|
display: inline-block; |
|
padding: 3px 8px; |
|
margin: 3px; |
|
border-radius: 12px; |
|
font-size: 12px; |
|
} |
|
.tag-common { |
|
background-color: #4CAF50; |
|
color: white; |
|
} |
|
.tag-rare { |
|
background-color: #9C27B0; |
|
color: white; |
|
} |
|
.comparison-table { |
|
width: 100%; |
|
border-collapse: collapse; |
|
margin-top: 15px; |
|
} |
|
.comparison-table th, .comparison-table td { |
|
border: 1px solid #ddd; |
|
padding: 8px; |
|
text-align: left; |
|
} |
|
.comparison-table th { |
|
background-color: #f2f2f2; |
|
} |
|
</style> |
|
""" |
|
return html |
|
|
|
def apply_thresholds(all_probs, threshold_profile, active_threshold, active_category_thresholds, min_confidence, selected_categories): |
|
""" |
|
Apply thresholds to raw probabilities and return filtered tags |
|
|
|
Args: |
|
all_probs: Dictionary with all probabilities organized by category |
|
threshold_profile: Current threshold profile |
|
active_threshold: Overall threshold value |
|
active_category_thresholds: Dictionary of category-specific thresholds |
|
min_confidence: Minimum confidence to include |
|
selected_categories: Dictionary of selected categories |
|
|
|
Returns: |
|
tags: Dictionary of filtered tags above threshold by category |
|
all_tags: List of all tags above threshold |
|
""" |
|
|
|
tags = {} |
|
all_tags = [] |
|
|
|
for category, cat_probs in all_probs.items(): |
|
|
|
|
|
threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold |
|
|
|
|
|
tags[category] = [(tag, prob) for tag, prob in cat_probs if prob >= threshold] |
|
|
|
|
|
if selected_categories.get(category, True): |
|
for tag, prob in tags[category]: |
|
all_tags.append(tag) |
|
|
|
return tags, all_tags |
|
|
|
def image_tagger_app(): |
|
"""Main Streamlit application for image tagging.""" |
|
st.set_page_config(layout="wide", page_title="Camie Tagger", page_icon="🖼️") |
|
|
|
st.title("Image Tagging Interface") |
|
st.markdown("---") |
|
|
|
|
|
windows_system = is_windows() |
|
flash_attn_installed = check_flash_attention() |
|
|
|
if 'settings' not in st.session_state: |
|
st.session_state.settings = { |
|
'show_all_tags': False, |
|
'compact_view': True, |
|
'min_confidence': 0.01, |
|
'threshold_profile': "Balanced", |
|
'active_threshold': 0.35, |
|
'active_category_thresholds': None, |
|
'selected_categories': {}, |
|
'replace_underscores': False |
|
} |
|
|
|
st.session_state.show_profile_help = False |
|
|
|
|
|
default_threshold_values = { |
|
'overall': 0.35, |
|
'weighted': 0.35, |
|
'category_thresholds': {}, |
|
'high_precision_thresholds': {}, |
|
'high_recall_thresholds': {} |
|
} |
|
|
|
|
|
if 'model_loaded' not in st.session_state: |
|
st.session_state.model_loaded = False |
|
st.session_state.model = None |
|
st.session_state.thresholds = None |
|
st.session_state.metadata = None |
|
|
|
|
|
onnx_model_path = os.path.join(os.path.dirname(MODEL_DIR), "model_initial.onnx") |
|
onnx_metadata_path = os.path.join(os.path.dirname(MODEL_DIR), "model_initial_metadata.json") |
|
onnx_available = os.path.exists(onnx_model_path) and os.path.exists(onnx_metadata_path) |
|
|
|
|
|
if onnx_available: |
|
st.session_state.model_type = "onnx" |
|
else: |
|
st.session_state.model_type = "initial_only" if windows_system else "full" |
|
|
|
|
|
with st.sidebar: |
|
st.header("Model Selection") |
|
|
|
|
|
onnx_model_path = os.path.join(os.path.dirname(MODEL_DIR), "model_initial.onnx") |
|
onnx_metadata_path = os.path.join(os.path.dirname(MODEL_DIR), "model_initial_metadata.json") |
|
onnx_available = os.path.exists(onnx_model_path) and os.path.exists(onnx_metadata_path) |
|
|
|
|
|
model_options = [ |
|
"Refined (Tag Embeddings)", |
|
"Initial (Base Model)" |
|
] |
|
|
|
|
|
if onnx_available: |
|
model_options.append("ONNX Accelerated (Fastest)") |
|
|
|
|
|
if st.session_state.model_type == "onnx" and onnx_available: |
|
default_index = 2 |
|
elif windows_system or st.session_state.model_type == "initial_only": |
|
default_index = 1 |
|
else: |
|
default_index = 0 |
|
|
|
|
|
model_type = st.radio( |
|
"Select Model Type:", |
|
model_options, |
|
index=min(default_index, len(model_options)-1), |
|
help=""" |
|
Full Model: Uses both initial and refined predictions for highest accuracy (requires more VRAM) |
|
Initial Only: Uses only the initial classifier, reducing VRAM usage at slight quality cost |
|
ONNX Accelerated: Optimized for inference speed, best for batch processing (if available) |
|
""" |
|
) |
|
|
|
|
|
if model_type == "Full Model (Best Quality)": |
|
selected_model_type = "full" |
|
elif model_type == "ONNX Accelerated (Fastest)": |
|
selected_model_type = "onnx" |
|
|
|
st.session_state.onnx_model_path = onnx_model_path |
|
st.session_state.onnx_metadata_path = onnx_metadata_path |
|
else: |
|
selected_model_type = "initial_only" |
|
|
|
|
|
if selected_model_type != st.session_state.model_type: |
|
st.session_state.model_loaded = False |
|
st.session_state.model_type = selected_model_type |
|
|
|
|
|
if windows_system and selected_model_type=="full": |
|
st.warning(""" |
|
### Windows Compatibility Note |
|
|
|
The refined model requires Flash Attention which is difficult to install on Windows. |
|
|
|
For Windows users, I recommend using the "Initial Only" or ONNX Accelerated model which: |
|
- Does not require Flash Attention |
|
- Uses less memory |
|
- Provides very close to full prediction quality (check performance notes on HF) |
|
""") |
|
|
|
|
|
if st.button("Reload Model") and st.session_state.model_loaded: |
|
st.session_state.model_loaded = False |
|
st.info("Reloading model...") |
|
|
|
|
|
if not st.session_state.model_loaded: |
|
try: |
|
with st.spinner(f"Loading {st.session_state.model_type} model..."): |
|
if st.session_state.model_type == "onnx": |
|
|
|
import json |
|
import onnxruntime as ort |
|
|
|
try: |
|
|
|
providers = ort.get_available_providers() |
|
gpu_available = any('GPU' in provider for provider in providers) |
|
|
|
|
|
st.session_state.onnx_providers = providers |
|
st.session_state.onnx_gpu_available = gpu_available |
|
|
|
|
|
with open(st.session_state.onnx_metadata_path, 'r') as f: |
|
metadata = json.load(f) |
|
|
|
|
|
thresholds_path = os.path.join(MODEL_DIR, "thresholds.json") |
|
if os.path.exists(thresholds_path): |
|
with open(thresholds_path, 'r') as f: |
|
thresholds = json.load(f) |
|
else: |
|
|
|
if 'thresholds' in metadata: |
|
thresholds = metadata['thresholds'] |
|
else: |
|
|
|
thresholds = { |
|
'overall': {'balanced': {'threshold': 0.35}}, |
|
'weighted': {'f1': {'threshold': 0.35}}, |
|
'categories': {} |
|
} |
|
|
|
|
|
if 'tag_to_category' in metadata: |
|
categories = set(metadata['tag_to_category'].values()) |
|
thresholds['categories'] = { |
|
cat: { |
|
'balanced': {'threshold': 0.35}, |
|
'high_precision': {'threshold': 0.45}, |
|
'high_recall': {'threshold': 0.25} |
|
} for cat in categories |
|
} |
|
|
|
|
|
device = "ONNX Runtime" + (" (GPU)" if gpu_available else " (CPU)") |
|
param_dtype = "float32" |
|
|
|
|
|
st.session_state.model = None |
|
st.session_state.device = device |
|
st.session_state.param_dtype = param_dtype |
|
st.session_state.thresholds = thresholds |
|
st.session_state.metadata = metadata |
|
st.session_state.model_loaded = True |
|
|
|
|
|
categories = list(set(metadata['tag_to_category'].values())) |
|
st.session_state.categories = categories |
|
|
|
|
|
if not st.session_state.settings['selected_categories']: |
|
st.session_state.settings['selected_categories'] = {cat: True for cat in categories} |
|
|
|
except Exception as e: |
|
st.error(f"Error loading ONNX model: {str(e)}") |
|
st.info(f"Make sure the ONNX model and metadata files exist at: {st.session_state.onnx_model_path} and {st.session_state.onnx_metadata_path}") |
|
st.code(traceback.format_exc()) |
|
st.stop() |
|
else: |
|
|
|
model, thresholds, metadata = load_exported_model( |
|
MODEL_DIR, |
|
model_type=st.session_state.model_type |
|
) |
|
|
|
|
|
device = next(model.parameters()).device |
|
param_dtype = next(model.parameters()).dtype |
|
|
|
|
|
st.session_state.model = model |
|
|
|
|
|
|
|
categories = list(set(metadata['tag_to_category'].values())) |
|
|
|
|
|
if not st.session_state.settings['selected_categories']: |
|
st.session_state.settings['selected_categories'] = {cat: True for cat in categories} |
|
|
|
|
|
st.session_state.device = device |
|
st.session_state.param_dtype = param_dtype |
|
st.session_state.thresholds = thresholds |
|
st.session_state.metadata = metadata |
|
st.session_state.model_loaded = True |
|
st.session_state.categories = categories |
|
|
|
|
|
print("Loaded thresholds:", thresholds) |
|
|
|
if "initial" in thresholds and "refined" in thresholds: |
|
|
|
model_type_key = "refined" if st.session_state.model_type == "full" else "initial" |
|
|
|
|
|
if "overall" in thresholds[model_type_key] and "balanced" in thresholds[model_type_key]["overall"]: |
|
default_threshold_values['overall'] = thresholds[model_type_key]["overall"]["balanced"]["threshold"] |
|
|
|
|
|
if "weighted" in thresholds[model_type_key] and "f1" in thresholds[model_type_key]["weighted"]: |
|
default_threshold_values['weighted'] = thresholds[model_type_key]["weighted"]["f1"]["threshold"] |
|
|
|
|
|
if "categories" in thresholds[model_type_key]: |
|
default_threshold_values['category_thresholds'] = { |
|
cat: opt['balanced']['threshold'] |
|
for cat, opt in thresholds[model_type_key]["categories"].items() |
|
} |
|
|
|
|
|
default_threshold_values['high_precision_thresholds'] = { |
|
cat: opt['high_precision']['threshold'] |
|
for cat, opt in thresholds[model_type_key]["categories"].items() |
|
} |
|
|
|
default_threshold_values['high_recall_thresholds'] = { |
|
cat: opt['high_recall']['threshold'] |
|
for cat, opt in thresholds[model_type_key]["categories"].items() |
|
} |
|
else: |
|
|
|
if "overall" in thresholds and "balanced" in thresholds["overall"]: |
|
default_threshold_values['overall'] = thresholds["overall"]["balanced"]["threshold"] |
|
|
|
|
|
if "weighted" in thresholds and "f1" in thresholds["weighted"]: |
|
default_threshold_values['weighted'] = thresholds["weighted"]["f1"]["threshold"] |
|
|
|
|
|
if "categories" in thresholds: |
|
default_threshold_values['category_thresholds'] = { |
|
cat: opt['balanced']['threshold'] |
|
for cat, opt in thresholds["categories"].items() |
|
} |
|
|
|
|
|
default_threshold_values['high_precision_thresholds'] = { |
|
cat: opt['high_precision']['threshold'] |
|
for cat, opt in thresholds["categories"].items() |
|
} |
|
|
|
default_threshold_values['high_recall_thresholds'] = { |
|
cat: opt['high_recall']['threshold'] |
|
for cat, opt in thresholds["categories"].items() |
|
} |
|
|
|
|
|
|
|
st.session_state.default_threshold_values = default_threshold_values |
|
|
|
|
|
if st.session_state.settings['threshold_profile'] == "Overall": |
|
st.session_state.settings['active_threshold'] = default_threshold_values['overall'] |
|
elif st.session_state.settings['threshold_profile'] == "Weighted": |
|
st.session_state.settings['active_threshold'] = default_threshold_values['weighted'] |
|
|
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
st.info(f"Looking for model in: {os.path.abspath(MODEL_DIR)}") |
|
|
|
|
|
if st.session_state.model_type == "initial_only": |
|
expected_model_paths = [ |
|
os.path.join(MODEL_DIR, "model_initial_only.pt"), |
|
os.path.join(MODEL_DIR, "model_initial.pt") |
|
] |
|
if not any(os.path.exists(p) for p in expected_model_paths): |
|
st.error(f"Initial-only model file not found. Checked: {', '.join(expected_model_paths)}") |
|
st.info("Make sure you've exported both model types.") |
|
else: |
|
expected_model_paths = [ |
|
os.path.join(MODEL_DIR, "model_refined.pt"), |
|
os.path.join(MODEL_DIR, "model.pt"), |
|
os.path.join(MODEL_DIR, "model_full.pt") |
|
] |
|
if not any(os.path.exists(p) for p in expected_model_paths): |
|
st.error(f"Full model file not found. Checked: {', '.join(expected_model_paths)}") |
|
|
|
st.code(traceback.format_exc()) |
|
st.stop() |
|
|
|
with st.sidebar: |
|
st.header("Model Information") |
|
if st.session_state.model_loaded: |
|
|
|
if st.session_state.model_type == "onnx": |
|
st.success("Using ONNX Accelerated Model") |
|
|
|
if hasattr(st.session_state, 'onnx_gpu_available') and st.session_state.onnx_gpu_available: |
|
st.write("Acceleration: GPU available") |
|
else: |
|
st.write("Acceleration: CPU only") |
|
elif st.session_state.model_type == "full": |
|
st.success("Using Full Model (Best Quality)") |
|
|
|
if not flash_attn_installed and is_windows(): |
|
st.warning("Note: Flash Attention not available on Windows") |
|
else: |
|
st.success("Using Initial-Only Model (Lower VRAM)") |
|
|
|
|
|
st.write(f"Device: {st.session_state.device}") |
|
st.write(f"Precision: {st.session_state.param_dtype}") |
|
st.write(f"Total tags: {st.session_state.metadata['total_tags']}") |
|
|
|
|
|
with st.expander("Available Categories"): |
|
for category in sorted(st.session_state.categories): |
|
st.write(f"- {category.capitalize()}") |
|
|
|
|
|
with st.expander("About this app"): |
|
st.write(""" |
|
This app uses a trained image tagging model to analyze and tag images. |
|
|
|
**Model Options**: |
|
- **ONNX Accelerated (Fastest)**: Optimized for inference speed with minimal VRAM usage, ideal for batch processing |
|
- **Refined Model (Tag Embeddings)**: Higher quality predictions using both initial and refined layers (uses more VRAM) |
|
- **Initial Model (Base model)**: Reduced VRAM usage with slightly lower accuracy (good for systems with limited resources) |
|
|
|
**Platform Notes**: |
|
- **Windows Users**: ONNX Accelerated model is recommended for best performance |
|
- **CUDA Support**: GPU acceleration is available for ONNX models if CUDA 12.x and cuDNN are installed |
|
- **Linux Users**: The Refined Model with Flash Attention provides the best quality results |
|
|
|
**Features**: |
|
- Upload or select an image |
|
- Process multiple images in batch mode with customizable batch size |
|
- Choose from different threshold profiles |
|
- Adjust category-specific thresholds |
|
- View predictions organized by category |
|
- Limit results to top N tags within each category |
|
- Save tags to text files in various locations |
|
- Export tags with consistent formatting for external use |
|
- Fast batch processing |
|
|
|
**Threshold profiles**: |
|
- **Micro Optimized**: Optimizes micro-averaged F1 score (best for common tags) |
|
- **Macro Optimized**: Optimizes macro-averaged F1 score (better for rare tags) |
|
- **Balanced**: Provides a good balance of precision and recall |
|
- **High Precision**: Prioritizes accuracy over recall |
|
- **High Recall**: Captures more potential tags but may be less accurate |
|
""") |
|
|
|
with st.sidebar: |
|
|
|
st.markdown("---") |
|
|
|
|
|
st.subheader("💡 Notes") |
|
|
|
st.markdown(""" |
|
This tagger was trained on a subset of the available data and for limited epochs due to hardware limitations. |
|
|
|
A more comprehensive model trained on the full 3+ million image dataset and many more epochs would provide: |
|
- More recent characters and tags. |
|
- Improved accuracy. |
|
|
|
If you find this tool useful and would like to support future development: |
|
""") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
@keyframes coffee-button-glow { |
|
0% { box-shadow: 0 0 5px #FFD700; } |
|
50% { box-shadow: 0 0 15px #FFD700; } |
|
100% { box-shadow: 0 0 5px #FFD700; } |
|
} |
|
|
|
.coffee-button { |
|
display: inline-block; |
|
animation: coffee-button-glow 2s infinite; |
|
border-radius: 5px; |
|
transition: transform 0.3s ease; |
|
} |
|
|
|
.coffee-button:hover { |
|
transform: scale(1.05); |
|
} |
|
</style> |
|
|
|
<a href="https://buymeacoffee.com/camais" target="_blank" class="coffee-button"> |
|
<img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" |
|
alt="Buy Me A Coffee" |
|
style="height: 45px; width: 162px; border-radius: 5px;" /> |
|
</a> |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown(""" |
|
Your support helps with: |
|
- GPU costs for training |
|
- Storage for larger datasets |
|
- Development of new features |
|
- Future projects |
|
|
|
Thank you! 🙏 |
|
|
|
Full Details: https://huggingface.co/Camais03/camie-tagger |
|
""") |
|
|
|
with st.sidebar: |
|
|
|
st.markdown("---") |
|
st.subheader("Try the Tag Collector Game!") |
|
st.write("Test your tagging skills in our gamified version of the image tagger!") |
|
|
|
if st.button("🎮 Launch Tag Collector Game", type="primary"): |
|
|
|
current_port = os.environ.get("STREAMLIT_SERVER_PORT", "8501") |
|
|
|
game_port = "8502" if current_port == "8501" else "8501" |
|
|
|
|
|
game_path = os.path.join(os.path.dirname(__file__), "tag_collector_game.py") |
|
if os.path.exists(game_path): |
|
|
|
try: |
|
|
|
if sys.platform == "win32": |
|
streamlit_path = os.path.join("venv", "Scripts", "streamlit.exe") |
|
else: |
|
streamlit_path = os.path.join("venv", "bin", "streamlit") |
|
|
|
if not os.path.exists(streamlit_path): |
|
streamlit_path = "streamlit" |
|
|
|
|
|
command = [streamlit_path, "run", game_path, "--server.port", game_port] |
|
|
|
|
|
if sys.platform == "win32": |
|
subprocess.Popen(command, shell=True, creationflags=subprocess.CREATE_NEW_CONSOLE) |
|
else: |
|
subprocess.Popen(command) |
|
|
|
|
|
game_url = f"http://localhost:{game_port}" |
|
webbrowser.open(game_url) |
|
st.success(f"Launching Tag Collector Game!") |
|
|
|
except Exception as e: |
|
st.error(f"Failed to launch game: {str(e)}") |
|
else: |
|
st.error(f"Game file not found: {game_path}") |
|
st.info("Make sure tag_collector_game.py is in the same directory as this app.") |
|
|
|
|
|
col1, col2 = st.columns([1, 1.5]) |
|
|
|
|
|
with col1: |
|
st.header("Image") |
|
|
|
|
|
upload_tab, batch_tab = st.tabs(["Upload Image", "Batch Processing"]) |
|
|
|
image_path = None |
|
|
|
with upload_tab: |
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: |
|
tmp_file.write(uploaded_file.getvalue()) |
|
image_path = tmp_file.name |
|
|
|
|
|
st.session_state.original_filename = uploaded_file.name |
|
|
|
|
|
image = Image.open(uploaded_file) |
|
st.image(image, use_container_width=True) |
|
|
|
with batch_tab: |
|
st.subheader("Batch Process Images") |
|
st.write("Process multiple images from a folder and save tags to text files.") |
|
|
|
|
|
batch_folder = st.text_input("Enter folder path containing images:", "") |
|
if st.button("Browse Folder..."): |
|
|
|
st.info("Please type the folder path manually in the text input above.") |
|
|
|
|
|
save_options = st.radio( |
|
"Where to save tag files:", |
|
["Same folder as images", "Custom location", "Default save folder"], |
|
index=0 |
|
) |
|
|
|
|
|
st.subheader("Performance Options") |
|
batch_size_text = st.text_input( |
|
"Batch size (images processed at once)", |
|
value="4", |
|
help="Higher values may improve processing speed but use more memory. Recommended: 4-16" |
|
) |
|
|
|
|
|
try: |
|
batch_size = int(batch_size_text) |
|
if batch_size < 1: |
|
st.warning("Batch size must be at least 1. Using batch size of 1.") |
|
batch_size = 1 |
|
except ValueError: |
|
st.warning("Please enter a valid number for batch size. Using default batch size of 4.") |
|
batch_size = 4 |
|
|
|
|
|
if batch_size > 8: |
|
st.info(f"Using larger batch size ({batch_size}). If you encounter memory issues, try reducing this value.") |
|
|
|
st.write("Set tag limits per category for batch processing:") |
|
|
|
|
|
enable_category_limits = st.checkbox("Limit tags per category in batch output", value=False) |
|
|
|
|
|
if 'category_limits' not in st.session_state: |
|
st.session_state.category_limits = {} |
|
|
|
if enable_category_limits: |
|
|
|
limit_cols = st.columns(2) |
|
|
|
|
|
st.markdown(""" |
|
**Limit Values:** |
|
* **-1** = No limit (include all tags) |
|
* **0** = Exclude category entirely |
|
* **N** (positive number) = Include only top N tags |
|
""") |
|
|
|
if hasattr(st.session_state, 'categories'): |
|
|
|
for i, category in enumerate(sorted(st.session_state.categories)): |
|
col_idx = i % 2 |
|
with limit_cols[col_idx]: |
|
|
|
current_limit = st.session_state.category_limits.get(category, -1) |
|
|
|
|
|
limit_text = st.text_input( |
|
f"{category.capitalize()} (top N):", |
|
value=str(current_limit), |
|
key=f"limit_{category}", |
|
help="-1 = no limit, 0 = exclude, N = top N tags" |
|
) |
|
|
|
|
|
try: |
|
limit = int(limit_text) |
|
if limit < -1: |
|
st.warning(f"Limit for {category} must be -1 or greater. Using -1 (unlimited).") |
|
limit = -1 |
|
except ValueError: |
|
st.warning(f"Invalid limit for {category}. Using -1 (unlimited).") |
|
limit = -1 |
|
|
|
|
|
if limit == -1: |
|
st.caption(f"✅ Including all {category} tags") |
|
elif limit == 0: |
|
st.caption(f"❌ Excluding all {category} tags") |
|
else: |
|
st.caption(f"⚙️ Including top {limit} {category} tags") |
|
|
|
|
|
st.session_state.category_limits[category] = limit |
|
else: |
|
st.info("Categories will be available after loading a model.") |
|
else: |
|
|
|
st.session_state.category_limits = {} |
|
|
|
custom_save_dir = None |
|
if save_options == "Custom location": |
|
|
|
if 'custom_folders' not in st.session_state: |
|
st.session_state.custom_folders = get_default_save_locations() |
|
|
|
custom_save_dir = st.selectbox( |
|
"Select save location:", |
|
options=st.session_state.custom_folders, |
|
format_func=lambda x: os.path.basename(x) if os.path.basename(x) else x |
|
) |
|
|
|
|
|
new_folder = st.text_input("Or enter a new folder path:", key="batch_new_folder") |
|
if st.button("Add Folder", key="batch_add_folder") and new_folder: |
|
if os.path.isdir(new_folder): |
|
if new_folder not in st.session_state.custom_folders: |
|
st.session_state.custom_folders.append(new_folder) |
|
st.success(f"Added folder: {new_folder}") |
|
st.rerun() |
|
else: |
|
st.info("This folder is already in the list.") |
|
else: |
|
try: |
|
|
|
os.makedirs(new_folder, exist_ok=True) |
|
st.session_state.custom_folders.append(new_folder) |
|
st.success(f"Created and added folder: {new_folder}") |
|
st.rerun() |
|
except Exception as e: |
|
st.error(f"Could not create folder: {str(e)}") |
|
|
|
|
|
if batch_folder and os.path.isdir(batch_folder): |
|
|
|
image_extensions = ['*.jpg', '*.jpeg', '*.png'] |
|
image_files = [] |
|
|
|
for ext in image_extensions: |
|
image_files.extend(glob.glob(os.path.join(batch_folder, ext))) |
|
image_files.extend(glob.glob(os.path.join(batch_folder, ext.upper()))) |
|
|
|
|
|
if os.name == 'nt': |
|
|
|
unique_paths = set() |
|
unique_files = [] |
|
for file_path in image_files: |
|
normalized_path = os.path.normpath(file_path).lower() |
|
if normalized_path not in unique_paths: |
|
unique_paths.add(normalized_path) |
|
unique_files.append(file_path) |
|
image_files = unique_files |
|
|
|
total_images = len(image_files) |
|
st.write(f"Found {total_images} image files in the folder.") |
|
|
|
|
|
if image_files: |
|
st.write("Sample images:") |
|
num_preview = min(8, len(image_files)) |
|
thumbnail_cols = st.columns(4) |
|
for i, img_path in enumerate(image_files[:num_preview]): |
|
with thumbnail_cols[i % 4]: |
|
try: |
|
img = Image.open(img_path) |
|
|
|
st.image(img, width=80, caption=os.path.basename(img_path)) |
|
except: |
|
st.write(f"Error loading {os.path.basename(img_path)}") |
|
|
|
|
|
if save_options == "Same folder as images": |
|
save_dir = batch_folder |
|
elif save_options == "Custom location": |
|
save_dir = custom_save_dir |
|
else: |
|
app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
save_dir = os.path.join(app_dir, "saved_tags") |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
st.markdown("---") |
|
process_col1, process_col2 = st.columns([3, 1]) |
|
with process_col1: |
|
st.write(f"Ready to process {total_images} images") |
|
st.write(f"Tags will be saved to: **{save_dir}**") |
|
|
|
with process_col2: |
|
|
|
process_button_disabled = not st.session_state.model_loaded |
|
|
|
|
|
if st.button("🔄 Process All Images", |
|
key="process_batch_btn", |
|
use_container_width=True, |
|
disabled=process_button_disabled, |
|
type="primary"): |
|
|
|
if not st.session_state.model_loaded: |
|
st.error("Model not loaded. Please check the model settings.") |
|
else: |
|
with st.spinner("Processing images..."): |
|
|
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
|
|
|
|
def update_progress(current, total, image_path): |
|
if total > 0: |
|
progress = min(current / total, 1.0) |
|
progress_bar.progress(progress) |
|
if image_path: |
|
status_text.text(f"Processing {current}/{total}: {os.path.basename(image_path)}") |
|
else: |
|
status_text.text(f"Completed processing {current}/{total} images") |
|
|
|
|
|
curr_threshold_profile = st.session_state.settings['threshold_profile'] |
|
curr_active_threshold = st.session_state.settings['active_threshold'] |
|
curr_active_category_thresholds = st.session_state.settings['active_category_thresholds'] |
|
curr_min_confidence = st.session_state.settings['min_confidence'] |
|
|
|
|
|
curr_category_limits = None |
|
if 'category_limits' in st.session_state and enable_category_limits: |
|
curr_category_limits = st.session_state.category_limits |
|
|
|
|
|
if curr_category_limits: |
|
st.write("Category limit settings:") |
|
|
|
|
|
excluded = [] |
|
limited = [] |
|
unlimited = [] |
|
|
|
for cat, limit in sorted(curr_category_limits.items()): |
|
if limit == 0: |
|
excluded.append(cat) |
|
elif limit > 0: |
|
limited.append(f"{cat}: top {limit}") |
|
else: |
|
unlimited.append(cat) |
|
|
|
|
|
if excluded: |
|
st.write("❌ Excluded categories: " + ", ".join(excluded)) |
|
|
|
if limited: |
|
st.write("⚙️ Limited categories: " + ", ".join(limited)) |
|
|
|
if unlimited: |
|
st.write("✅ Unlimited categories: " + ", ".join(unlimited)) |
|
|
|
if not excluded and not limited: |
|
st.write("No limits set (all categories included)") |
|
|
|
|
|
if st.session_state.model_type == "onnx": |
|
|
|
batch_results = batch_process_images_onnx( |
|
folder_path=batch_folder, |
|
model_path=st.session_state.onnx_model_path, |
|
metadata_path=st.session_state.onnx_metadata_path, |
|
threshold_profile=curr_threshold_profile, |
|
active_threshold=curr_active_threshold, |
|
active_category_thresholds=curr_active_category_thresholds, |
|
save_dir=save_dir, |
|
progress_callback=update_progress, |
|
min_confidence=curr_min_confidence, |
|
batch_size=batch_size, |
|
category_limits=curr_category_limits |
|
) |
|
else: |
|
|
|
batch_results = batch_process_images( |
|
folder_path=batch_folder, |
|
model=st.session_state.model, |
|
thresholds=st.session_state.thresholds, |
|
metadata=st.session_state.metadata, |
|
threshold_profile=curr_threshold_profile, |
|
active_threshold=curr_active_threshold, |
|
active_category_thresholds=curr_active_category_thresholds, |
|
save_dir=save_dir, |
|
progress_callback=update_progress, |
|
min_confidence=curr_min_confidence, |
|
batch_size=batch_size, |
|
category_limits=st.session_state.category_limits if enable_category_limits else None |
|
) |
|
|
|
|
|
display_batch_results(batch_results) |
|
|
|
|
|
if not st.session_state.model_loaded: |
|
st.warning("Please load a model before processing images.") |
|
|
|
else: |
|
st.warning("No image files found in the selected folder.") |
|
elif batch_folder: |
|
st.error(f"Folder not found: {batch_folder}") |
|
|
|
|
|
with col2: |
|
st.header("Tagging Controls") |
|
|
|
|
|
all_profiles = [ |
|
"Micro Optimized", |
|
"Macro Optimized", |
|
"Balanced", |
|
"High Precision", |
|
"High Recall", |
|
"Overall", |
|
"Weighted", |
|
"Category-specific" |
|
] |
|
|
|
|
|
default_index = 2 |
|
if "threshold_profile" in st.session_state.settings: |
|
|
|
existing_profile = st.session_state.settings['threshold_profile'] |
|
if existing_profile in all_profiles: |
|
default_index = all_profiles.index(existing_profile) |
|
|
|
elif existing_profile == "Overall": |
|
default_index = all_profiles.index("Overall") |
|
elif existing_profile == "Weighted": |
|
default_index = all_profiles.index("Weighted") |
|
elif existing_profile == "Category-specific": |
|
default_index = all_profiles.index("Category-specific") |
|
elif existing_profile == "High Precision": |
|
default_index = all_profiles.index("High Precision") |
|
elif existing_profile == "High Recall": |
|
default_index = all_profiles.index("High Recall") |
|
|
|
|
|
profile_col1, profile_col2 = st.columns([3, 1]) |
|
|
|
with profile_col1: |
|
|
|
threshold_profile = st.selectbox( |
|
"Select threshold profile", |
|
options=all_profiles, |
|
index=default_index, |
|
key="threshold_profile", |
|
on_change=on_threshold_profile_change |
|
) |
|
|
|
with profile_col2: |
|
|
|
if st.button("ℹ️ Help", key="profile_help"): |
|
st.session_state.show_profile_help = not st.session_state.get('show_profile_help', False) |
|
|
|
|
|
if st.session_state.get('show_profile_help', False): |
|
st.markdown(threshold_profile_explanations[threshold_profile]) |
|
else: |
|
|
|
st.info(threshold_profile_descriptions[threshold_profile]) |
|
|
|
|
|
if st.session_state.model_loaded: |
|
|
|
model_type = "refined" if st.session_state.model_type == "full" else "initial" |
|
metrics = get_profile_metrics(st.session_state.thresholds, threshold_profile, model_type) |
|
|
|
if metrics: |
|
|
|
metrics_cols = st.columns(3) |
|
|
|
with metrics_cols[0]: |
|
|
|
threshold_value = metrics.get("threshold", 0.35) |
|
st.metric("Threshold", f"{threshold_value:.3f}") |
|
|
|
with metrics_cols[1]: |
|
|
|
micro_f1 = metrics.get("micro_f1", metrics.get("micro_precision", 0)) |
|
st.metric("Micro F1", f"{micro_f1:.3f}" if micro_f1 else "N/A") |
|
|
|
|
|
precision = metrics.get("precision", metrics.get("micro_precision", 0)) |
|
if precision: |
|
st.metric("Precision", f"{precision:.3f}") |
|
|
|
with metrics_cols[2]: |
|
|
|
macro_f1 = metrics.get("macro_f1", 0) |
|
st.metric("Macro F1", f"{macro_f1:.3f}" if macro_f1 else "N/A") |
|
|
|
|
|
recall = metrics.get("recall", metrics.get("micro_recall", 0)) |
|
if recall: |
|
st.metric("Recall", f"{recall:.3f}") |
|
|
|
|
|
active_threshold = None |
|
active_category_thresholds = {} |
|
|
|
|
|
if st.session_state.model_loaded: |
|
|
|
if "initial" in st.session_state.thresholds and "refined" in st.session_state.thresholds: |
|
model_type_key = "refined" if st.session_state.model_type == "full" else "initial" |
|
|
|
|
|
if model_type_key not in st.session_state.thresholds: |
|
model_type_key = "refined" if "refined" in st.session_state.thresholds else "initial" |
|
else: |
|
model_type_key = None |
|
|
|
|
|
profile_key = None |
|
if threshold_profile == "Micro Optimized": |
|
profile_key = "micro_opt" |
|
elif threshold_profile == "Macro Optimized": |
|
profile_key = "macro_opt" |
|
elif threshold_profile == "Balanced": |
|
profile_key = "balanced" |
|
elif threshold_profile == "High Precision": |
|
profile_key = "high_precision" |
|
elif threshold_profile == "High Recall": |
|
profile_key = "high_recall" |
|
|
|
|
|
if profile_key: |
|
|
|
if model_type_key is not None: |
|
if "overall" in st.session_state.thresholds[model_type_key] and profile_key in st.session_state.thresholds[model_type_key]["overall"]: |
|
active_threshold = st.session_state.thresholds[model_type_key]["overall"][profile_key]["threshold"] |
|
else: |
|
if "overall" in st.session_state.thresholds and profile_key in st.session_state.thresholds["overall"]: |
|
active_threshold = st.session_state.thresholds["overall"][profile_key]["threshold"] |
|
|
|
|
|
for category in st.session_state.categories: |
|
if model_type_key is not None: |
|
if "categories" in st.session_state.thresholds[model_type_key] and category in st.session_state.thresholds[model_type_key]["categories"]: |
|
if profile_key in st.session_state.thresholds[model_type_key]["categories"][category]: |
|
active_category_thresholds[category] = st.session_state.thresholds[model_type_key]["categories"][category][profile_key]["threshold"] |
|
else: |
|
|
|
active_category_thresholds[category] = active_threshold |
|
else: |
|
active_category_thresholds[category] = active_threshold |
|
else: |
|
if "categories" in st.session_state.thresholds and category in st.session_state.thresholds["categories"]: |
|
if profile_key in st.session_state.thresholds["categories"][category]: |
|
active_category_thresholds[category] = st.session_state.thresholds["categories"][category][profile_key]["threshold"] |
|
else: |
|
active_category_thresholds[category] = active_threshold |
|
else: |
|
active_category_thresholds[category] = active_threshold |
|
|
|
|
|
st.info(f"The '{threshold_profile}' profile uses pre-optimized thresholds.") |
|
|
|
|
|
st.slider( |
|
"Overall threshold (reference)", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(active_threshold), |
|
step=0.01, |
|
disabled=True |
|
) |
|
|
|
elif threshold_profile == "Overall" and st.session_state.model_loaded: |
|
|
|
if model_type_key is not None: |
|
if "overall" in st.session_state.thresholds[model_type_key] and "balanced" in st.session_state.thresholds[model_type_key]["overall"]: |
|
active_threshold = st.session_state.thresholds[model_type_key]["overall"]["balanced"]["threshold"] |
|
else: |
|
if "overall" in st.session_state.thresholds and "balanced" in st.session_state.thresholds["overall"]: |
|
active_threshold = st.session_state.thresholds["overall"]["balanced"]["threshold"] |
|
|
|
|
|
active_threshold = st.slider( |
|
"Overall threshold", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(active_threshold), |
|
step=0.01 |
|
) |
|
|
|
elif threshold_profile == "Weighted" and st.session_state.model_loaded: |
|
|
|
if model_type_key is not None: |
|
if "overall" in st.session_state.thresholds[model_type_key] and "balanced" in st.session_state.thresholds[model_type_key]["overall"]: |
|
active_threshold = st.session_state.thresholds[model_type_key]["overall"]["balanced"]["threshold"] |
|
else: |
|
if "overall" in st.session_state.thresholds and "balanced" in st.session_state.thresholds["overall"]: |
|
active_threshold = st.session_state.thresholds["overall"]["balanced"]["threshold"] |
|
|
|
|
|
st.slider( |
|
"Overall threshold (reference)", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(active_threshold), |
|
step=0.01, |
|
disabled=True |
|
) |
|
|
|
st.info("The 'Weighted' profile uses different optimized thresholds for each category.") |
|
|
|
|
|
if model_type_key is not None: |
|
if "weighted" in st.session_state.thresholds[model_type_key]: |
|
weighted_thresholds = st.session_state.thresholds[model_type_key]["weighted"] |
|
for category in st.session_state.categories: |
|
if category in weighted_thresholds: |
|
active_category_thresholds[category] = weighted_thresholds[category] |
|
else: |
|
|
|
if "categories" in st.session_state.thresholds[model_type_key] and category in st.session_state.thresholds[model_type_key]["categories"]: |
|
if "balanced" in st.session_state.thresholds[model_type_key]["categories"][category]: |
|
active_category_thresholds[category] = st.session_state.thresholds[model_type_key]["categories"][category]["balanced"]["threshold"] |
|
else: |
|
active_category_thresholds[category] = active_threshold |
|
else: |
|
active_category_thresholds[category] = active_threshold |
|
else: |
|
if "weighted" in st.session_state.thresholds: |
|
weighted_thresholds = st.session_state.thresholds["weighted"] |
|
for category in st.session_state.categories: |
|
if category in weighted_thresholds: |
|
active_category_thresholds[category] = weighted_thresholds[category] |
|
else: |
|
|
|
if "categories" in st.session_state.thresholds and category in st.session_state.thresholds["categories"]: |
|
if "balanced" in st.session_state.thresholds["categories"][category]: |
|
active_category_thresholds[category] = st.session_state.thresholds["categories"][category]["balanced"]["threshold"] |
|
else: |
|
active_category_thresholds[category] = active_threshold |
|
else: |
|
active_category_thresholds[category] = active_threshold |
|
|
|
elif threshold_profile == "Category-specific" and st.session_state.model_loaded: |
|
|
|
if model_type_key is not None: |
|
if "overall" in st.session_state.thresholds[model_type_key] and "balanced" in st.session_state.thresholds[model_type_key]["overall"]: |
|
active_threshold = st.session_state.thresholds[model_type_key]["overall"]["balanced"]["threshold"] |
|
else: |
|
if "overall" in st.session_state.thresholds and "balanced" in st.session_state.thresholds["overall"]: |
|
active_threshold = st.session_state.thresholds["overall"]["balanced"]["threshold"] |
|
|
|
|
|
st.slider( |
|
"Overall threshold (reference)", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(active_threshold), |
|
step=0.01, |
|
disabled=True |
|
) |
|
|
|
st.write("Adjust thresholds for individual categories:") |
|
|
|
|
|
slider_cols = st.columns(2) |
|
|
|
|
|
for i, category in enumerate(sorted(st.session_state.categories)): |
|
|
|
category_threshold = None |
|
|
|
if model_type_key is not None: |
|
if "categories" in st.session_state.thresholds[model_type_key] and category in st.session_state.thresholds[model_type_key]["categories"]: |
|
if "balanced" in st.session_state.thresholds[model_type_key]["categories"][category]: |
|
category_threshold = st.session_state.thresholds[model_type_key]["categories"][category]["balanced"]["threshold"] |
|
else: |
|
category_threshold = active_threshold |
|
else: |
|
category_threshold = active_threshold |
|
else: |
|
if "categories" in st.session_state.thresholds and category in st.session_state.thresholds["categories"]: |
|
if "balanced" in st.session_state.thresholds["categories"][category]: |
|
category_threshold = st.session_state.thresholds["categories"][category]["balanced"]["threshold"] |
|
else: |
|
category_threshold = active_threshold |
|
else: |
|
category_threshold = active_threshold |
|
|
|
|
|
col_idx = i % 2 |
|
with slider_cols[col_idx]: |
|
active_category_thresholds[category] = st.slider( |
|
f"{category.capitalize()}", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(category_threshold), |
|
step=0.01, |
|
key=f"slider_{category}" |
|
) |
|
|
|
|
|
if active_threshold is not None: |
|
st.session_state.settings['active_threshold'] = active_threshold |
|
if active_category_thresholds: |
|
st.session_state.settings['active_category_thresholds'] = active_category_thresholds |
|
|
|
|
|
with st.expander("Threshold Profile Details"): |
|
|
|
if st.session_state.model_loaded: |
|
threshold_tabs = st.tabs(["About Metrics"]) |
|
|
|
with threshold_tabs[0]: |
|
st.markdown(""" |
|
### Understanding Performance Metrics |
|
|
|
**F1 Score** is the harmonic mean of precision and recall: `2 * (precision * recall) / (precision + recall)` |
|
|
|
**Micro F1** calculates metrics globally by considering each example/prediction pair. This gives more weight to categories with more examples. |
|
|
|
**Macro F1** calculates F1 separately for each category and then takes the average. This treats all categories equally regardless of their size. |
|
""") |
|
|
|
st.markdown(create_micro_macro_comparison(), unsafe_allow_html=True) |
|
|
|
st.markdown(""" |
|
### Other Metrics |
|
|
|
**Precision** measures how many of the predicted tags are correct: `true_positives / (true_positives + false_positives)` |
|
|
|
**Recall** measures how many of the relevant tags are captured: `true_positives / (true_positives + false_negatives)` |
|
|
|
### The Precision-Recall Tradeoff |
|
|
|
There's an inherent tradeoff between precision and recall: |
|
- Higher threshold → Higher precision, Lower recall |
|
- Lower threshold → Lower precision, Higher recall |
|
|
|
The best threshold depends on your specific use case: |
|
- **Prefer Precision**: When false positives are costly (e.g., you want only accurate tags) |
|
- **Prefer Recall**: When false negatives are costly (e.g., you don't want to miss any potentially relevant tags) |
|
- **Balanced**: When both types of errors are equally important |
|
""") |
|
else: |
|
st.info("Load a model to see detailed threshold information.") |
|
|
|
|
|
display_options = st.expander("Display Options", expanded=False) |
|
with display_options: |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
show_all_tags = st.checkbox("Show all tags (including below threshold)", |
|
value=st.session_state.settings['show_all_tags']) |
|
compact_view = st.checkbox("Compact view (hide progress bars)", |
|
value=st.session_state.settings['compact_view']) |
|
|
|
|
|
replace_underscores = st.checkbox("Replace underscores with spaces", |
|
value=st.session_state.settings.get('replace_underscores', False)) |
|
|
|
with col2: |
|
min_confidence = st.slider("Minimum confidence to display", 0.0, 0.5, |
|
st.session_state.settings['min_confidence'], 0.01) |
|
|
|
|
|
st.session_state.settings['show_all_tags'] = show_all_tags |
|
st.session_state.settings['compact_view'] = compact_view |
|
st.session_state.settings['min_confidence'] = min_confidence |
|
st.session_state.settings['replace_underscores'] = replace_underscores |
|
|
|
|
|
st.write("Categories to include in 'All Tags' section:") |
|
|
|
|
|
category_cols = st.columns(3) |
|
selected_categories = {} |
|
|
|
|
|
if hasattr(st.session_state, 'categories'): |
|
for i, category in enumerate(sorted(st.session_state.categories)): |
|
col_idx = i % 3 |
|
with category_cols[col_idx]: |
|
|
|
default_val = st.session_state.settings['selected_categories'].get(category, True) |
|
selected_categories[category] = st.checkbox( |
|
f"{category.capitalize()}", |
|
value=default_val, |
|
key=f"cat_select_{category}" |
|
) |
|
|
|
|
|
st.session_state.settings['selected_categories'] = selected_categories |
|
|
|
if st.session_state.model_loaded: |
|
if st.session_state.model_type == "onnx": |
|
model_type_display = "ONNX Accelerated Model" |
|
elif st.session_state.model_type == "full": |
|
model_type_display = "Full Model" |
|
else: |
|
model_type_display = "Initial-Only Model (Lower VRAM)" |
|
|
|
st.info(f"Using: {model_type_display}") |
|
|
|
|
|
if image_path and st.button("Run Tagging"): |
|
if not st.session_state.model_loaded: |
|
st.error("Model not loaded. Please check the model settings.") |
|
else: |
|
with st.spinner("Analyzing image..."): |
|
try: |
|
inference_start = time.time() |
|
|
|
|
|
if st.session_state.model_type == "onnx": |
|
|
|
from utils.onnx_processing import process_single_image_onnx |
|
|
|
|
|
result = process_single_image_onnx( |
|
image_path=image_path, |
|
model_path=st.session_state.onnx_model_path, |
|
metadata=st.session_state.metadata, |
|
threshold_profile=threshold_profile, |
|
active_threshold=active_threshold, |
|
active_category_thresholds=active_category_thresholds, |
|
min_confidence=min_confidence |
|
) |
|
else: |
|
|
|
result = process_image( |
|
image_path=image_path, |
|
model=st.session_state.model, |
|
thresholds=st.session_state.thresholds, |
|
metadata=st.session_state.metadata, |
|
threshold_profile=threshold_profile, |
|
active_threshold=active_threshold, |
|
active_category_thresholds=active_category_thresholds, |
|
min_confidence=min_confidence |
|
) |
|
|
|
inference_time = time.time() - inference_start |
|
|
|
if result['success']: |
|
|
|
st.session_state.all_probs = result['all_probs'] |
|
st.session_state.tags = result['tags'] |
|
st.session_state.all_tags = result['all_tags'] |
|
|
|
st.success(f"Analysis completed in {inference_time:.2f} seconds") |
|
else: |
|
st.error(f"Inference failed: {result.get('error', 'Unknown error')}") |
|
|
|
except Exception as e: |
|
st.error(f"Inference error: {str(e)}") |
|
st.code(traceback.format_exc()) |
|
|
|
|
|
if image_path and hasattr(st.session_state, 'all_probs'): |
|
st.header("Predictions") |
|
|
|
|
|
filtered_tags, current_all_tags = apply_thresholds( |
|
st.session_state.all_probs, |
|
threshold_profile, |
|
active_threshold, |
|
active_category_thresholds, |
|
min_confidence, |
|
st.session_state.settings['selected_categories'] |
|
) |
|
|
|
|
|
st.session_state.tags = filtered_tags |
|
st.session_state.all_tags = current_all_tags |
|
|
|
|
|
|
|
all_tags = [] |
|
|
|
for category in sorted(st.session_state.all_probs.keys()): |
|
|
|
all_tags_in_category = st.session_state.all_probs.get(category, []) |
|
filtered_tags_in_category = filtered_tags.get(category, []) |
|
|
|
|
|
if all_tags_in_category: |
|
|
|
if threshold_profile in ["Overall", "Weighted"]: |
|
threshold = active_threshold |
|
else: |
|
threshold = active_category_thresholds.get(category, active_threshold) |
|
|
|
|
|
expander_label = f"{category.capitalize()} ({len(filtered_tags_in_category)} tags)" |
|
|
|
with st.expander(expander_label, expanded=True): |
|
|
|
threshold_row = st.columns([1, 2]) |
|
with threshold_row[0]: |
|
|
|
pass |
|
|
|
with threshold_row[1]: |
|
|
|
cat_slider_key = f"cat_threshold_{category}" |
|
|
|
|
|
if threshold_profile in ["Overall", "Weighted"]: |
|
|
|
current_cat_threshold = active_threshold |
|
else: |
|
|
|
current_cat_threshold = active_category_thresholds.get(category, active_threshold) |
|
|
|
|
|
new_threshold = st.slider( |
|
f"Adjust {category.capitalize()} threshold:", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(current_cat_threshold), |
|
step=0.01, |
|
key=cat_slider_key, |
|
disabled=(threshold_profile in ["Overall", "Weighted"]) |
|
) |
|
|
|
|
|
if threshold_profile in ["Overall", "Weighted"]: |
|
st.info(f"Using global {threshold_profile.lower()} threshold. Switch to Category-specific mode to adjust individual categories.") |
|
else: |
|
|
|
active_category_thresholds[category] = new_threshold |
|
threshold = new_threshold |
|
|
|
|
|
active_category_thresholds[category] = new_threshold |
|
|
|
|
|
|
|
threshold = new_threshold |
|
|
|
|
|
if show_all_tags: |
|
tags_to_display = all_tags_in_category |
|
else: |
|
|
|
|
|
tags_to_display = [(tag, prob) for tag, prob in all_tags_in_category if prob >= threshold] |
|
filtered_tags[category] = tags_to_display |
|
|
|
|
|
limit_col1, limit_col2 = st.columns([1, 2]) |
|
with limit_col1: |
|
|
|
limit_key = f"limit_{category}_tags" |
|
limit_tags_for_category = st.checkbox("Limit tags", value=False, key=limit_key) |
|
|
|
with limit_col2: |
|
|
|
slider_key = f"top_n_{category}_tags" |
|
|
|
|
|
tag_count = len(tags_to_display) |
|
|
|
|
|
if tag_count > 0: |
|
|
|
min_value = 0 |
|
|
|
max_value = max(1, min(999, tag_count)) |
|
|
|
default_value = min(max_value, 5) |
|
|
|
top_n_tags_for_category = st.slider( |
|
"Show top", |
|
min_value=min_value, |
|
max_value=max_value, |
|
value=default_value, |
|
step=1, |
|
disabled=not limit_tags_for_category, |
|
key=slider_key |
|
) |
|
else: |
|
|
|
top_n_tags_for_category = 5 |
|
st.write("No tags to display") |
|
|
|
st.markdown("---") |
|
|
|
if not tags_to_display: |
|
st.info(f"No tags above {min_confidence:.2f} confidence threshold") |
|
continue |
|
|
|
|
|
original_count = len(tags_to_display) |
|
if limit_tags_for_category and tags_to_display: |
|
limited_tags_to_display = tags_to_display[:top_n_tags_for_category] |
|
display_count = len(limited_tags_to_display) |
|
else: |
|
limited_tags_to_display = tags_to_display |
|
display_count = original_count |
|
|
|
|
|
if compact_view: |
|
|
|
tag_list = [] |
|
|
|
replace_underscores = st.session_state.settings.get('replace_underscores', False) |
|
for tag, prob in limited_tags_to_display: |
|
|
|
percentage = int(prob * 100) |
|
|
|
|
|
display_tag = tag.replace('_', ' ') if replace_underscores else tag |
|
tag_list.append(f"{display_tag} ({percentage}%)") |
|
|
|
|
|
|
|
if prob >= threshold and selected_categories.get(category, True): |
|
all_tags.append(tag) |
|
|
|
|
|
st.markdown(", ".join(tag_list)) |
|
else: |
|
|
|
for tag, prob in limited_tags_to_display: |
|
|
|
replace_underscores = st.session_state.settings.get('replace_underscores', False) |
|
|
|
|
|
display_tag = tag.replace('_', ' ') if replace_underscores else tag |
|
|
|
|
|
if prob >= threshold and selected_categories.get(category, True): |
|
all_tags.append(tag) |
|
tag_display = f"**{display_tag}**" |
|
else: |
|
tag_display = display_tag |
|
|
|
|
|
st.write(tag_display) |
|
st.markdown(display_progress_bar(prob), unsafe_allow_html=True) |
|
|
|
|
|
if limit_tags_for_category and original_count > display_count: |
|
st.caption(f"Showing top {display_count} of {original_count} qualifying tags.") |
|
|
|
|
|
st.markdown("---") |
|
st.subheader(f"All Tags ({len(all_tags)} total)") |
|
if all_tags: |
|
|
|
replace_underscores = st.session_state.settings.get('replace_underscores', False) |
|
|
|
if replace_underscores: |
|
|
|
display_tags = [tag.replace('_', ' ') for tag in all_tags] |
|
st.write(", ".join(display_tags)) |
|
else: |
|
|
|
st.write(", ".join(all_tags)) |
|
else: |
|
st.info("No tags detected above the threshold.") |
|
|
|
|
|
st.markdown("---") |
|
st.subheader("Save Tags") |
|
|
|
|
|
save_col = st.columns(1)[0] |
|
|
|
with save_col: |
|
|
|
if 'custom_folders' not in st.session_state: |
|
|
|
st.session_state.custom_folders = get_default_save_locations() |
|
|
|
|
|
selected_folder = st.selectbox( |
|
"Select save location:", |
|
options=st.session_state.custom_folders, |
|
format_func=lambda x: os.path.basename(x) if os.path.basename(x) else x |
|
) |
|
|
|
|
|
new_folder = st.text_input("Or enter a new folder path:") |
|
|
|
if st.button("Add Folder", key="add_folder") and new_folder: |
|
if os.path.isdir(new_folder): |
|
if new_folder not in st.session_state.custom_folders: |
|
st.session_state.custom_folders.append(new_folder) |
|
st.success(f"Added folder: {new_folder}") |
|
st.rerun() |
|
else: |
|
st.info("This folder is already in the list.") |
|
else: |
|
try: |
|
|
|
os.makedirs(new_folder, exist_ok=True) |
|
st.session_state.custom_folders.append(new_folder) |
|
st.success(f"Created and added folder: {new_folder}") |
|
st.rerun() |
|
except Exception as e: |
|
st.error(f"Could not create folder: {str(e)}") |
|
|
|
|
|
if st.button("💾 Save to Selected Location"): |
|
try: |
|
|
|
original_filename = st.session_state.original_filename if hasattr(st.session_state, 'original_filename') else None |
|
|
|
|
|
saved_path = save_tags_to_file( |
|
image_path=image_path, |
|
all_tags=all_tags, |
|
original_filename=original_filename, |
|
custom_dir=selected_folder, |
|
overwrite=True |
|
) |
|
|
|
st.success(f"Tags saved to: {os.path.basename(saved_path)}") |
|
st.info(f"Full path: {saved_path}") |
|
|
|
|
|
with st.expander("File Contents", expanded=True): |
|
with open(saved_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
st.code(content, language='text') |
|
|
|
except Exception as e: |
|
st.error(f"Error saving tags: {str(e)}") |
|
st.code(traceback.format_exc()) |
|
|
|
if __name__ == "__main__": |
|
image_tagger_app() |