|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
from torchvision.models import resnet50 |
|
import os |
|
import logging |
|
from typing import Optional, Union |
|
import numpy as np |
|
from pathlib import Path |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent |
|
MODELS_DIR = BASE_DIR / "models" |
|
EXAMPLES_DIR = BASE_DIR / "examples" |
|
STATIC_DIR = BASE_DIR / "static" / "uploaded" |
|
|
|
|
|
STATIC_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
MODEL_PATH = MODELS_DIR / "resnet_50.pth" |
|
CLASSES_PATH = BASE_DIR / "classes.txt" |
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def load_class_labels() -> Optional[list]: |
|
""" |
|
Load class labels from the classes.txt file |
|
""" |
|
try: |
|
if not CLASSES_PATH.exists(): |
|
raise FileNotFoundError(f"Classes file not found at {CLASSES_PATH}") |
|
|
|
with open(CLASSES_PATH, 'r') as f: |
|
return [line.strip() for line in f.readlines()] |
|
except Exception as e: |
|
logger.error(f"Error loading class labels: {str(e)}") |
|
return None |
|
|
|
|
|
CLASS_NAMES = load_class_labels() |
|
if CLASS_NAMES is None: |
|
raise RuntimeError("Failed to load class labels from classes.txt") |
|
|
|
|
|
model = None |
|
|
|
def load_model() -> Optional[torch.nn.Module]: |
|
""" |
|
Load the ResNet50 model with error handling |
|
""" |
|
global model |
|
|
|
try: |
|
if model is not None: |
|
return model |
|
|
|
if not MODEL_PATH.exists(): |
|
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}") |
|
|
|
logger.info(f"Loading model on {DEVICE}") |
|
model = resnet50(pretrained=False) |
|
model.fc = torch.nn.Linear(model.fc.in_features, len(CLASS_NAMES)) |
|
|
|
|
|
state_dict = torch.load(MODEL_PATH, map_location=DEVICE) |
|
|
|
if 'state_dict' in state_dict: |
|
state_dict = state_dict['state_dict'] |
|
|
|
model.load_state_dict(state_dict) |
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
logger.info("Model loaded successfully") |
|
return model |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}") |
|
return None |
|
|
|
def preprocess_image(image: Union[np.ndarray, Image.Image]) -> Optional[torch.Tensor]: |
|
""" |
|
Preprocess the input image with error handling |
|
""" |
|
try: |
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
]) |
|
|
|
return transform(image).unsqueeze(0).to(DEVICE) |
|
|
|
except Exception as e: |
|
logger.error(f"Error preprocessing image: {str(e)}") |
|
return None |
|
|
|
def predict(image: Union[np.ndarray, None]) -> tuple[str, dict]: |
|
""" |
|
Make predictions on the input image with comprehensive error handling |
|
Returns the predicted class and top 5 confidence scores |
|
""" |
|
try: |
|
if image is None: |
|
return "Error: No image provided", {} |
|
|
|
model = load_model() |
|
if model is None: |
|
return "Error: Failed to load model", {} |
|
|
|
input_tensor = preprocess_image(image) |
|
if input_tensor is None: |
|
return "Error: Failed to preprocess image", {} |
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
|
|
predicted_class_idx = torch.argmax(probabilities).item() |
|
predicted_class = CLASS_NAMES[predicted_class_idx] |
|
|
|
|
|
top_5_probs, top_5_indices = torch.topk(probabilities, k=5) |
|
|
|
|
|
confidences = { |
|
CLASS_NAMES[idx.item()]: float(prob.item()) |
|
for prob, idx in zip(top_5_probs, top_5_indices) |
|
} |
|
|
|
return predicted_class, confidences |
|
|
|
except Exception as e: |
|
logger.error(f"Prediction error: {str(e)}") |
|
return f"Error during prediction: {str(e)}", {} |
|
|
|
def get_example_list() -> list: |
|
""" |
|
Get list of example images from the examples directory |
|
""" |
|
try: |
|
examples = [] |
|
for ext in ['.jpg', '.jpeg', '.png']: |
|
examples.extend(list(EXAMPLES_DIR.glob(f'*.{ext}'))) |
|
return [[str(ex)] for ex in sorted(examples)] |
|
except Exception as e: |
|
logger.error(f"Error loading examples: {str(e)}") |
|
return [] |
|
|
|
|
|
try: |
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="numpy", label="Upload Image"), |
|
outputs=[ |
|
gr.Label(label="Predicted Class", num_top_classes=1), |
|
gr.Label(label="Top 5 Predictions", num_top_classes=5) |
|
], |
|
title="Image Classification with ResNet50", |
|
description=( |
|
"Upload an image to classify:\n" |
|
"The model will predict the class and show top 5 confidence scores." |
|
), |
|
examples=get_example_list(), |
|
cache_examples=True, |
|
theme=gr.themes.Base() |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Error creating Gradio interface: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
try: |
|
load_model() |
|
iface.launch( |
|
share=False, |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
debug=False |
|
) |
|
except Exception as e: |
|
logger.error(f"Error launching application: {str(e)}") |