Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import PIL
|
|
6 |
from PIL import Image
|
7 |
|
8 |
# Cargamos el learner
|
9 |
-
|
10 |
|
11 |
# Definimos una funci贸n que se encarga de llevar a cabo las predicciones
|
12 |
def predict(img):
|
@@ -14,7 +14,7 @@ def predict(img):
|
|
14 |
size = 384
|
15 |
class_map = ClassMap(['kangaroo'])
|
16 |
infer_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size),tfms.A.Normalize()])
|
17 |
-
pred_dict = models.torchvision.faster_rcnn.end2end_detect(img, infer_tfms,
|
18 |
#return pred_dict['img']
|
19 |
return img
|
20 |
|
|
|
6 |
from PIL import Image
|
7 |
|
8 |
# Cargamos el learner
|
9 |
+
learner = torch.load('fasterRCNNKangaroo_obligatorio.pth',map_location='cpu')
|
10 |
|
11 |
# Definimos una funci贸n que se encarga de llevar a cabo las predicciones
|
12 |
def predict(img):
|
|
|
14 |
size = 384
|
15 |
class_map = ClassMap(['kangaroo'])
|
16 |
infer_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size),tfms.A.Normalize()])
|
17 |
+
pred_dict = models.torchvision.faster_rcnn.end2end_detect(img, infer_tfms, learner.to("cpu"), class_map=class_map, detection_threshold=0.5)
|
18 |
#return pred_dict['img']
|
19 |
return img
|
20 |
|