Spaces:
Sleeping
Sleeping
Update cnn3d_model.py
Browse files- cnn3d_model.py +2 -2
cnn3d_model.py
CHANGED
@@ -300,13 +300,13 @@ def load_model():
|
|
300 |
args = NetData(cnndata, lindata)
|
301 |
|
302 |
# weight file
|
303 |
-
|
304 |
|
305 |
# CNN3D model
|
306 |
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
307 |
device = torch.device('cpu')
|
308 |
cnn3d = CNN3D(args).to(device)
|
309 |
-
|
310 |
cnn3d.eval()
|
311 |
#print(cnn3d)
|
312 |
|
|
|
300 |
args = NetData(cnndata, lindata)
|
301 |
|
302 |
# weight file
|
303 |
+
weight_file = 'cnn3d_epoch_300.pt'
|
304 |
|
305 |
# CNN3D model
|
306 |
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
307 |
device = torch.device('cpu')
|
308 |
cnn3d = CNN3D(args).to(device)
|
309 |
+
cnn3d.load_state_dict(torch.load(weight_file, map_location=device))
|
310 |
cnn3d.eval()
|
311 |
#print(cnn3d)
|
312 |
|