Spaces:
Sleeping
Sleeping
File size: 699 Bytes
6692a2b 4f12f52 6692a2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
from torchvision.models import EfficientNet_B2_Weights, efficientnet_b2
from torch import nn
def create_effnet_b2(num_classes:int = 3, seed:int = 42):
eff_weights = EfficientNet_B2_Weights.DEFAULT
efficientnet_transform = eff_weights.transforms()
effnet_model = efficientnet_b2(eff_weights)
for params in effnet_model.parameters():
params.requires_grad = False
torch.manual_seed(seed=seed)
effnet_model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=num_classes, bias=True)
)
return effnet_model, efficientnet_transform
|