|
"""
|
|
Model loading utilities for Image Tagger application.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import torch
|
|
import platform
|
|
import traceback
|
|
import importlib.util
|
|
|
|
|
|
def is_windows():
|
|
"""Check if the system is Windows"""
|
|
return platform.system() == "Windows"
|
|
|
|
|
|
class DummyDataset:
|
|
"""Minimal dataset class for inference"""
|
|
def __init__(self, metadata):
|
|
self.total_tags = metadata['total_tags']
|
|
self.idx_to_tag = {int(k): v for k, v in metadata['idx_to_tag'].items()}
|
|
self.tag_to_category = metadata['tag_to_category']
|
|
|
|
def get_tag_info(self, idx):
|
|
tag = self.idx_to_tag.get(idx, f"unknown_{idx}")
|
|
category = self.tag_to_category.get(tag, "general")
|
|
return tag, category
|
|
|
|
|
|
def load_model_code(model_dir):
|
|
"""
|
|
Load the model code module from the model directory.
|
|
|
|
Args:
|
|
model_dir: Path to the model directory
|
|
|
|
Returns:
|
|
Imported model code module
|
|
"""
|
|
model_code_path = os.path.join(model_dir, "model_code.py")
|
|
|
|
if not os.path.exists(model_code_path):
|
|
raise FileNotFoundError(f"model_code.py not found at {model_code_path}")
|
|
|
|
|
|
spec = importlib.util.spec_from_file_location("model_code", model_code_path)
|
|
model_code = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(model_code)
|
|
|
|
|
|
if not hasattr(model_code, 'ImageTagger') or not hasattr(model_code, 'FlashAttention'):
|
|
raise ImportError("Required classes not found in model_code.py")
|
|
|
|
return model_code
|
|
|
|
|
|
def check_flash_attention():
|
|
"""
|
|
Check if Flash Attention is properly installed.
|
|
|
|
Returns:
|
|
bool: True if Flash Attention is available and working
|
|
"""
|
|
try:
|
|
import flash_attn
|
|
if hasattr(flash_attn, 'flash_attn_func'):
|
|
module_path = flash_attn.flash_attn_func.__module__
|
|
return 'flash_attn_fallback' not in module_path
|
|
except:
|
|
pass
|
|
return False
|
|
|
|
|
|
def estimate_model_memory_usage(model, device):
|
|
"""
|
|
Estimate the memory usage of a model in MB.
|
|
"""
|
|
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
|
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
|
mem_total = mem_params + mem_bufs
|
|
return mem_total / (1024 * 1024)
|
|
|
|
|
|
def load_exported_model(model_dir, model_type="full"):
|
|
"""
|
|
Load the exported model and metadata with correct precision.
|
|
|
|
Args:
|
|
model_dir: Directory containing the model files
|
|
model_type: "full" or "initial_only"
|
|
|
|
Returns:
|
|
model, thresholds, metadata
|
|
"""
|
|
print(f"Loading {model_type} model from {model_dir}")
|
|
|
|
|
|
model_dir = os.path.abspath(model_dir)
|
|
print(f"Absolute model path: {model_dir}")
|
|
|
|
|
|
metadata_path = os.path.join(model_dir, "metadata.json")
|
|
thresholds_path = os.path.join(model_dir, "thresholds.json")
|
|
|
|
print(f"Looking for thresholds at: {thresholds_path}")
|
|
|
|
|
|
windows_system = is_windows()
|
|
flash_attn_installed = check_flash_attention()
|
|
|
|
|
|
if windows_system and model_type == "full" and not flash_attn_installed:
|
|
print("Note: On Windows without Flash Attention, the full model will not work")
|
|
print(" which may produce less accurate results.")
|
|
print(" Consider using the 'initial_only' model for better performance on Windows.")
|
|
|
|
|
|
if model_type == "initial_only":
|
|
|
|
if os.path.exists(os.path.join(model_dir, "model_initial_only.pt")):
|
|
model_path = os.path.join(model_dir, "model_initial_only.pt")
|
|
else:
|
|
model_path = os.path.join(model_dir, "model_initial.pt")
|
|
|
|
|
|
if os.path.exists(os.path.join(model_dir, "model_info_initial_only.json")):
|
|
model_info_path = os.path.join(model_dir, "model_info_initial_only.json")
|
|
else:
|
|
model_info_path = os.path.join(model_dir, "model_info_initial.json")
|
|
else:
|
|
|
|
model_filenames = ["model_refined.pt", "model.pt", "model_full.pt"]
|
|
model_path = None
|
|
for filename in model_filenames:
|
|
path = os.path.join(model_dir, filename)
|
|
if os.path.exists(path):
|
|
model_path = path
|
|
break
|
|
|
|
if model_path is None:
|
|
raise FileNotFoundError(f"No model file found in {model_dir}. Looked for: {', '.join(model_filenames)}")
|
|
|
|
model_info_path = os.path.join(model_dir, "model_info.json")
|
|
|
|
|
|
metadata_path = os.path.join(model_dir, "metadata.json")
|
|
thresholds_path = os.path.join(model_dir, "thresholds.json")
|
|
|
|
required_files = [metadata_path, thresholds_path, model_path]
|
|
for file_path in required_files:
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"Required file {file_path} not found")
|
|
|
|
|
|
with open(metadata_path, "r") as f:
|
|
metadata = json.load(f)
|
|
|
|
|
|
model_code = load_model_code(model_dir)
|
|
|
|
|
|
dummy_dataset = DummyDataset(metadata)
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
if os.path.exists(model_info_path):
|
|
with open(model_info_path, 'r') as f:
|
|
model_info = json.load(f)
|
|
print("Loaded model info:", model_info)
|
|
tag_context_size = model_info.get('tag_context_size', 256)
|
|
num_heads = model_info.get('num_heads', 16)
|
|
else:
|
|
print("Model info not found, using defaults")
|
|
tag_context_size = 256
|
|
num_heads = 16
|
|
|
|
try:
|
|
|
|
has_initial_only_class = hasattr(model_code, 'InitialOnlyImageTagger')
|
|
|
|
|
|
if model_type == "initial_only":
|
|
|
|
if has_initial_only_class:
|
|
model = model_code.InitialOnlyImageTagger(
|
|
total_tags=metadata['total_tags'],
|
|
dataset=dummy_dataset,
|
|
pretrained=False
|
|
)
|
|
else:
|
|
|
|
print("InitialOnlyImageTagger class not found. Using ImageTagger as fallback.")
|
|
model = model_code.ImageTagger(
|
|
total_tags=metadata['total_tags'],
|
|
dataset=dummy_dataset,
|
|
pretrained=False,
|
|
tag_context_size=tag_context_size,
|
|
num_heads=num_heads
|
|
)
|
|
else:
|
|
|
|
model = model_code.ImageTagger(
|
|
total_tags=metadata['total_tags'],
|
|
dataset=dummy_dataset,
|
|
pretrained=False,
|
|
tag_context_size=tag_context_size,
|
|
num_heads=num_heads
|
|
)
|
|
|
|
|
|
state_dict = torch.load(model_path, map_location=device)
|
|
|
|
|
|
try:
|
|
model.load_state_dict(state_dict, strict=True)
|
|
print("✓ Model loaded with strict=True")
|
|
except Exception as e:
|
|
print(f"Warning: Strict loading failed: {str(e)}")
|
|
print("Attempting to load with strict=False...")
|
|
model.load_state_dict(state_dict, strict=False)
|
|
print("✓ Model loaded with strict=False")
|
|
|
|
|
|
model = model.to(device=device, dtype=torch.float16)
|
|
model.eval()
|
|
|
|
|
|
param_dtype = next(model.parameters()).dtype
|
|
print(f"Model loaded successfully on {device} with precision {param_dtype}")
|
|
print(f"Model memory usage: {estimate_model_memory_usage(model, device):.2f} MB")
|
|
|
|
except Exception as e:
|
|
print(f"Error loading model: {str(e)}")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
|
|
with open(thresholds_path, "r") as f:
|
|
thresholds = json.load(f)
|
|
|
|
return model, thresholds, metadata |