CarVision / app.py
=
initial commit
cb5412b
raw
history blame
969 Bytes
import torch
import torchvision
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from torch import nn
from PIL import Image
from model import create_effnet_v2_model
class_names = ['Honda', 'Hyundai', 'Toyota']
effnet_v2, transforms = create_effnet_v2_model(num_classes=len(class_names), weights_path="efficient_net_s_carvision_3.pth")
def predict(model, image_path, device):
image = Image.open(image_path)
image = transforms(image).unsqueeze(0)
image = image.to(device)
output = model(image)
model.eval()
with torch.inference_mode():
probs = torch.softmax(output, dim=1)
pred_labels_and_probs = {class_names[i]: float(probs[0, i]) for i in range(len(class_names))}
return pred_labels_and_probs
print(predict(effnet_v2, "examples/Toyota_Tacoma_2017_36_18_270_35_6_75_70_212_19_RWD_5_4_Pickup_xQa.jpg", torch.device("cpu")))
# print(predict(effnet_v2, "test.jpg", torch.device("cuda:0")))