biplab2008 commited on
Commit
cb606d8
·
verified ·
1 Parent(s): 62300a8

Update cnn3d_model.py

Browse files
Files changed (1) hide show
  1. cnn3d_model.py +2 -2
cnn3d_model.py CHANGED
@@ -300,13 +300,13 @@ def load_model():
300
  args = NetData(cnndata, lindata)
301
 
302
  # weight file
303
- #weight_file = 'cnn3d_epoch_300.pt'
304
 
305
  # CNN3D model
306
  # device = 'cuda' if torch.cuda.is_available() else 'cpu'
307
  device = torch.device('cpu')
308
  cnn3d = CNN3D(args).to(device)
309
- #cnn3d.load_state_dict(torch.load(os.path.join(base_path,'weights',weight_file), map_location=device))
310
  cnn3d.eval()
311
  #print(cnn3d)
312
 
 
300
  args = NetData(cnndata, lindata)
301
 
302
  # weight file
303
+ weight_file = 'cnn3d_epoch_300.pt'
304
 
305
  # CNN3D model
306
  # device = 'cuda' if torch.cuda.is_available() else 'cpu'
307
  device = torch.device('cpu')
308
  cnn3d = CNN3D(args).to(device)
309
+ cnn3d.load_state_dict(torch.load(weight_file, map_location=device))
310
  cnn3d.eval()
311
  #print(cnn3d)
312