Spaces:
Runtime error
Runtime error
Hector Lopez
commited on
Commit
·
bc3d4e9
1
Parent(s):
f890c24
Properly load of the detection model
Browse files
app.py
CHANGED
@@ -9,8 +9,6 @@ from model import get_model, predict, prepare_prediction, predict_class
|
|
9 |
DET_CKPT = 'efficientDet_icevision.ckpt'
|
10 |
CLASS_CKPT = 'class_ViT_taco_7_class.pth'
|
11 |
|
12 |
-
|
13 |
-
|
14 |
st.subheader('Upload Custom Image')
|
15 |
|
16 |
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
|
|
|
9 |
DET_CKPT = 'efficientDet_icevision.ckpt'
|
10 |
CLASS_CKPT = 'class_ViT_taco_7_class.pth'
|
11 |
|
|
|
|
|
12 |
st.subheader('Upload Custom Image')
|
13 |
|
14 |
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
|
model.py
CHANGED
@@ -1,30 +1,28 @@
|
|
1 |
from io import BytesIO
|
2 |
from typing import Union
|
3 |
from icevision import *
|
|
|
|
|
|
|
|
|
4 |
import collections
|
5 |
import PIL
|
6 |
import torch
|
7 |
import numpy as np
|
8 |
import torchvision
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
import icevision.models.ross.efficientdet
|
13 |
-
|
14 |
-
MODEL_TYPE = icevision.models.ross.efficientdet
|
15 |
|
16 |
def get_model(checkpoint_path : str):
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
ckpt = get_checkpoint(checkpoint_path)
|
27 |
-
model.load_state_dict(ckpt)
|
28 |
|
29 |
return model
|
30 |
|
@@ -76,7 +74,7 @@ def prepare_prediction(pred_dict, threshold):
|
|
76 |
|
77 |
return boxes, image
|
78 |
|
79 |
-
def predict_class(
|
80 |
preds = []
|
81 |
|
82 |
for bbox in bboxes:
|
@@ -84,13 +82,12 @@ def predict_class(model, image, bboxes):
|
|
84 |
bbox = np.array(bbox).astype(int)
|
85 |
cropped_img = PIL.Image.fromarray(img).crop(bbox)
|
86 |
cropped_img = np.array(cropped_img)
|
87 |
-
#cropped_img = torch.as_tensor(cropped_img, dtype=torch.float).unsqueeze(0)
|
88 |
|
89 |
tran_image = transform_image(cropped_img, 224)
|
90 |
tran_image = tran_image.transpose(2, 0, 1)
|
91 |
tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
|
92 |
print(tran_image.shape)
|
93 |
-
y_preds =
|
94 |
preds.append(y_preds.softmax(1).detach().numpy())
|
95 |
|
96 |
preds = np.concatenate(preds).argmax(1)
|
|
|
1 |
from io import BytesIO
|
2 |
from typing import Union
|
3 |
from icevision import *
|
4 |
+
from icevision.models.checkpoint import model_from_checkpoint
|
5 |
+
from classifier import transform_image
|
6 |
+
from icevision.models import ross
|
7 |
+
|
8 |
import collections
|
9 |
import PIL
|
10 |
import torch
|
11 |
import numpy as np
|
12 |
import torchvision
|
13 |
|
14 |
+
MODEL_TYPE = ross.efficientdet
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def get_model(checkpoint_path : str):
|
17 |
+
checkpoint_and_model = model_from_checkpoint(
|
18 |
+
checkpoint_path,
|
19 |
+
model_name='ross.efficientdet',
|
20 |
+
backbone_name='d0',
|
21 |
+
img_size=512,
|
22 |
+
classes=['Waste'],
|
23 |
+
revise_keys=[(r'^model\.', '')])
|
24 |
+
|
25 |
+
model = checkpoint_and_model['model']
|
|
|
|
|
26 |
|
27 |
return model
|
28 |
|
|
|
74 |
|
75 |
return boxes, image
|
76 |
|
77 |
+
def predict_class(classifier, image, bboxes):
|
78 |
preds = []
|
79 |
|
80 |
for bbox in bboxes:
|
|
|
82 |
bbox = np.array(bbox).astype(int)
|
83 |
cropped_img = PIL.Image.fromarray(img).crop(bbox)
|
84 |
cropped_img = np.array(cropped_img)
|
|
|
85 |
|
86 |
tran_image = transform_image(cropped_img, 224)
|
87 |
tran_image = tran_image.transpose(2, 0, 1)
|
88 |
tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
|
89 |
print(tran_image.shape)
|
90 |
+
y_preds = classifier(tran_image)
|
91 |
preds.append(y_preds.softmax(1).detach().numpy())
|
92 |
|
93 |
preds = np.concatenate(preds).argmax(1)
|