File size: 5,231 Bytes
6834c75
 
 
 
de70513
dddc5d2
de70513
6834c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de70513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6834c75
de70513
 
067ff37
de70513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6834c75
de70513
6834c75
53c5450
 
 
de70513
53c5450
 
 
de70513
53c5450
 
 
 
de70513
 
 
53c5450
de70513
53c5450
 
 
 
de70513
 
 
 
 
53c5450
 
 
 
6834c75
53c5450
 
 
 
6834c75
 
 
 
 
 
 
 
 
94625eb
6834c75
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import torchvision.models as models
import os

# Load ImageNet class labels
with open('imagenet_classes.txt', 'r') as f:
    categories = [line.strip() for line in f.readlines()]

# Define image preprocessing
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

def convert_state_dict(state_dict):
    """Convert Composer state dict to standard ResNet state dict."""
    print("Original state dict keys:", list(state_dict.keys())[:5], "...")
    
    new_state_dict = {}
    for key, value in state_dict.items():
        # Remove 'module.' prefix if it exists
        if key.startswith('module.'):
            key = key[7:]  # Remove first 7 characters ('module.')
            
        # Handle blur filter layers
        if 'blur_filter' in key or 'filt2d' in key:
            continue
            
        # Convert conv layers with blur
        if '.conv.weight' in key:
            key = key.replace('.conv.weight', '.weight')
            
        new_state_dict[key] = value
        
        # Print shape information for debugging
        print(f"Layer: {key}, Shape: {value.shape}")
        
    print("\nConverted state dict keys:", list(new_state_dict.keys())[:5], "...")
    return new_state_dict

# Load model from Hugging Face Hub
def load_model():
    try:
        repo_id = "satyanayak/imagenet-resnet50-composer-model"
        filename = "model.pt"
        
        print(f"Attempting to load model from {repo_id}/{filename}")
        
        # Download the model file
        model_path = hf_hub_download(repo_id=repo_id, filename=filename)
        print(f"Model downloaded to: {model_path}")
        
        # Initialize standard ResNet50
        print("Initializing ResNet50 model...")
        model = models.resnet50(weights=None)
        
        # Print model structure
        print("\nModel structure:")
        for name, module in model.named_children():
            print(f"{name}: {module.__class__.__name__}")
        
        # Load and convert the state dict
        print("\nLoading state dict...")
        state_dict = torch.load(
            model_path, 
            map_location=torch.device('cpu'),
            weights_only=True
        )
        
        print("\nConverting state dict...")
        converted_state_dict = convert_state_dict(state_dict)
        
        # Load the converted state dict
        print("\nLoading weights into model...")
        missing_keys, unexpected_keys = model.load_state_dict(converted_state_dict, strict=False)
        
        if missing_keys:
            print("\nMissing keys:", missing_keys)
        if unexpected_keys:
            print("\nUnexpected keys:", unexpected_keys)
        
        model.eval()
        print("\nModel loaded successfully!")
        return model
        
    except Exception as e:
        print(f"\nError loading custom model: {str(e)}")
        print("Stack trace:", e.__traceback__)
        print("Falling back to pretrained ResNet50")
        model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        model.eval()
        return model

# Prediction function with debugging
def predict(input_image):
    try:
        # Convert from BGR to RGB
        input_image = Image.fromarray(input_image)
        print(f"Input image size: {input_image.size}")
        
        # Preprocess the image
        input_tensor = preprocess_image(input_image)
        print(f"Preprocessed tensor shape: {input_tensor.shape}")
        
        # Make prediction
        with torch.no_grad():
            output = model(input_tensor)
            print(f"Raw output shape: {output.shape}")
            print(f"Raw output values (first 5): {output[0][:5]}")
            
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
            print(f"Probability values (first 5): {probabilities[:5]}")
            
        # Get top 5 predictions
        top5_prob, top5_catid = torch.topk(probabilities, 5)
        
        # Print debugging info
        print("\nTop 5 predictions:")
        for i in range(5):
            print(f"{categories[top5_catid[i]]}: {float(top5_prob[i]):.4f}")
        
        # Create result dictionary
        results = {}
        for i in range(5):
            results[categories[top5_catid[i]]] = float(top5_prob[i])
        
        return results
    except Exception as e:
        print(f"Error during prediction: {str(e)}")
        return {"error": str(e)}

# Load the model globally
model = load_model()

# Create Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(),
    outputs=gr.Label(num_top_classes=5),
    examples=[["images/dog.jpg"], ["images/cat.jpg"]],
    title="ImageNet Classification",
    description="Upload an image to classify it into one of 1000 ImageNet categories."
)

# Launch the interface
iface.launch()