File size: 559 Bytes
cb5412b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torchvision

from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from torch import nn


def create_effnet_v2_model(weights_path, num_classes=3):
    weights = EfficientNet_V2_S_Weights.DEFAULT
    transforms = weights.transforms()

    model = efficientnet_v2_s()
    model.classifier = nn.Sequential(
        nn.Dropout(0.0),
        nn.Linear(in_features=1280, out_features=num_classes)
    )
    model.load_state_dict(torch.load(f = weights_path, map_location=torch.device("cpu")))

    return model, transforms