File size: 5,239 Bytes
4723159
78d8c4b
 
4723159
 
 
 
 
fec6b29
 
4723159
78d8c4b
 
4723159
859ac57
 
 
 
 
 
 
4723159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ec7588
 
 
 
4723159
 
78d8c4b
b4d5e5f
4723159
 
 
78d8c4b
4723159
859ac57
 
 
 
 
56fe6ce
859ac57
 
 
 
78d8c4b
859ac57
4723159
 
 
78d8c4b
4723159
78d8c4b
4723159
 
78d8c4b
b4d5e5f
 
 
 
 
 
78d8c4b
 
 
 
90f022c
ee8717c
 
4723159
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import spaces
# from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import torch
import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda")
owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")

# dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
# dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")

english_candidate_labels = ["hat", "sunglass", "hair band", "glove", "arm sleeve", "watch", "singlet", "t-shirts", "energy gel", "half pants", "socks", "shoes", "ear phone"]
korean_candidate_labels = ["๋ชจ์ž", "์ฌ๊ธ€๋ผ์Šค", "ํ—ค์–ด๋ฐด๋“œ", "์žฅ๊ฐ‘", "ํŒ”ํ† ์‹œ", "์‹œ๊ณ„", "์‹ฑ๊ธ€๋ ›", "ํ‹ฐ์…”์ธ ", "์—๋„ˆ์ง€์ ค", "์‡ผ์ธ ๋ฐ”์ง€", "์–‘๋ง", "์‹ ๋ฐœ", "์ด์–ดํฐ"]
english_candidate_labels_string = ",".join(english_candidate_labels)

# ์˜๋ฌธ ๋ ˆ์ด๋ธ”์„ ํ•œ๊ธ€ ๋ ˆ์ด๋ธ”๋กœ ๋งค์นญํ•˜๋Š” ๋”•์…”๋„ˆ๋ฆฌ ์ƒ์„ฑ
label_mapping = dict(zip(english_candidate_labels, korean_candidate_labels))

@spaces.GPU
def infer(img, text_queries, score_threshold, model):
  
  if model == "dino":
    queries=""
    for query in text_queries:
      queries += f"{query}. "

    width, height = img.shape[:2]

    target_sizes=[(width, height)]
    inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)

    with torch.no_grad():
      outputs = dino_model(**inputs)
      outputs.logits = outputs.logits.cpu()
      outputs.pred_boxes = outputs.pred_boxes.cpu()
      results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
                                                                    box_threshold=score_threshold,
                                                                    target_sizes=target_sizes)
  elif model == "owl":
    size = max(img.shape[:2])
    target_sizes = torch.Tensor([[size, size]])
    inputs = owl_processor(text=text_queries, images=img, return_tensors="pt").to(device)

    with torch.no_grad():
      outputs = owl_model(**inputs)
      outputs.logits = outputs.logits.cpu()
      outputs.pred_boxes = outputs.pred_boxes.cpu()
      results = owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
      
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
  result_labels = []

  for box, score, label in zip(boxes, scores, labels):
      box = [int(i) for i in box.tolist()]
      if score < score_threshold:
          continue
      if model == "owl":
        label = text_queries[label.cpu().item()]
        result_labels.append((box, label))
      elif model == "dino":
        if label != "":
            result_labels.append((box, label))
  return result_labels

# def query_image(img, text_queries, owl_threshold, dino_threshold):
def query_image(img, text_queries, owl_threshold, flag_output_korean):
    text_queries = text_queries
    text_queries = text_queries.split(",")
    owl_output = infer(img, text_queries, owl_threshold, "owl")
    # dino_output = infer(img, text_queries, dino_threshold, "dino")

    # add - check flag output korean
    owl_output_final = []
    if flag_output_korean:
        for box, label in owl_output:
            kor_label = label_mapping[label]
            owl_output_final.append((box, kor_label))
        
    else:
        owl_output_final = owl_output
        
    # return (img, owl_output), (img, dino_output)
    return (img, owl_output_final)


owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
# dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
owl_output = gr.AnnotatedImage(label="OWL Output")
# dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
demo = gr.Interface(
    query_image,
    # inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold],
    inputs=[
        gr.Image(label="Input Image"), 
        gr.Textbox(label="Candidate Labels", value=english_candidate_labels_string), 
        owl_threshold,
        gr.Checkbox(label="Output labels Korean")
    ],
    # outputs=[owl_output, dino_output],
    outputs=[owl_output],
    title="OWLv2 Demo",
    description="Compare two state-of-the-art zero-shot object detection models [OWLv2](https://huggingface.co/google/owlv2-base-patch16) . Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in each model.",
    # examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]]
    # examples=[["./rs_sample1.jpg", english_candidate_labels_string, 0.16, 0.12], ["./rs_sample2.jpg", english_candidate_labels_string, 0.13, 0.10]]
    examples=[["./rs_sample1.jpg", english_candidate_labels_string, 0.16, 0.12], ["./rs_sample2.jpg", english_candidate_labels_string, 0.13, False]]
)
demo.launch(debug=True)