|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
from torchvision.models import resnet50 |
|
from pathlib import Path |
|
import logging |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_PATH = Path('src/model_10.pth') |
|
CLASSES_PATH = Path('src/classes.txt') |
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
]) |
|
|
|
def load_classes(): |
|
with open(CLASSES_PATH) as f: |
|
return [line.strip() for line in f.readlines()] |
|
|
|
def load_model(): |
|
""" |
|
Load the trained ResNet50 model |
|
""" |
|
try: |
|
|
|
model = resnet50(weights=None) |
|
num_classes = len(load_classes()) |
|
model.fc = torch.nn.Linear(model.fc.in_features, num_classes) |
|
|
|
|
|
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) |
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
if "model" in checkpoint: |
|
state_dict = checkpoint["model"] |
|
elif "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
elif "model_state_dict" in checkpoint: |
|
state_dict = checkpoint["model_state_dict"] |
|
else: |
|
state_dict = checkpoint |
|
else: |
|
state_dict = checkpoint |
|
|
|
|
|
new_state_dict = {} |
|
for k, v in state_dict.items(): |
|
name = k.replace("module.", "") |
|
if name.startswith("model."): |
|
name = name[6:] |
|
new_state_dict[name] = v |
|
|
|
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
logger.info("Model loaded successfully") |
|
return model |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model: {e}") |
|
raise |
|
|
|
|
|
CLASSES = load_classes() |
|
MODEL = load_model() |
|
|
|
def predict_image(image): |
|
""" |
|
Predict class for input image with top-3 accuracy |
|
""" |
|
try: |
|
if image is None: |
|
return "No image provided", "Please upload an image" |
|
|
|
|
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(image) |
|
|
|
|
|
input_tensor = transform(image).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
output = MODEL(input_tensor) |
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
|
|
|
|
top3_prob, top3_indices = torch.topk(probabilities, k=3) |
|
|
|
|
|
predictions = [] |
|
for prob, idx in zip(top3_prob, top3_indices): |
|
class_name = CLASSES[idx] |
|
confidence = prob.item() * 100 |
|
predictions.append(f"{class_name}: {confidence:.2f}%") |
|
|
|
|
|
predictions_text = "\n".join(predictions) |
|
|
|
|
|
predicted_class = CLASSES[top3_indices[0]] |
|
|
|
|
|
logger.info(f"Predicted class: {predicted_class}") |
|
logger.info(f"Top 3 predictions:\n{predictions_text}") |
|
|
|
return predicted_class, predictions_text |
|
|
|
except Exception as e: |
|
logger.error(f"Prediction error: {e}") |
|
return "Error in prediction", str(e) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_image, |
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
outputs=[ |
|
gr.Textbox(label="Predicted Class"), |
|
gr.Textbox(label="Top 3 Predictions", lines=3) |
|
], |
|
title="ResNet50 Image Classifier", |
|
description=( |
|
"Upload an image to classify.\n" |
|
"The model will predict the class and show confidence scores for the top 3 predictions." |
|
), |
|
examples=[ |
|
["examples/example1.jpg"], |
|
["examples/example2.jpg"] |
|
] if Path("examples").exists() else None, |
|
theme=gr.themes.Base() |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |