hasibzunair commited on
Commit
0b85db2
1 Parent(s): 7c33f91
Files changed (2) hide show
  1. app.py +1 -1
  2. pipeline/timm_utils/tuple.py +2 -2
app.py CHANGED
@@ -45,7 +45,7 @@ model = ResNet_CSRA(num_heads=1, lam=0.1, num_classes=20)
45
  normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
46
  model.to(DEVICE)
47
  print("Loading weights from {}".format("./models/msl_c_voc.pth"))
48
- model.load_state_dict(torch.load("./models/msl_c_voc.pth"))
49
 
50
  # Inference!
51
  def inference(img_path):
 
45
  normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
46
  model.to(DEVICE)
47
  print("Loading weights from {}".format("./models/msl_c_voc.pth"))
48
+ model.load_state_dict(torch.load("./models/msl_c_voc.pth"), map_location=torch.device('cpu'))
49
 
50
  # Inference!
51
  def inference(img_path):
pipeline/timm_utils/tuple.py CHANGED
@@ -3,8 +3,8 @@
3
  Hacked together by / Copyright 2020 Ross Wightman
4
  """
5
  from itertools import repeat
6
- import collections.abc as container_abcs
7
- #from torch._six import container_abcs
8
 
9
  # From PyTorch internals
10
  def _ntuple(n):
 
3
  Hacked together by / Copyright 2020 Ross Wightman
4
  """
5
  from itertools import repeat
6
+ #import collections.abc as container_abcs
7
+ from torch._six import container_abcs
8
 
9
  # From PyTorch internals
10
  def _ntuple(n):