Juliofc commited on
Commit
543f089
1 Parent(s): 021f0c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -11,14 +11,18 @@ def load_model():
11
  num_ftrs = model.classifier[6].in_features
12
  model.classifier[6] = nn.Linear(num_ftrs, len(class_names)) # Adjust this for your number of classes
13
 
14
- # Load the model weights
15
- model.load_state_dict(torch.load('model_alexnet.pth'))
16
- model.eval() # Set to evaluation mode
17
-
18
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
19
  model.to(device)
20
  return model, device
21
 
 
22
  model_alexnet, device = load_model()
23
 
24
  # Image transformations
 
11
  num_ftrs = model.classifier[6].in_features
12
  model.classifier[6] = nn.Linear(num_ftrs, len(class_names)) # Adjust this for your number of classes
13
 
14
+ # Correctly map the model to CPU if CUDA is not available
 
 
 
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+ if device == torch.device('cpu'):
17
+ model.load_state_dict(torch.load('model_alexnet.pth', map_location=device))
18
+ else:
19
+ model.load_state_dict(torch.load('model_alexnet.pth'))
20
+
21
+ model.eval() # Set to evaluation mode
22
  model.to(device)
23
  return model, device
24
 
25
+
26
  model_alexnet, device = load_model()
27
 
28
  # Image transformations