Hector Lopez commited on
Commit
cd4c90e
·
1 Parent(s): f08398e

First version

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. model.py +69 -0
  3. requirements.txt +2 -0
  4. web_app.py +41 -0
.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from icevision import *
2
+ import collections
3
+ import cv2
4
+ import PIL
5
+ import torch
6
+ import numpy as np
7
+ import torchvision
8
+
9
+ MODEL_TYPE = models.ross.efficientdet
10
+
11
+ def get_model(checkpoint_path):
12
+ extra_args = {}
13
+ backbone = MODEL_TYPE.backbones.d0
14
+ # The efficientdet model requires an img_size parameter
15
+ extra_args['img_size'] = 512
16
+
17
+ model = MODEL_TYPE.model(backbone=backbone(pretrained=True),
18
+ num_classes=2,
19
+ **extra_args)
20
+
21
+ ckpt = get_checkpoint(checkpoint_path)
22
+ model.load_state_dict(ckpt)
23
+
24
+ return model
25
+
26
+ def get_checkpoint(checkpoint_path):
27
+ ckpt = torch.load('checkpoint.ckpt')
28
+
29
+ fixed_state_dict = collections.OrderedDict()
30
+
31
+ for k, v in ckpt['state_dict'].items():
32
+ new_k = k[6:]
33
+ fixed_state_dict[new_k] = v
34
+
35
+ return fixed_state_dict
36
+
37
+ def predict(model, image):
38
+ img = PIL.Image.open(image)
39
+ class_map = ClassMap(classes=['Waste'])
40
+ transforms = tfms.A.Adapter([
41
+ *tfms.A.resize_and_pad(512),
42
+ tfms.A.Normalize()
43
+ ])
44
+
45
+ pred_dict = MODEL_TYPE.end2end_detect(img,
46
+ transforms,
47
+ model,
48
+ class_map=class_map,
49
+ detection_threshold=0.5,
50
+ return_as_pil_img=False,
51
+ return_img=True,
52
+ display_bbox=False,
53
+ display_score=False,
54
+ display_label=False)
55
+
56
+ return pred_dict
57
+
58
+ def prepare_prediction(pred_dict):
59
+ boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
60
+ boxes = torch.stack(boxes)
61
+
62
+ scores = torch.as_tensor(pred_dict['detection']['scores'])
63
+ labels = torch.as_tensor(pred_dict['detection']['label_ids'])
64
+ image = np.array(pred_dict['img'])
65
+
66
+ fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, 0.1)
67
+ boxes = boxes[fixed_boxes, :]
68
+
69
+ return boxes, image
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ icevision[full]
2
+ matplotlib
web_app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import cv2
5
+ import PIL
6
+
7
+ from model import get_model, predict, prepare_prediction
8
+
9
+ print('Creating the model')
10
+ model = get_model('checkpoint.ckpt')
11
+
12
+ def plot_img_no_mask(image, boxes):
13
+ # Show image
14
+ boxes = boxes.cpu().detach().numpy().astype(np.int32)
15
+ fig, ax = plt.subplots(1, 1, figsize=(12, 6))
16
+
17
+ for i, box in enumerate(boxes):
18
+ [x1, y1, x2, y2] = np.array(box).astype(int)
19
+ # Si no se hace la copia da error en cv2.rectangle
20
+ image = np.array(image).copy()
21
+
22
+ pt1 = (x1, y1)
23
+ pt2 = (x2, y2)
24
+ cv2.rectangle(image, pt1, pt2, (220,0,0), thickness=5)
25
+ plt.axis('off')
26
+ ax.imshow(image)
27
+ fig.savefig("img.png", bbox_inches='tight')
28
+
29
+ image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
30
+
31
+ if image_file is not None:
32
+ print(image_file)
33
+ print('Getting predictions')
34
+ pred_dict = predict(model, image_file)
35
+ print('Fixing the preds')
36
+ boxes, image = prepare_prediction(pred_dict)
37
+ print('Plotting')
38
+ plot_img_no_mask(image, boxes)
39
+
40
+ img = PIL.Image.open('img.png')
41
+ st.image(img,width=750)