File size: 4,609 Bytes
eb03e62 7ae96e1 eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff cbb31e2 eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff eb03e62 d67c1ff 2182d48 d67c1ff eb03e62 d67c1ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision.models import resnet50
from pathlib import Path
import logging
import warnings
warnings.filterwarnings('ignore')
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Path configurations
MODEL_PATH = Path('src/model_10.pth')
CLASSES_PATH = Path('src/classes.txt')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Image preprocessing - using the same transforms as training
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def load_classes():
with open(CLASSES_PATH) as f:
return [line.strip() for line in f.readlines()]
def load_model():
"""
Load the trained ResNet50 model
"""
try:
# Initialize model
model = resnet50(weights=None)
num_classes = len(load_classes())
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# Load checkpoint
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
# Extract state dict from checkpoint
if isinstance(checkpoint, dict):
if "model" in checkpoint:
state_dict = checkpoint["model"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
elif "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
else:
state_dict = checkpoint
# Clean state dict keys
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("module.", "")
if name.startswith("model."):
name = name[6:]
new_state_dict[name] = v
# Load state dict and set to eval mode
model.load_state_dict(new_state_dict, strict=False)
model.to(DEVICE)
model.eval()
logger.info("Model loaded successfully")
return model
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
# Global variables
CLASSES = load_classes()
MODEL = load_model()
def predict_image(image):
"""
Predict class for input image with top-3 accuracy
"""
try:
if image is None:
return "No image provided", "Please upload an image"
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Preprocess image
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
# Get prediction
with torch.no_grad():
output = MODEL(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Get top-3 predictions
top3_prob, top3_indices = torch.topk(probabilities, k=3)
# Format predictions
predictions = []
for prob, idx in zip(top3_prob, top3_indices):
class_name = CLASSES[idx]
confidence = prob.item() * 100
predictions.append(f"{class_name}: {confidence:.2f}%")
# Join predictions with newlines
predictions_text = "\n".join(predictions)
# Get top prediction
predicted_class = CLASSES[top3_indices[0]]
# Log predictions
logger.info(f"Predicted class: {predicted_class}")
logger.info(f"Top 3 predictions:\n{predictions_text}")
return predicted_class, predictions_text
except Exception as e:
logger.error(f"Prediction error: {e}")
return "Error in prediction", str(e)
# Create Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=[
gr.Textbox(label="Predicted Class"),
gr.Textbox(label="Top 3 Predictions", lines=3)
],
title="ResNet50 Image Classifier",
description=(
"Upload an image to classify.\n"
"The model will predict the class and show confidence scores for the top 3 predictions."
),
examples=[
["examples/example1.jpg"],
["examples/example2.jpg"]
] if Path("examples").exists() else None,
theme=gr.themes.Base()
)
# Launch the app
if __name__ == "__main__":
iface.launch() |