satyanayak's picture
update the model file
067ff37
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import torchvision.models as models
import os
# Load ImageNet class labels
with open('imagenet_classes.txt', 'r') as f:
categories = [line.strip() for line in f.readlines()]
# Define image preprocessing
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
def convert_state_dict(state_dict):
"""Convert Composer state dict to standard ResNet state dict."""
print("Original state dict keys:", list(state_dict.keys())[:5], "...")
new_state_dict = {}
for key, value in state_dict.items():
# Remove 'module.' prefix if it exists
if key.startswith('module.'):
key = key[7:] # Remove first 7 characters ('module.')
# Handle blur filter layers
if 'blur_filter' in key or 'filt2d' in key:
continue
# Convert conv layers with blur
if '.conv.weight' in key:
key = key.replace('.conv.weight', '.weight')
new_state_dict[key] = value
# Print shape information for debugging
print(f"Layer: {key}, Shape: {value.shape}")
print("\nConverted state dict keys:", list(new_state_dict.keys())[:5], "...")
return new_state_dict
# Load model from Hugging Face Hub
def load_model():
try:
repo_id = "satyanayak/imagenet-resnet50-composer-model"
filename = "model.pt"
print(f"Attempting to load model from {repo_id}/{filename}")
# Download the model file
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Model downloaded to: {model_path}")
# Initialize standard ResNet50
print("Initializing ResNet50 model...")
model = models.resnet50(weights=None)
# Print model structure
print("\nModel structure:")
for name, module in model.named_children():
print(f"{name}: {module.__class__.__name__}")
# Load and convert the state dict
print("\nLoading state dict...")
state_dict = torch.load(
model_path,
map_location=torch.device('cpu'),
weights_only=True
)
print("\nConverting state dict...")
converted_state_dict = convert_state_dict(state_dict)
# Load the converted state dict
print("\nLoading weights into model...")
missing_keys, unexpected_keys = model.load_state_dict(converted_state_dict, strict=False)
if missing_keys:
print("\nMissing keys:", missing_keys)
if unexpected_keys:
print("\nUnexpected keys:", unexpected_keys)
model.eval()
print("\nModel loaded successfully!")
return model
except Exception as e:
print(f"\nError loading custom model: {str(e)}")
print("Stack trace:", e.__traceback__)
print("Falling back to pretrained ResNet50")
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
model.eval()
return model
# Prediction function with debugging
def predict(input_image):
try:
# Convert from BGR to RGB
input_image = Image.fromarray(input_image)
print(f"Input image size: {input_image.size}")
# Preprocess the image
input_tensor = preprocess_image(input_image)
print(f"Preprocessed tensor shape: {input_tensor.shape}")
# Make prediction
with torch.no_grad():
output = model(input_tensor)
print(f"Raw output shape: {output.shape}")
print(f"Raw output values (first 5): {output[0][:5]}")
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(f"Probability values (first 5): {probabilities[:5]}")
# Get top 5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
# Print debugging info
print("\nTop 5 predictions:")
for i in range(5):
print(f"{categories[top5_catid[i]]}: {float(top5_prob[i]):.4f}")
# Create result dictionary
results = {}
for i in range(5):
results[categories[top5_catid[i]]] = float(top5_prob[i])
return results
except Exception as e:
print(f"Error during prediction: {str(e)}")
return {"error": str(e)}
# Load the model globally
model = load_model()
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=gr.Label(num_top_classes=5),
examples=[["images/dog.jpg"], ["images/cat.jpg"]],
title="ImageNet Classification",
description="Upload an image to classify it into one of 1000 ImageNet categories."
)
# Launch the interface
iface.launch()