|
import torch
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
import gradio as gr
|
|
from resnet_model import ResNet50
|
|
from utils import load_checkpoint
|
|
import ast
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = ResNet50()
|
|
model = torch.nn.DataParallel(model)
|
|
model = model.to(device)
|
|
|
|
|
|
checkpoint_path = "checkpoint.pth"
|
|
model, _, _, _ = load_checkpoint(model, None, checkpoint_path)
|
|
model.eval()
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
|
|
with open("imagenet1000_clsidx_to_labels.txt") as f:
|
|
class_labels = ast.literal_eval(f.read())
|
|
|
|
|
|
def predict(image):
|
|
image = transform(image).unsqueeze(0).to(device)
|
|
with torch.no_grad():
|
|
outputs = model(image)
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
|
|
top5_prob, top5_catid = torch.topk(probabilities, 5)
|
|
|
|
results = []
|
|
for i in range(top5_prob.size(0)):
|
|
class_index = top5_catid[i].item()
|
|
class_label = class_labels.get(class_index, "Unknown")
|
|
prob = top5_prob[i].item() * 100
|
|
results.append(f"{class_label}: {prob:.2f}%")
|
|
|
|
return "\n".join(results)
|
|
|
|
|
|
iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="ResNet 50 Image Classifier")
|
|
iface.launch() |