import torch import torchvision from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights from torch import nn def create_effnet_v2_model(weights_path, num_classes=3): weights = EfficientNet_V2_S_Weights.DEFAULT transforms = weights.transforms() model = efficientnet_v2_s() model.classifier = nn.Sequential( nn.Dropout(0.0), nn.Linear(in_features=1280, out_features=num_classes) ) model.load_state_dict(torch.load(f = weights_path, map_location=torch.device("cpu"))) return model, transforms