Spaces:
Runtime error
Runtime error
File size: 5,239 Bytes
4723159 78d8c4b 4723159 fec6b29 4723159 78d8c4b 4723159 859ac57 4723159 6ec7588 4723159 78d8c4b b4d5e5f 4723159 78d8c4b 4723159 859ac57 56fe6ce 859ac57 78d8c4b 859ac57 4723159 78d8c4b 4723159 78d8c4b 4723159 78d8c4b b4d5e5f 78d8c4b 90f022c ee8717c 4723159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import spaces
# from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import torch
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda")
owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
# dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
# dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
english_candidate_labels = ["hat", "sunglass", "hair band", "glove", "arm sleeve", "watch", "singlet", "t-shirts", "energy gel", "half pants", "socks", "shoes", "ear phone"]
korean_candidate_labels = ["๋ชจ์", "์ฌ๊ธ๋ผ์ค", "ํค์ด๋ฐด๋", "์ฅ๊ฐ", "ํํ ์", "์๊ณ", "์ฑ๊ธ๋ ", "ํฐ์
์ธ ", "์๋์ง์ ค", "์ผ์ธ ๋ฐ์ง", "์๋ง", "์ ๋ฐ", "์ด์ดํฐ"]
english_candidate_labels_string = ",".join(english_candidate_labels)
# ์๋ฌธ ๋ ์ด๋ธ์ ํ๊ธ ๋ ์ด๋ธ๋ก ๋งค์นญํ๋ ๋์
๋๋ฆฌ ์์ฑ
label_mapping = dict(zip(english_candidate_labels, korean_candidate_labels))
@spaces.GPU
def infer(img, text_queries, score_threshold, model):
if model == "dino":
queries=""
for query in text_queries:
queries += f"{query}. "
width, height = img.shape[:2]
target_sizes=[(width, height)]
inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = dino_model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
box_threshold=score_threshold,
target_sizes=target_sizes)
elif model == "owl":
size = max(img.shape[:2])
target_sizes = torch.Tensor([[size, size]])
inputs = owl_processor(text=text_queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = owl_model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
result_labels = []
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score < score_threshold:
continue
if model == "owl":
label = text_queries[label.cpu().item()]
result_labels.append((box, label))
elif model == "dino":
if label != "":
result_labels.append((box, label))
return result_labels
# def query_image(img, text_queries, owl_threshold, dino_threshold):
def query_image(img, text_queries, owl_threshold, flag_output_korean):
text_queries = text_queries
text_queries = text_queries.split(",")
owl_output = infer(img, text_queries, owl_threshold, "owl")
# dino_output = infer(img, text_queries, dino_threshold, "dino")
# add - check flag output korean
owl_output_final = []
if flag_output_korean:
for box, label in owl_output:
kor_label = label_mapping[label]
owl_output_final.append((box, kor_label))
else:
owl_output_final = owl_output
# return (img, owl_output), (img, dino_output)
return (img, owl_output_final)
owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
# dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
owl_output = gr.AnnotatedImage(label="OWL Output")
# dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
demo = gr.Interface(
query_image,
# inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold],
inputs=[
gr.Image(label="Input Image"),
gr.Textbox(label="Candidate Labels", value=english_candidate_labels_string),
owl_threshold,
gr.Checkbox(label="Output labels Korean")
],
# outputs=[owl_output, dino_output],
outputs=[owl_output],
title="OWLv2 Demo",
description="Compare two state-of-the-art zero-shot object detection models [OWLv2](https://huggingface.co/google/owlv2-base-patch16) . Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in each model.",
# examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]]
# examples=[["./rs_sample1.jpg", english_candidate_labels_string, 0.16, 0.12], ["./rs_sample2.jpg", english_candidate_labels_string, 0.13, 0.10]]
examples=[["./rs_sample1.jpg", english_candidate_labels_string, 0.16, 0.12], ["./rs_sample2.jpg", english_candidate_labels_string, 0.13, False]]
)
demo.launch(debug=True) |