hasibzunair commited on
Commit
f02cbe5
1 Parent(s): 0b85db2
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -45,7 +45,8 @@ model = ResNet_CSRA(num_heads=1, lam=0.1, num_classes=20)
45
  normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
46
  model.to(DEVICE)
47
  print("Loading weights from {}".format("./models/msl_c_voc.pth"))
48
- model.load_state_dict(torch.load("./models/msl_c_voc.pth"), map_location=torch.device('cpu'))
 
49
 
50
  # Inference!
51
  def inference(img_path):
@@ -64,7 +65,6 @@ def inference(img_path):
64
 
65
  # Predict
66
  result = []
67
- model.eval()
68
  with torch.no_grad():
69
  image = image.to(DEVICE)
70
  logit = model(image).squeeze(0)
 
45
  normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
46
  model.to(DEVICE)
47
  print("Loading weights from {}".format("./models/msl_c_voc.pth"))
48
+ model.load_state_dict(torch.load("./models/msl_c_voc.pth", map_location=torch.device("cpu")))
49
+ model.eval()
50
 
51
  # Inference!
52
  def inference(img_path):
 
65
 
66
  # Predict
67
  result = []
 
68
  with torch.no_grad():
69
  image = image.to(DEVICE)
70
  logit = model(image).squeeze(0)