Upadated Inference.py

#1
by shlok123 - opened
Files changed (1) hide show
  1. inference.py +2 -2
inference.py CHANGED
@@ -97,8 +97,8 @@ def load_model(model_class, file_path):
97
  model = model_class(net_G=net_G)
98
  model.load_state_dict(torch.load(file_path, map_location=device))
99
 
100
- resnet_weights = torch.load(file_path)
101
- resnet_weights = torch.load("./model/res18-unet.pt")
102
  resnet_state_dict = resnet_weights['state_dict'] if 'state_dict' in resnet_weights else resnet_weights
103
 
104
  model_dict = model.state_dict()
 
97
  model = model_class(net_G=net_G)
98
  model.load_state_dict(torch.load(file_path, map_location=device))
99
 
100
+ resnet_weights = torch.load(file_path, map_location=device)
101
+ resnet_weights = torch.load("./model/res18-unet.pt", map_location=device)
102
  resnet_state_dict = resnet_weights['state_dict'] if 'state_dict' in resnet_weights else resnet_weights
103
 
104
  model_dict = model.state_dict()