File size: 9,112 Bytes
29b445b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""

Model loading utilities for Image Tagger application.

"""

import os
import json
import torch
import platform
import traceback
import importlib.util


def is_windows():
    """Check if the system is Windows"""
    return platform.system() == "Windows"


class DummyDataset:
    """Minimal dataset class for inference"""
    def __init__(self, metadata):
        self.total_tags = metadata['total_tags']
        self.idx_to_tag = {int(k): v for k, v in metadata['idx_to_tag'].items()}
        self.tag_to_category = metadata['tag_to_category']
    
    def get_tag_info(self, idx):
        tag = self.idx_to_tag.get(idx, f"unknown_{idx}")
        category = self.tag_to_category.get(tag, "general")
        return tag, category


def load_model_code(model_dir):
    """

    Load the model code module from the model directory.

    

    Args:

        model_dir: Path to the model directory

        

    Returns:

        Imported model code module

    """
    model_code_path = os.path.join(model_dir, "model_code.py")
    
    if not os.path.exists(model_code_path):
        raise FileNotFoundError(f"model_code.py not found at {model_code_path}")
    
    # Import the model code dynamically
    spec = importlib.util.spec_from_file_location("model_code", model_code_path)
    model_code = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(model_code)
    
    # Check if required classes exist
    if not hasattr(model_code, 'ImageTagger') or not hasattr(model_code, 'FlashAttention'):
        raise ImportError("Required classes not found in model_code.py")
    
    return model_code


def check_flash_attention():
    """

    Check if Flash Attention is properly installed.

    

    Returns:

        bool: True if Flash Attention is available and working

    """
    try:
        import flash_attn
        if hasattr(flash_attn, 'flash_attn_func'):
            module_path = flash_attn.flash_attn_func.__module__
            return 'flash_attn_fallback' not in module_path
    except:
        pass
    return False


def estimate_model_memory_usage(model, device):
    """

    Estimate the memory usage of a model in MB.

    """
    mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
    mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
    mem_total = mem_params + mem_bufs  # in bytes
    return mem_total / (1024 * 1024)  # convert to MB


def load_exported_model(model_dir, model_type="full"):
    """

    Load the exported model and metadata with correct precision.

    

    Args:

        model_dir: Directory containing the model files

        model_type: "full" or "initial_only" 

    

    Returns:

        model, thresholds, metadata

    """
    print(f"Loading {model_type} model from {model_dir}")
    
    # Make sure we have the absolute path to the model directory
    model_dir = os.path.abspath(model_dir)
    print(f"Absolute model path: {model_dir}")
    
    # Check for required files
    metadata_path = os.path.join(model_dir, "metadata.json")
    thresholds_path = os.path.join(model_dir, "thresholds.json")
    
    print(f"Looking for thresholds at: {thresholds_path}")
    
    # Check platform and Flash Attention status
    windows_system = is_windows()
    flash_attn_installed = check_flash_attention()
    
    # Add a specific warning for Windows users trying to use the full model without Flash Attention
    if windows_system and model_type == "full" and not flash_attn_installed:
        print("Note: On Windows without Flash Attention, the full model will not work")
        print("      which may produce less accurate results.")
        print("      Consider using the 'initial_only' model for better performance on Windows.")
    
    # Determine file paths based on model type
    if model_type == "initial_only":
        # Try both naming conventions
        if os.path.exists(os.path.join(model_dir, "model_initial_only.pt")):
            model_path = os.path.join(model_dir, "model_initial_only.pt")
        else:
            model_path = os.path.join(model_dir, "model_initial.pt")
        
        # Try both naming conventions for info file
        if os.path.exists(os.path.join(model_dir, "model_info_initial_only.json")):
            model_info_path = os.path.join(model_dir, "model_info_initial_only.json")
        else:
            model_info_path = os.path.join(model_dir, "model_info_initial.json")
    else:
        # Try multiple naming conventions for the full model
        model_filenames = ["model_refined.pt", "model.pt", "model_full.pt"]
        model_path = None
        for filename in model_filenames:
            path = os.path.join(model_dir, filename)
            if os.path.exists(path):
                model_path = path
                break
        
        if model_path is None:
            raise FileNotFoundError(f"No model file found in {model_dir}. Looked for: {', '.join(model_filenames)}")
            
        model_info_path = os.path.join(model_dir, "model_info.json")
    
    # Check for required files
    metadata_path = os.path.join(model_dir, "metadata.json")
    thresholds_path = os.path.join(model_dir, "thresholds.json")
    
    required_files = [metadata_path, thresholds_path, model_path]
    for file_path in required_files:
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Required file {file_path} not found")
    
    # Load metadata
    with open(metadata_path, "r") as f:
        metadata = json.load(f)
    
    # Load model code
    model_code = load_model_code(model_dir)
    
    # Create dataset
    dummy_dataset = DummyDataset(metadata)
    
    # Determine device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load model info
    if os.path.exists(model_info_path):
        with open(model_info_path, 'r') as f:
            model_info = json.load(f)
        print("Loaded model info:", model_info)
        tag_context_size = model_info.get('tag_context_size', 256)
        num_heads = model_info.get('num_heads', 16)
    else:
        print("Model info not found, using defaults")
        tag_context_size = 256
        num_heads = 16
    
    try:
        # Check if InitialOnlyImageTagger class exists
        has_initial_only_class = hasattr(model_code, 'InitialOnlyImageTagger')
        
        # Create the appropriate model type
        if model_type == "initial_only":
            # Create the lightweight model
            if has_initial_only_class:
                model = model_code.InitialOnlyImageTagger(
                    total_tags=metadata['total_tags'],
                    dataset=dummy_dataset,
                    pretrained=False
                )
            else:
                # Fallback to using ImageTagger for initial-only if the specific class isn't available
                print("InitialOnlyImageTagger class not found. Using ImageTagger as fallback.")
                model = model_code.ImageTagger(
                    total_tags=metadata['total_tags'],
                    dataset=dummy_dataset,
                    pretrained=False,
                    tag_context_size=tag_context_size,
                    num_heads=num_heads
                )
        else:
            # Create the full model
            model = model_code.ImageTagger(
                total_tags=metadata['total_tags'],
                dataset=dummy_dataset,
                pretrained=False,
                tag_context_size=tag_context_size,
                num_heads=num_heads
            )
        
        # Load state dict
        state_dict = torch.load(model_path, map_location=device)
        
        # Try loading with strict=True first, then fall back to strict=False
        try:
            model.load_state_dict(state_dict, strict=True)
            print("✓ Model loaded with strict=True")
        except Exception as e:
            print(f"Warning: Strict loading failed: {str(e)}")
            print("Attempting to load with strict=False...")
            model.load_state_dict(state_dict, strict=False)
            print("✓ Model loaded with strict=False")
        
        # Ensure model is in half precision to match training conditions
        model = model.to(device=device, dtype=torch.float16)
        model.eval()
        
        # Check parameter dtype
        param_dtype = next(model.parameters()).dtype
        print(f"Model loaded successfully on {device} with precision {param_dtype}")
        print(f"Model memory usage: {estimate_model_memory_usage(model, device):.2f} MB")
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        traceback.print_exc()
        raise
    
    # Load thresholds
    with open(thresholds_path, "r") as f:
        thresholds = json.load(f)
    
    return model, thresholds, metadata