Spaces:
Sleeping
Sleeping
import timm | |
import torch.nn as nn | |
class ExampleModel(nn.Module): | |
def __init__(self, num_classes): | |
super(ExampleModel, self).__init__() | |
self.base_model = timm.create_model("efficientnet_b0", pretrained=True) | |
self.features = nn.Sequential(*list(self.base_model.children())[:-1]) | |
network_out_size = 1280 | |
self.classifier = nn.Sequential( | |
nn.Flatten(), nn.Linear(network_out_size, num_classes) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
output = self.classifier(x) | |
return output | |