Gaejoon commited on
Commit
859ac57
ยท
verified ยท
1 Parent(s): b4d5e5f

Update app.py

Browse files

output label - korean by checkbox

Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -12,6 +12,13 @@ owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemb
12
  # dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
13
  # dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
14
 
 
 
 
 
 
 
 
15
  @spaces.GPU
16
  def infer(img, text_queries, score_threshold, model):
17
 
@@ -65,13 +72,19 @@ def query_image(img, text_queries, owl_threshold, flag_output_korean):
65
  owl_output = infer(img, text_queries, owl_threshold, "owl")
66
  # dino_output = infer(img, text_queries, dino_threshold, "dino")
67
 
68
-
 
 
 
 
 
 
 
 
 
69
  # return (img, owl_output), (img, dino_output)
70
- return (img, owl_output)
71
 
72
- english_candidate_labels = ["hat", "sunglass", "hair band", "glove", "arm sleeve", "watch", "singlet", "t-shirts", "energy gel", "half pants", "socks", "shoes", "ear phone"]
73
- korean_candidate_labels = ["๋ชจ์ž", "์ฌ๊ธ€๋ผ์Šค", "ํ—ค์–ด๋ฐด๋“œ", "์žฅ๊ฐ‘", "ํŒ”ํ† ์‹œ", "์‹œ๊ณ„", "์‹ฑ๊ธ€๋ ›", "ํ‹ฐ์…”์ธ ", "์—๋„ˆ์ง€์ ค", "์‡ผ์ธ ๋ฐ”์ง€", "์–‘๋ง", "์‹ ๋ฐœ", "์ด์–ดํฐ"]
74
- english_candidate_labels_string = ",".join(english_candidate_labels)
75
 
76
  owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
77
  # dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
 
12
  # dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
13
  # dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
14
 
15
+ english_candidate_labels = ["hat", "sunglass", "hair band", "glove", "arm sleeve", "watch", "singlet", "t-shirts", "energy gel", "half pants", "socks", "shoes", "ear phone"]
16
+ korean_candidate_labels = ["๋ชจ์ž", "์ฌ๊ธ€๋ผ์Šค", "ํ—ค์–ด๋ฐด๋“œ", "์žฅ๊ฐ‘", "ํŒ”ํ† ์‹œ", "์‹œ๊ณ„", "์‹ฑ๊ธ€๋ ›", "ํ‹ฐ์…”์ธ ", "์—๋„ˆ์ง€์ ค", "์‡ผ์ธ ๋ฐ”์ง€", "์–‘๋ง", "์‹ ๋ฐœ", "์ด์–ดํฐ"]
17
+ english_candidate_labels_string = ",".join(english_candidate_labels)
18
+
19
+ # ์˜๋ฌธ ๋ ˆ์ด๋ธ”์„ ํ•œ๊ธ€ ๋ ˆ์ด๋ธ”๋กœ ๋งค์นญํ•˜๋Š” ๋”•์…”๋„ˆ๋ฆฌ ์ƒ์„ฑ
20
+ label_mapping = dict(zip(english_candidate_labels, korean_candidate_labels))
21
+
22
  @spaces.GPU
23
  def infer(img, text_queries, score_threshold, model):
24
 
 
72
  owl_output = infer(img, text_queries, owl_threshold, "owl")
73
  # dino_output = infer(img, text_queries, dino_threshold, "dino")
74
 
75
+ # add - check flag output korean
76
+ owl_output_final = []
77
+ if flag_output_korean:
78
+ for box, label in owl_output:
79
+ kor_label = label_mapping[label]
80
+ owl_output_final.append(box, kor_label)
81
+
82
+ else:
83
+ owl_output_final = owl_output
84
+
85
  # return (img, owl_output), (img, dino_output)
86
+ return (img, owl_output_final)
87
 
 
 
 
88
 
89
  owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
90
  # dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")