Spaces:
Build error
Build error
import gradio as gr | |
import PIL | |
import torch | |
from fastai.vision.all import models | |
from icevision.all import ClassMap, tfms | |
class_map = ClassMap(["racoon"]) | |
state_dict = torch.load("fasterRCNNRaccoons.pth") | |
model = models.torchvision.faster_rcnn.model( | |
backbone=models.torchvision.faster_rcnn.backbones.resnet50_fpn, | |
num_classes=len(class_map), | |
) | |
model.load_state_dict(state_dict) | |
size = 384 | |
def predict(img_file): | |
img = PIL.Image.open(img_file) | |
infer_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size), tfms.A.Normalize()]) | |
pred_dict = models.torchvision.faster_rcnn.end2end_detect( | |
img, infer_tfms, model.to("cpu"), class_map=class_map, detection_threshold=0.5 | |
) | |
return pred_dict["img"] | |
gr.Interface( | |
fn=predict, | |
inputs=gr.inputs.Image(shape=(128, 128)), | |
outputs=gr.outputs.Label(num_top_classes=3), | |
examples=["raccoon-27.jpg", "raccoon-20.jpg"], | |
).launch(share=False) | |