Spaces:
Sleeping
Sleeping
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import torchvision.models as models | |
import os | |
# Load ImageNet class labels | |
with open('imagenet_classes.txt', 'r') as f: | |
categories = [line.strip() for line in f.readlines()] | |
# Define image preprocessing | |
def preprocess_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
return transform(image).unsqueeze(0) | |
def convert_state_dict(state_dict): | |
"""Convert Composer state dict to standard ResNet state dict.""" | |
print("Original state dict keys:", list(state_dict.keys())[:5], "...") | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
# Remove 'module.' prefix if it exists | |
if key.startswith('module.'): | |
key = key[7:] # Remove first 7 characters ('module.') | |
# Handle blur filter layers | |
if 'blur_filter' in key or 'filt2d' in key: | |
continue | |
# Convert conv layers with blur | |
if '.conv.weight' in key: | |
key = key.replace('.conv.weight', '.weight') | |
new_state_dict[key] = value | |
# Print shape information for debugging | |
print(f"Layer: {key}, Shape: {value.shape}") | |
print("\nConverted state dict keys:", list(new_state_dict.keys())[:5], "...") | |
return new_state_dict | |
# Load model from Hugging Face Hub | |
def load_model(): | |
try: | |
repo_id = "satyanayak/imagenet-resnet50-composer-model" | |
filename = "model.pt" | |
print(f"Attempting to load model from {repo_id}/{filename}") | |
# Download the model file | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
print(f"Model downloaded to: {model_path}") | |
# Initialize standard ResNet50 | |
print("Initializing ResNet50 model...") | |
model = models.resnet50(weights=None) | |
# Print model structure | |
print("\nModel structure:") | |
for name, module in model.named_children(): | |
print(f"{name}: {module.__class__.__name__}") | |
# Load and convert the state dict | |
print("\nLoading state dict...") | |
state_dict = torch.load( | |
model_path, | |
map_location=torch.device('cpu'), | |
weights_only=True | |
) | |
print("\nConverting state dict...") | |
converted_state_dict = convert_state_dict(state_dict) | |
# Load the converted state dict | |
print("\nLoading weights into model...") | |
missing_keys, unexpected_keys = model.load_state_dict(converted_state_dict, strict=False) | |
if missing_keys: | |
print("\nMissing keys:", missing_keys) | |
if unexpected_keys: | |
print("\nUnexpected keys:", unexpected_keys) | |
model.eval() | |
print("\nModel loaded successfully!") | |
return model | |
except Exception as e: | |
print(f"\nError loading custom model: {str(e)}") | |
print("Stack trace:", e.__traceback__) | |
print("Falling back to pretrained ResNet50") | |
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) | |
model.eval() | |
return model | |
# Prediction function with debugging | |
def predict(input_image): | |
try: | |
# Convert from BGR to RGB | |
input_image = Image.fromarray(input_image) | |
print(f"Input image size: {input_image.size}") | |
# Preprocess the image | |
input_tensor = preprocess_image(input_image) | |
print(f"Preprocessed tensor shape: {input_tensor.shape}") | |
# Make prediction | |
with torch.no_grad(): | |
output = model(input_tensor) | |
print(f"Raw output shape: {output.shape}") | |
print(f"Raw output values (first 5): {output[0][:5]}") | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
print(f"Probability values (first 5): {probabilities[:5]}") | |
# Get top 5 predictions | |
top5_prob, top5_catid = torch.topk(probabilities, 5) | |
# Print debugging info | |
print("\nTop 5 predictions:") | |
for i in range(5): | |
print(f"{categories[top5_catid[i]]}: {float(top5_prob[i]):.4f}") | |
# Create result dictionary | |
results = {} | |
for i in range(5): | |
results[categories[top5_catid[i]]] = float(top5_prob[i]) | |
return results | |
except Exception as e: | |
print(f"Error during prediction: {str(e)}") | |
return {"error": str(e)} | |
# Load the model globally | |
model = load_model() | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(), | |
outputs=gr.Label(num_top_classes=5), | |
examples=[["images/dog.jpg"], ["images/cat.jpg"]], | |
title="ImageNet Classification", | |
description="Upload an image to classify it into one of 1000 ImageNet categories." | |
) | |
# Launch the interface | |
iface.launch() |