Spaces:
Sleeping
Sleeping
Update bgremover.py
Browse files- bgremover.py +15 -13
bgremover.py
CHANGED
@@ -351,20 +351,22 @@ class DamageClassifier():
|
|
351 |
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet18_SpidermitesModel.pth"
|
352 |
model_filepath = model_filepath = os.path.join(MODEL_PATH, "short_resnet18_SpidermitesModel.pth")
|
353 |
model = models.resnet18(weights='IMAGENET1K_V1')
|
354 |
-
|
355 |
-
#Add fully connected layer at the end with num_classes as output
|
356 |
-
num_ftrs = model.fc.in_features
|
357 |
-
model.fc = nn.Linear(num_ftrs, 4)
|
358 |
|
359 |
-
if
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
model.
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
return
|
370 |
|
|
|
351 |
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet18_SpidermitesModel.pth"
|
352 |
model_filepath = model_filepath = os.path.join(MODEL_PATH, "short_resnet18_SpidermitesModel.pth")
|
353 |
model = models.resnet18(weights='IMAGENET1K_V1')
|
|
|
|
|
|
|
|
|
354 |
|
355 |
+
if not model is None:
|
356 |
+
|
357 |
+
#Add fully connected layer at the end with num_classes as output
|
358 |
+
num_ftrs = model.fc.in_features
|
359 |
+
model.fc = nn.Linear(num_ftrs, 4)
|
360 |
+
|
361 |
+
if torch.cuda.is_available():
|
362 |
+
model.load_state_dict(torch.load(model_filepath))
|
363 |
+
model.cuda()
|
364 |
+
else:
|
365 |
+
model.load_state_dict(torch.load(model_filepath, map_location='cpu'))
|
366 |
+
model.eval()
|
367 |
+
|
368 |
+
self.model = model
|
369 |
+
self.model_name = model_name
|
370 |
|
371 |
return
|
372 |
|