File size: 2,293 Bytes
69af19f
 
 
 
 
 
 
e0af351
69af19f
 
 
 
 
51488de
69af19f
51488de
69af19f
51488de
69af19f
 
 
 
 
e0af351
69af19f
 
 
51488de
69af19f
 
51488de
69af19f
51488de
 
 
69af19f
 
 
 
 
51488de
69af19f
 
51488de
69af19f
 
 
51488de
69af19f
 
 
51488de
e0af351
 
 
 
 
 
 
 
 
69af19f
e0af351
 
 
 
 
 
69af19f
e0af351
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python

import pathlib

import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from sahi.prediction import ObjectPrediction
from sahi.utils.cv import visualize_object_predictions
from transformers import AutoImageProcessor, DetaForObjectDetection

DESCRIPTION = "# DETA (Detection Transformers with Assignment)"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

MODEL_ID = "jozhang97/deta-swin-large"
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = DetaForObjectDetection.from_pretrained(MODEL_ID)
model.to(device)


@spaces.GPU
@torch.inference_mode()
def run(image_path: str, threshold: float) -> np.ndarray:
    image = PIL.Image.open(image_path)
    inputs = image_processor(images=image, return_tensors="pt").to(device)
    outputs = model(**inputs)
    target_sizes = torch.tensor([image.size[::-1]])
    results = image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0]

    boxes = results["boxes"].cpu().numpy()
    scores = results["scores"].cpu().numpy()
    cat_ids = results["labels"].cpu().numpy().tolist()

    preds = []
    for box, score, cat_id in zip(boxes, scores, cat_ids):
        box = np.round(box).astype(int)
        cat_label = model.config.id2label[cat_id]
        pred = ObjectPrediction(bbox=box, category_id=cat_id, category_name=cat_label, score=score)
        preds.append(pred)

    res = visualize_object_predictions(np.asarray(image), preds)["image"]
    return res


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Input image", type="filepath")
            threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.01, value=0.1)
            run_button = gr.Button()
        result = gr.Image(label="Result")
    gr.Examples(
        examples=[[path, 0.1] for path in sorted(pathlib.Path("images").glob("*.jpg"))],
        inputs=[image, threshold],
        outputs=result,
        fn=run,
    )

    run_button.click(
        fn=run,
        inputs=[image, threshold],
        outputs=result,
        api_name="predict",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()