File size: 2,293 Bytes
69af19f e0af351 69af19f 51488de 69af19f 51488de 69af19f 51488de 69af19f e0af351 69af19f 51488de 69af19f 51488de 69af19f 51488de 69af19f 51488de 69af19f 51488de 69af19f 51488de 69af19f 51488de e0af351 69af19f e0af351 69af19f e0af351 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
#!/usr/bin/env python
import pathlib
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from sahi.prediction import ObjectPrediction
from sahi.utils.cv import visualize_object_predictions
from transformers import AutoImageProcessor, DetaForObjectDetection
DESCRIPTION = "# DETA (Detection Transformers with Assignment)"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_ID = "jozhang97/deta-swin-large"
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = DetaForObjectDetection.from_pretrained(MODEL_ID)
model.to(device)
@spaces.GPU
@torch.inference_mode()
def run(image_path: str, threshold: float) -> np.ndarray:
image = PIL.Image.open(image_path)
inputs = image_processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0]
boxes = results["boxes"].cpu().numpy()
scores = results["scores"].cpu().numpy()
cat_ids = results["labels"].cpu().numpy().tolist()
preds = []
for box, score, cat_id in zip(boxes, scores, cat_ids):
box = np.round(box).astype(int)
cat_label = model.config.id2label[cat_id]
pred = ObjectPrediction(bbox=box, category_id=cat_id, category_name=cat_label, score=score)
preds.append(pred)
res = visualize_object_predictions(np.asarray(image), preds)["image"]
return res
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
image = gr.Image(label="Input image", type="filepath")
threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.01, value=0.1)
run_button = gr.Button()
result = gr.Image(label="Result")
gr.Examples(
examples=[[path, 0.1] for path in sorted(pathlib.Path("images").glob("*.jpg"))],
inputs=[image, threshold],
outputs=result,
fn=run,
)
run_button.click(
fn=run,
inputs=[image, threshold],
outputs=result,
api_name="predict",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|