anfruizhu commited on
Commit
f0ad0ed
·
verified ·
1 Parent(s): 2a9ca8f

Update bgremover.py

Browse files
Files changed (1) hide show
  1. bgremover.py +15 -13
bgremover.py CHANGED
@@ -351,20 +351,22 @@ class DamageClassifier():
351
  model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet18_SpidermitesModel.pth"
352
  model_filepath = model_filepath = os.path.join(MODEL_PATH, "short_resnet18_SpidermitesModel.pth")
353
  model = models.resnet18(weights='IMAGENET1K_V1')
354
-
355
- #Add fully connected layer at the end with num_classes as output
356
- num_ftrs = model.fc.in_features
357
- model.fc = nn.Linear(num_ftrs, 4)
358
 
359
- if torch.cuda.is_available():
360
- model.load_state_dict(torch.load(model_filepath))
361
- model.cuda()
362
- else:
363
- model.load_state_dict(torch.load(model_filepath, map_location='cpu'))
364
- model.eval()
365
-
366
- self.model = model
367
- self.model_name = model_name
 
 
 
 
 
 
368
 
369
  return
370
 
 
351
  model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet18_SpidermitesModel.pth"
352
  model_filepath = model_filepath = os.path.join(MODEL_PATH, "short_resnet18_SpidermitesModel.pth")
353
  model = models.resnet18(weights='IMAGENET1K_V1')
 
 
 
 
354
 
355
+ if not model is None:
356
+
357
+ #Add fully connected layer at the end with num_classes as output
358
+ num_ftrs = model.fc.in_features
359
+ model.fc = nn.Linear(num_ftrs, 4)
360
+
361
+ if torch.cuda.is_available():
362
+ model.load_state_dict(torch.load(model_filepath))
363
+ model.cuda()
364
+ else:
365
+ model.load_state_dict(torch.load(model_filepath, map_location='cpu'))
366
+ model.eval()
367
+
368
+ self.model = model
369
+ self.model_name = model_name
370
 
371
  return
372