camie-tagger / app /utils /model_loader.py
Camais03's picture
V1.5
29b445b verified
"""
Model loading utilities for Image Tagger application.
"""
import os
import json
import torch
import platform
import traceback
import importlib.util
def is_windows():
"""Check if the system is Windows"""
return platform.system() == "Windows"
class DummyDataset:
"""Minimal dataset class for inference"""
def __init__(self, metadata):
self.total_tags = metadata['total_tags']
self.idx_to_tag = {int(k): v for k, v in metadata['idx_to_tag'].items()}
self.tag_to_category = metadata['tag_to_category']
def get_tag_info(self, idx):
tag = self.idx_to_tag.get(idx, f"unknown_{idx}")
category = self.tag_to_category.get(tag, "general")
return tag, category
def load_model_code(model_dir):
"""
Load the model code module from the model directory.
Args:
model_dir: Path to the model directory
Returns:
Imported model code module
"""
model_code_path = os.path.join(model_dir, "model_code.py")
if not os.path.exists(model_code_path):
raise FileNotFoundError(f"model_code.py not found at {model_code_path}")
# Import the model code dynamically
spec = importlib.util.spec_from_file_location("model_code", model_code_path)
model_code = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_code)
# Check if required classes exist
if not hasattr(model_code, 'ImageTagger') or not hasattr(model_code, 'FlashAttention'):
raise ImportError("Required classes not found in model_code.py")
return model_code
def check_flash_attention():
"""
Check if Flash Attention is properly installed.
Returns:
bool: True if Flash Attention is available and working
"""
try:
import flash_attn
if hasattr(flash_attn, 'flash_attn_func'):
module_path = flash_attn.flash_attn_func.__module__
return 'flash_attn_fallback' not in module_path
except:
pass
return False
def estimate_model_memory_usage(model, device):
"""
Estimate the memory usage of a model in MB.
"""
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem_total = mem_params + mem_bufs # in bytes
return mem_total / (1024 * 1024) # convert to MB
def load_exported_model(model_dir, model_type="full"):
"""
Load the exported model and metadata with correct precision.
Args:
model_dir: Directory containing the model files
model_type: "full" or "initial_only"
Returns:
model, thresholds, metadata
"""
print(f"Loading {model_type} model from {model_dir}")
# Make sure we have the absolute path to the model directory
model_dir = os.path.abspath(model_dir)
print(f"Absolute model path: {model_dir}")
# Check for required files
metadata_path = os.path.join(model_dir, "metadata.json")
thresholds_path = os.path.join(model_dir, "thresholds.json")
print(f"Looking for thresholds at: {thresholds_path}")
# Check platform and Flash Attention status
windows_system = is_windows()
flash_attn_installed = check_flash_attention()
# Add a specific warning for Windows users trying to use the full model without Flash Attention
if windows_system and model_type == "full" and not flash_attn_installed:
print("Note: On Windows without Flash Attention, the full model will not work")
print(" which may produce less accurate results.")
print(" Consider using the 'initial_only' model for better performance on Windows.")
# Determine file paths based on model type
if model_type == "initial_only":
# Try both naming conventions
if os.path.exists(os.path.join(model_dir, "model_initial_only.pt")):
model_path = os.path.join(model_dir, "model_initial_only.pt")
else:
model_path = os.path.join(model_dir, "model_initial.pt")
# Try both naming conventions for info file
if os.path.exists(os.path.join(model_dir, "model_info_initial_only.json")):
model_info_path = os.path.join(model_dir, "model_info_initial_only.json")
else:
model_info_path = os.path.join(model_dir, "model_info_initial.json")
else:
# Try multiple naming conventions for the full model
model_filenames = ["model_refined.pt", "model.pt", "model_full.pt"]
model_path = None
for filename in model_filenames:
path = os.path.join(model_dir, filename)
if os.path.exists(path):
model_path = path
break
if model_path is None:
raise FileNotFoundError(f"No model file found in {model_dir}. Looked for: {', '.join(model_filenames)}")
model_info_path = os.path.join(model_dir, "model_info.json")
# Check for required files
metadata_path = os.path.join(model_dir, "metadata.json")
thresholds_path = os.path.join(model_dir, "thresholds.json")
required_files = [metadata_path, thresholds_path, model_path]
for file_path in required_files:
if not os.path.exists(file_path):
raise FileNotFoundError(f"Required file {file_path} not found")
# Load metadata
with open(metadata_path, "r") as f:
metadata = json.load(f)
# Load model code
model_code = load_model_code(model_dir)
# Create dataset
dummy_dataset = DummyDataset(metadata)
# Determine device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model info
if os.path.exists(model_info_path):
with open(model_info_path, 'r') as f:
model_info = json.load(f)
print("Loaded model info:", model_info)
tag_context_size = model_info.get('tag_context_size', 256)
num_heads = model_info.get('num_heads', 16)
else:
print("Model info not found, using defaults")
tag_context_size = 256
num_heads = 16
try:
# Check if InitialOnlyImageTagger class exists
has_initial_only_class = hasattr(model_code, 'InitialOnlyImageTagger')
# Create the appropriate model type
if model_type == "initial_only":
# Create the lightweight model
if has_initial_only_class:
model = model_code.InitialOnlyImageTagger(
total_tags=metadata['total_tags'],
dataset=dummy_dataset,
pretrained=False
)
else:
# Fallback to using ImageTagger for initial-only if the specific class isn't available
print("InitialOnlyImageTagger class not found. Using ImageTagger as fallback.")
model = model_code.ImageTagger(
total_tags=metadata['total_tags'],
dataset=dummy_dataset,
pretrained=False,
tag_context_size=tag_context_size,
num_heads=num_heads
)
else:
# Create the full model
model = model_code.ImageTagger(
total_tags=metadata['total_tags'],
dataset=dummy_dataset,
pretrained=False,
tag_context_size=tag_context_size,
num_heads=num_heads
)
# Load state dict
state_dict = torch.load(model_path, map_location=device)
# Try loading with strict=True first, then fall back to strict=False
try:
model.load_state_dict(state_dict, strict=True)
print("✓ Model loaded with strict=True")
except Exception as e:
print(f"Warning: Strict loading failed: {str(e)}")
print("Attempting to load with strict=False...")
model.load_state_dict(state_dict, strict=False)
print("✓ Model loaded with strict=False")
# Ensure model is in half precision to match training conditions
model = model.to(device=device, dtype=torch.float16)
model.eval()
# Check parameter dtype
param_dtype = next(model.parameters()).dtype
print(f"Model loaded successfully on {device} with precision {param_dtype}")
print(f"Model memory usage: {estimate_model_memory_usage(model, device):.2f} MB")
except Exception as e:
print(f"Error loading model: {str(e)}")
traceback.print_exc()
raise
# Load thresholds
with open(thresholds_path, "r") as f:
thresholds = json.load(f)
return model, thresholds, metadata