CarVision / model.py
=
initial commit
cb5412b
raw
history blame contribute delete
559 Bytes
import torch
import torchvision
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from torch import nn
def create_effnet_v2_model(weights_path, num_classes=3):
weights = EfficientNet_V2_S_Weights.DEFAULT
transforms = weights.transforms()
model = efficientnet_v2_s()
model.classifier = nn.Sequential(
nn.Dropout(0.0),
nn.Linear(in_features=1280, out_features=num_classes)
)
model.load_state_dict(torch.load(f = weights_path, map_location=torch.device("cpu")))
return model, transforms