import torch import torchvision.transforms as transforms from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download import torchvision.models as models import os # Load ImageNet class labels with open('imagenet_classes.txt', 'r') as f: categories = [line.strip() for line in f.readlines()] # Define image preprocessing def preprocess_image(image): transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0) def convert_state_dict(state_dict): """Convert Composer state dict to standard ResNet state dict.""" print("Original state dict keys:", list(state_dict.keys())[:5], "...") new_state_dict = {} for key, value in state_dict.items(): # Remove 'module.' prefix if it exists if key.startswith('module.'): key = key[7:] # Remove first 7 characters ('module.') # Handle blur filter layers if 'blur_filter' in key or 'filt2d' in key: continue # Convert conv layers with blur if '.conv.weight' in key: key = key.replace('.conv.weight', '.weight') new_state_dict[key] = value # Print shape information for debugging print(f"Layer: {key}, Shape: {value.shape}") print("\nConverted state dict keys:", list(new_state_dict.keys())[:5], "...") return new_state_dict # Load model from Hugging Face Hub def load_model(): try: repo_id = "satyanayak/imagenet-resnet50-composer-model" filename = "model.pt" print(f"Attempting to load model from {repo_id}/{filename}") # Download the model file model_path = hf_hub_download(repo_id=repo_id, filename=filename) print(f"Model downloaded to: {model_path}") # Initialize standard ResNet50 print("Initializing ResNet50 model...") model = models.resnet50(weights=None) # Print model structure print("\nModel structure:") for name, module in model.named_children(): print(f"{name}: {module.__class__.__name__}") # Load and convert the state dict print("\nLoading state dict...") state_dict = torch.load( model_path, map_location=torch.device('cpu'), weights_only=True ) print("\nConverting state dict...") converted_state_dict = convert_state_dict(state_dict) # Load the converted state dict print("\nLoading weights into model...") missing_keys, unexpected_keys = model.load_state_dict(converted_state_dict, strict=False) if missing_keys: print("\nMissing keys:", missing_keys) if unexpected_keys: print("\nUnexpected keys:", unexpected_keys) model.eval() print("\nModel loaded successfully!") return model except Exception as e: print(f"\nError loading custom model: {str(e)}") print("Stack trace:", e.__traceback__) print("Falling back to pretrained ResNet50") model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) model.eval() return model # Prediction function with debugging def predict(input_image): try: # Convert from BGR to RGB input_image = Image.fromarray(input_image) print(f"Input image size: {input_image.size}") # Preprocess the image input_tensor = preprocess_image(input_image) print(f"Preprocessed tensor shape: {input_tensor.shape}") # Make prediction with torch.no_grad(): output = model(input_tensor) print(f"Raw output shape: {output.shape}") print(f"Raw output values (first 5): {output[0][:5]}") probabilities = torch.nn.functional.softmax(output[0], dim=0) print(f"Probability values (first 5): {probabilities[:5]}") # Get top 5 predictions top5_prob, top5_catid = torch.topk(probabilities, 5) # Print debugging info print("\nTop 5 predictions:") for i in range(5): print(f"{categories[top5_catid[i]]}: {float(top5_prob[i]):.4f}") # Create result dictionary results = {} for i in range(5): results[categories[top5_catid[i]]] = float(top5_prob[i]) return results except Exception as e: print(f"Error during prediction: {str(e)}") return {"error": str(e)} # Load the model globally model = load_model() # Create Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=5), examples=[["images/dog.jpg"], ["images/cat.jpg"]], title="ImageNet Classification", description="Upload an image to classify it into one of 1000 ImageNet categories." ) # Launch the interface iface.launch()