Rub茅n Escobedo commited on
Commit
61ad881
1 Parent(s): 98f448a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -1,23 +1,19 @@
1
  from fastai.vision.all import *
2
  from icevision.all import *
3
- import PIL
4
  import gradio as gr
5
 
6
- # Cargamos el learner
7
- learn = load_learner('model.pth')
8
-
9
  # Definimos una funci贸n que se encarga de llevar a cabo las predicciones
10
  def predict(img):
11
- class_map = ['Kangaroo']
12
  model = models.torchvision.faster_rcnn.model(backbone=models.torchvision.faster_rcnn.backbones.resnet18_fpn, num_classes=len(class_map))
13
  state_dict = torch.load('model.pth')
14
  model.load_state_dict(state_dict)
15
 
16
- infer_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size),tfms.A.Normalize()])
17
 
18
- img = PIL.Image.open('kangarooc.jpg')
19
  pred_dict = models.torchvision.faster_rcnn.end2end_detect(img, infer_tfms, model.to("cpu"), class_map=class_map, detection_threshold=0.5)
20
  return pred_dict['img']
21
 
22
  # Creamos la interfaz y la lanzamos.
23
- gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Label(num_top_classes=3),examples=['kangarooc.jpg']).launch(share=False)
 
1
  from fastai.vision.all import *
2
  from icevision.all import *
 
3
  import gradio as gr
4
 
 
 
 
5
  # Definimos una funci贸n que se encarga de llevar a cabo las predicciones
6
  def predict(img):
7
+ class_map = ClassMap(['Kangaroo'])
8
  model = models.torchvision.faster_rcnn.model(backbone=models.torchvision.faster_rcnn.backbones.resnet18_fpn, num_classes=len(class_map))
9
  state_dict = torch.load('model.pth')
10
  model.load_state_dict(state_dict)
11
 
12
+ infer_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(384),tfms.A.Normalize()])
13
 
14
+ img = PILImage.create(img)
15
  pred_dict = models.torchvision.faster_rcnn.end2end_detect(img, infer_tfms, model.to("cpu"), class_map=class_map, detection_threshold=0.5)
16
  return pred_dict['img']
17
 
18
  # Creamos la interfaz y la lanzamos.
19
+ gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Image(shape=(128, 128)),examples=['kangarooc.jpg']).launch(share=False)