Spaces:
Sleeping
Sleeping
update to support gradio 4+
Browse files- app.py +24 -7
- requirements.txt +1 -1
- utils/load_model.py +9 -2
- utils/predict.py +10 -3
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
import io
|
3 |
import os
|
4 |
debug = False
|
@@ -29,7 +29,7 @@ PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
|
|
29 |
IMAGES_FOLDER = "data/images"
|
30 |
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
31 |
IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
|
32 |
-
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
|
33 |
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
|
34 |
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
|
35 |
# correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
|
@@ -269,12 +269,20 @@ def update_selected_image(event: gr.SelectData):
|
|
269 |
descs = {k: descs[k] for k in ORDERED_PARTS}
|
270 |
custom_text = [custom_class_name] + list(descs.values())
|
271 |
descriptions = ";\n".join(custom_text)
|
272 |
-
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
# modified_exp = gr.HTML().update(value="", visible=True)
|
274 |
return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
|
275 |
|
276 |
def on_edit_button_click_xclip():
|
277 |
-
empty_exp = gr.HTML.update(visible=False)
|
|
|
278 |
|
279 |
# Populate the textbox with current descriptions
|
280 |
descs = XCLIP_DESC[current_predicted_class.state]
|
@@ -282,7 +290,14 @@ def on_edit_button_click_xclip():
|
|
282 |
descs = {k: descs[k] for k in ORDERED_PARTS}
|
283 |
custom_text = ["class name: custom"] + list(descs.values())
|
284 |
descriptions = ";\n".join(custom_text)
|
285 |
-
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
return textbox, empty_exp
|
288 |
|
@@ -350,10 +365,12 @@ def on_predict_button_click_xclip(textbox_input: str):
|
|
350 |
custom_pred_markdown = f"""
|
351 |
### <span style='color:{custom_color}'> {custom_label} {custom_pred_score:.4f}</span>
|
352 |
"""
|
353 |
-
textbox = gr.Textbox.update(visible=False)
|
|
|
354 |
# return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
|
355 |
|
356 |
-
modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
|
|
|
357 |
return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
|
358 |
|
359 |
|
|
|
1 |
+
|
2 |
import io
|
3 |
import os
|
4 |
debug = False
|
|
|
29 |
IMAGES_FOLDER = "data/images"
|
30 |
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
31 |
IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
|
32 |
+
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt').to(DEVICE)
|
33 |
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
|
34 |
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
|
35 |
# correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
|
|
|
269 |
descs = {k: descs[k] for k in ORDERED_PARTS}
|
270 |
custom_text = [custom_class_name] + list(descs.values())
|
271 |
descriptions = ";\n".join(custom_text)
|
272 |
+
# textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
273 |
+
textbox = gr.Textbox(value=descriptions,
|
274 |
+
lines=12,
|
275 |
+
visible=True,
|
276 |
+
label="XCLIP descriptions",
|
277 |
+
interactive=True,
|
278 |
+
info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}',
|
279 |
+
show_label=False)
|
280 |
# modified_exp = gr.HTML().update(value="", visible=True)
|
281 |
return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
|
282 |
|
283 |
def on_edit_button_click_xclip():
|
284 |
+
# empty_exp = gr.HTML.update(visible=False)
|
285 |
+
empty_exp = gr.HTML(visible=False)
|
286 |
|
287 |
# Populate the textbox with current descriptions
|
288 |
descs = XCLIP_DESC[current_predicted_class.state]
|
|
|
290 |
descs = {k: descs[k] for k in ORDERED_PARTS}
|
291 |
custom_text = ["class name: custom"] + list(descs.values())
|
292 |
descriptions = ";\n".join(custom_text)
|
293 |
+
# textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
294 |
+
textbox = gr.Textbox(value=descriptions,
|
295 |
+
lines=12,
|
296 |
+
visible=True,
|
297 |
+
label="XCLIP descriptions",
|
298 |
+
interactive=True,
|
299 |
+
info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}',
|
300 |
+
show_label=False)
|
301 |
|
302 |
return textbox, empty_exp
|
303 |
|
|
|
365 |
custom_pred_markdown = f"""
|
366 |
### <span style='color:{custom_color}'> {custom_label} {custom_pred_score:.4f}</span>
|
367 |
"""
|
368 |
+
# textbox = gr.Textbox.update(visible=False)
|
369 |
+
textbox = gr.Textbox(visible=False)
|
370 |
# return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
|
371 |
|
372 |
+
# modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
|
373 |
+
modified_exp = gr.HTML(value=modified_explanation, visible=True)
|
374 |
return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
|
375 |
|
376 |
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
torch
|
2 |
torchvision
|
3 |
-
gradio
|
4 |
numpy
|
5 |
Pillow
|
6 |
transformers
|
|
|
1 |
torch
|
2 |
torchvision
|
3 |
+
gradio
|
4 |
numpy
|
5 |
Pillow
|
6 |
transformers
|
utils/load_model.py
CHANGED
@@ -1,12 +1,19 @@
|
|
1 |
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import torch
|
5 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
6 |
|
7 |
from .model import OwlViTForClassification
|
8 |
|
9 |
-
@
|
10 |
def load_xclip(device: str = "cuda:0",
|
11 |
n_classes: int = 183,
|
12 |
use_teacher_logits: bool = False,
|
|
|
1 |
|
2 |
|
3 |
+
try:
|
4 |
+
import spaces
|
5 |
+
gpu_decorator = spaces.GPU
|
6 |
+
except ImportError:
|
7 |
+
# Define a no-operation decorator as fallback
|
8 |
+
def gpu_decorator(func):
|
9 |
+
return func
|
10 |
+
|
11 |
import torch
|
12 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
13 |
|
14 |
from .model import OwlViTForClassification
|
15 |
|
16 |
+
@gpu_decorator
|
17 |
def load_xclip(device: str = "cuda:0",
|
18 |
n_classes: int = 183,
|
19 |
use_teacher_logits: bool = False,
|
utils/predict.py
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import PIL
|
3 |
import torch
|
4 |
|
@@ -30,7 +37,7 @@ def encode_descs_xclip(owlvit_det_processor: callable, model: callable, descs: l
|
|
30 |
# text_embeds = torch.cat(text_embeds, dim=0)
|
31 |
# text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
|
32 |
# return text_embeds.to(device)
|
33 |
-
@
|
34 |
def xclip_pred(new_desc: dict,
|
35 |
new_part_mask: dict,
|
36 |
new_class: str,
|
@@ -76,7 +83,7 @@ def xclip_pred(new_desc: dict,
|
|
76 |
n_classes = 201
|
77 |
query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
|
78 |
new_class_embed = model.owlvit.get_text_features(**query_tokens)
|
79 |
-
query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0)
|
80 |
modified_class_idx = 200
|
81 |
else:
|
82 |
n_classes = 200
|
|
|
1 |
+
try:
|
2 |
+
import spaces
|
3 |
+
gpu_decorator = spaces.GPU
|
4 |
+
except ImportError:
|
5 |
+
# Define a no-operation decorator as fallback
|
6 |
+
def gpu_decorator(func):
|
7 |
+
return func
|
8 |
+
|
9 |
import PIL
|
10 |
import torch
|
11 |
|
|
|
37 |
# text_embeds = torch.cat(text_embeds, dim=0)
|
38 |
# text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
|
39 |
# return text_embeds.to(device)
|
40 |
+
@gpu_decorator
|
41 |
def xclip_pred(new_desc: dict,
|
42 |
new_part_mask: dict,
|
43 |
new_class: str,
|
|
|
83 |
n_classes = 201
|
84 |
query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
|
85 |
new_class_embed = model.owlvit.get_text_features(**query_tokens)
|
86 |
+
query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0).to(device)
|
87 |
modified_class_idx = 200
|
88 |
else:
|
89 |
n_classes = 200
|