Spaces:
Runtime error
Runtime error
Hector Lopez
commited on
Commit
·
cd4c90e
1
Parent(s):
f08398e
First version
Browse files- .gitattributes +1 -0
- model.py +69 -0
- requirements.txt +2 -0
- web_app.py +41 -0
.gitattributes
CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
model.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from icevision import *
|
2 |
+
import collections
|
3 |
+
import cv2
|
4 |
+
import PIL
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
MODEL_TYPE = models.ross.efficientdet
|
10 |
+
|
11 |
+
def get_model(checkpoint_path):
|
12 |
+
extra_args = {}
|
13 |
+
backbone = MODEL_TYPE.backbones.d0
|
14 |
+
# The efficientdet model requires an img_size parameter
|
15 |
+
extra_args['img_size'] = 512
|
16 |
+
|
17 |
+
model = MODEL_TYPE.model(backbone=backbone(pretrained=True),
|
18 |
+
num_classes=2,
|
19 |
+
**extra_args)
|
20 |
+
|
21 |
+
ckpt = get_checkpoint(checkpoint_path)
|
22 |
+
model.load_state_dict(ckpt)
|
23 |
+
|
24 |
+
return model
|
25 |
+
|
26 |
+
def get_checkpoint(checkpoint_path):
|
27 |
+
ckpt = torch.load('checkpoint.ckpt')
|
28 |
+
|
29 |
+
fixed_state_dict = collections.OrderedDict()
|
30 |
+
|
31 |
+
for k, v in ckpt['state_dict'].items():
|
32 |
+
new_k = k[6:]
|
33 |
+
fixed_state_dict[new_k] = v
|
34 |
+
|
35 |
+
return fixed_state_dict
|
36 |
+
|
37 |
+
def predict(model, image):
|
38 |
+
img = PIL.Image.open(image)
|
39 |
+
class_map = ClassMap(classes=['Waste'])
|
40 |
+
transforms = tfms.A.Adapter([
|
41 |
+
*tfms.A.resize_and_pad(512),
|
42 |
+
tfms.A.Normalize()
|
43 |
+
])
|
44 |
+
|
45 |
+
pred_dict = MODEL_TYPE.end2end_detect(img,
|
46 |
+
transforms,
|
47 |
+
model,
|
48 |
+
class_map=class_map,
|
49 |
+
detection_threshold=0.5,
|
50 |
+
return_as_pil_img=False,
|
51 |
+
return_img=True,
|
52 |
+
display_bbox=False,
|
53 |
+
display_score=False,
|
54 |
+
display_label=False)
|
55 |
+
|
56 |
+
return pred_dict
|
57 |
+
|
58 |
+
def prepare_prediction(pred_dict):
|
59 |
+
boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
|
60 |
+
boxes = torch.stack(boxes)
|
61 |
+
|
62 |
+
scores = torch.as_tensor(pred_dict['detection']['scores'])
|
63 |
+
labels = torch.as_tensor(pred_dict['detection']['label_ids'])
|
64 |
+
image = np.array(pred_dict['img'])
|
65 |
+
|
66 |
+
fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, 0.1)
|
67 |
+
boxes = boxes[fixed_boxes, :]
|
68 |
+
|
69 |
+
return boxes, image
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
icevision[full]
|
2 |
+
matplotlib
|
web_app.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import PIL
|
6 |
+
|
7 |
+
from model import get_model, predict, prepare_prediction
|
8 |
+
|
9 |
+
print('Creating the model')
|
10 |
+
model = get_model('checkpoint.ckpt')
|
11 |
+
|
12 |
+
def plot_img_no_mask(image, boxes):
|
13 |
+
# Show image
|
14 |
+
boxes = boxes.cpu().detach().numpy().astype(np.int32)
|
15 |
+
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
|
16 |
+
|
17 |
+
for i, box in enumerate(boxes):
|
18 |
+
[x1, y1, x2, y2] = np.array(box).astype(int)
|
19 |
+
# Si no se hace la copia da error en cv2.rectangle
|
20 |
+
image = np.array(image).copy()
|
21 |
+
|
22 |
+
pt1 = (x1, y1)
|
23 |
+
pt2 = (x2, y2)
|
24 |
+
cv2.rectangle(image, pt1, pt2, (220,0,0), thickness=5)
|
25 |
+
plt.axis('off')
|
26 |
+
ax.imshow(image)
|
27 |
+
fig.savefig("img.png", bbox_inches='tight')
|
28 |
+
|
29 |
+
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
|
30 |
+
|
31 |
+
if image_file is not None:
|
32 |
+
print(image_file)
|
33 |
+
print('Getting predictions')
|
34 |
+
pred_dict = predict(model, image_file)
|
35 |
+
print('Fixing the preds')
|
36 |
+
boxes, image = prepare_prediction(pred_dict)
|
37 |
+
print('Plotting')
|
38 |
+
plot_img_no_mask(image, boxes)
|
39 |
+
|
40 |
+
img = PIL.Image.open('img.png')
|
41 |
+
st.image(img,width=750)
|