Hector Lopez commited on
Commit
bc3d4e9
·
1 Parent(s): f890c24

Properly load of the detection model

Browse files
Files changed (2) hide show
  1. app.py +0 -2
  2. model.py +16 -19
app.py CHANGED
@@ -9,8 +9,6 @@ from model import get_model, predict, prepare_prediction, predict_class
9
  DET_CKPT = 'efficientDet_icevision.ckpt'
10
  CLASS_CKPT = 'class_ViT_taco_7_class.pth'
11
 
12
-
13
-
14
  st.subheader('Upload Custom Image')
15
 
16
  image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
 
9
  DET_CKPT = 'efficientDet_icevision.ckpt'
10
  CLASS_CKPT = 'class_ViT_taco_7_class.pth'
11
 
 
 
12
  st.subheader('Upload Custom Image')
13
 
14
  image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
model.py CHANGED
@@ -1,30 +1,28 @@
1
  from io import BytesIO
2
  from typing import Union
3
  from icevision import *
 
 
 
 
4
  import collections
5
  import PIL
6
  import torch
7
  import numpy as np
8
  import torchvision
9
 
10
- from classifier import transform_image
11
-
12
- import icevision.models.ross.efficientdet
13
-
14
- MODEL_TYPE = icevision.models.ross.efficientdet
15
 
16
  def get_model(checkpoint_path : str):
17
- extra_args = {}
18
- backbone = MODEL_TYPE.backbones.d0
19
- # The efficientdet model requires an img_size parameter
20
- extra_args['img_size'] = 512
21
-
22
- model = MODEL_TYPE.model(backbone=backbone(pretrained=True),
23
- num_classes=2,
24
- **extra_args)
25
-
26
- ckpt = get_checkpoint(checkpoint_path)
27
- model.load_state_dict(ckpt)
28
 
29
  return model
30
 
@@ -76,7 +74,7 @@ def prepare_prediction(pred_dict, threshold):
76
 
77
  return boxes, image
78
 
79
- def predict_class(model, image, bboxes):
80
  preds = []
81
 
82
  for bbox in bboxes:
@@ -84,13 +82,12 @@ def predict_class(model, image, bboxes):
84
  bbox = np.array(bbox).astype(int)
85
  cropped_img = PIL.Image.fromarray(img).crop(bbox)
86
  cropped_img = np.array(cropped_img)
87
- #cropped_img = torch.as_tensor(cropped_img, dtype=torch.float).unsqueeze(0)
88
 
89
  tran_image = transform_image(cropped_img, 224)
90
  tran_image = tran_image.transpose(2, 0, 1)
91
  tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
92
  print(tran_image.shape)
93
- y_preds = model(tran_image)
94
  preds.append(y_preds.softmax(1).detach().numpy())
95
 
96
  preds = np.concatenate(preds).argmax(1)
 
1
  from io import BytesIO
2
  from typing import Union
3
  from icevision import *
4
+ from icevision.models.checkpoint import model_from_checkpoint
5
+ from classifier import transform_image
6
+ from icevision.models import ross
7
+
8
  import collections
9
  import PIL
10
  import torch
11
  import numpy as np
12
  import torchvision
13
 
14
+ MODEL_TYPE = ross.efficientdet
 
 
 
 
15
 
16
  def get_model(checkpoint_path : str):
17
+ checkpoint_and_model = model_from_checkpoint(
18
+ checkpoint_path,
19
+ model_name='ross.efficientdet',
20
+ backbone_name='d0',
21
+ img_size=512,
22
+ classes=['Waste'],
23
+ revise_keys=[(r'^model\.', '')])
24
+
25
+ model = checkpoint_and_model['model']
 
 
26
 
27
  return model
28
 
 
74
 
75
  return boxes, image
76
 
77
+ def predict_class(classifier, image, bboxes):
78
  preds = []
79
 
80
  for bbox in bboxes:
 
82
  bbox = np.array(bbox).astype(int)
83
  cropped_img = PIL.Image.fromarray(img).crop(bbox)
84
  cropped_img = np.array(cropped_img)
 
85
 
86
  tran_image = transform_image(cropped_img, 224)
87
  tran_image = tran_image.transpose(2, 0, 1)
88
  tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
89
  print(tran_image.shape)
90
+ y_preds = classifier(tran_image)
91
  preds.append(y_preds.softmax(1).detach().numpy())
92
 
93
  preds = np.concatenate(preds).argmax(1)