Update app.py
Browse files
app.py
CHANGED
@@ -20,154 +20,13 @@ from viz_utils import save_results_as_json, draw_keypoints_on_image, draw_bbox_w
|
|
20 |
from detection_utils import predict_md, crop_animal_detections, predict_dlc
|
21 |
from ui_utils import gradio_inputs_for_MD_DLC, gradio_outputs_for_MD_DLC, gradio_description_and_examples
|
22 |
|
23 |
-
#
|
24 |
-
#########################################
|
25 |
-
# Input params - Global vars
|
26 |
|
27 |
-
|
28 |
-
'md_v5b': "MD_models/md_v5b.0.0.pt"}
|
29 |
-
|
30 |
-
# DLC models target dirs
|
31 |
-
DLC_models_dict = {'supernanimal_quadruped': "DLC_models/supernanimal_quadruped/",
|
32 |
-
'supernanimal_topviewmouse': "DLC_models/supernanimal_topviewmouse/"
|
33 |
-
}
|
34 |
-
|
35 |
-
def predict_pipeline(img_input,
|
36 |
-
mega_model_input,
|
37 |
-
dlc_model_input_str,
|
38 |
-
flag_dlc_only,
|
39 |
-
flag_show_str_labels,
|
40 |
-
bbox_likelihood_th,
|
41 |
-
kpts_likelihood_th,
|
42 |
-
font_style,
|
43 |
-
font_size,
|
44 |
-
keypt_color,
|
45 |
-
marker_size,
|
46 |
-
):
|
47 |
-
|
48 |
-
if not flag_dlc_only:
|
49 |
-
############################################################
|
50 |
-
# ### Run Megadetector
|
51 |
-
md_results = predict_md(img_input,
|
52 |
-
MD_models_dict[mega_model_input], #mega_model_input,
|
53 |
-
size=640) #Image.fromarray(results.imgs[0])
|
54 |
-
|
55 |
-
################################################################
|
56 |
-
# Obtain animal crops for bboxes with confidence above th
|
57 |
-
list_crops = crop_animal_detections(img_input,
|
58 |
-
md_results,
|
59 |
-
bbox_likelihood_th)
|
60 |
-
|
61 |
-
############################################################
|
62 |
-
## Get DLC model and label map
|
63 |
-
|
64 |
-
# If model is found: do not download (previous execution is likely within same day)
|
65 |
-
# TODO: can we ask the user whether to reload dlc model if a directory is found?
|
66 |
-
if os.path.isdir(DLC_models_dict[dlc_model_input_str]) and \
|
67 |
-
len(os.listdir(DLC_models_dict[dlc_model_input_str])) > 0:
|
68 |
-
path_to_DLCmodel = DLC_models_dict[dlc_model_input_str]
|
69 |
-
else:
|
70 |
-
path_to_DLCmodel = DownloadModel(dlc_model_input_str,
|
71 |
-
DLC_models_dict[dlc_model_input_str])
|
72 |
-
|
73 |
-
# extract map label ids to strings
|
74 |
-
pose_cfg_path = os.path.join(DLC_models_dict[dlc_model_input_str],
|
75 |
-
'pose_cfg.yaml')
|
76 |
-
with open(pose_cfg_path, "r") as stream:
|
77 |
-
pose_cfg_dict = yaml.safe_load(stream)
|
78 |
-
map_label_id_to_str = dict([(k,v) for k,v in zip([el[0] for el in pose_cfg_dict['all_joints']], # pose_cfg_dict['all_joints'] is a list of one-element lists,
|
79 |
-
pose_cfg_dict['all_joints_names'])])
|
80 |
-
|
81 |
-
##############################################################
|
82 |
-
# Run DLC and visualise results
|
83 |
-
dlc_proc = Processor()
|
84 |
-
|
85 |
-
# if required: ignore MD crops and run DLC on full image [mostly for testing]
|
86 |
-
if flag_dlc_only:
|
87 |
-
# compute kpts on input img
|
88 |
-
list_kpts_per_crop = predict_dlc([np.asarray(img_input)],
|
89 |
-
kpts_likelihood_th,
|
90 |
-
path_to_DLCmodel,
|
91 |
-
dlc_proc)
|
92 |
-
# draw kpts on input img #fix!
|
93 |
-
draw_keypoints_on_image(img_input,
|
94 |
-
list_kpts_per_crop[0], # a numpy array with shape [num_keypoints, 2].
|
95 |
-
map_label_id_to_str,
|
96 |
-
flag_show_str_labels,
|
97 |
-
use_normalized_coordinates=False,
|
98 |
-
font_style=font_style,
|
99 |
-
font_size=font_size,
|
100 |
-
keypt_color=keypt_color,
|
101 |
-
marker_size=marker_size)
|
102 |
-
|
103 |
-
donw_file = save_results_only_dlc(list_kpts_per_crop[0], map_label_id_to_str,dlc_model_input_str)
|
104 |
-
|
105 |
-
return img_input, donw_file
|
106 |
-
|
107 |
-
else:
|
108 |
-
# Compute kpts for each crop
|
109 |
-
list_kpts_per_crop = predict_dlc(list_crops,
|
110 |
-
kpts_likelihood_th,
|
111 |
-
path_to_DLCmodel,
|
112 |
-
dlc_proc)
|
113 |
-
|
114 |
-
# resize input image to match megadetector output
|
115 |
-
img_background = img_input.resize((md_results.ims[0].shape[1],
|
116 |
-
md_results.ims[0].shape[0]))
|
117 |
-
|
118 |
-
# draw keypoints on each crop and paste to background img
|
119 |
-
for ic, (np_crop, kpts_crop) in enumerate(zip(list_crops,
|
120 |
-
list_kpts_per_crop)):
|
121 |
-
|
122 |
-
img_crop = Image.fromarray(np_crop)
|
123 |
-
|
124 |
-
# Draw keypts on crop
|
125 |
-
draw_keypoints_on_image(img_crop,
|
126 |
-
kpts_crop, # a numpy array with shape [num_keypoints, 2].
|
127 |
-
map_label_id_to_str,
|
128 |
-
flag_show_str_labels,
|
129 |
-
use_normalized_coordinates=False, # if True, then I should use md_results.xyxyn for list_kpts_crop
|
130 |
-
font_style=font_style,
|
131 |
-
font_size=font_size,
|
132 |
-
keypt_color=keypt_color,
|
133 |
-
marker_size=marker_size)
|
134 |
-
|
135 |
-
# Paste crop in original image
|
136 |
-
img_background.paste(img_crop,
|
137 |
-
box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))
|
138 |
-
|
139 |
-
# Plot bbox
|
140 |
-
bb_per_animal = md_results.xyxy[0].tolist()[ic]
|
141 |
-
pred = md_results.xyxy[0].tolist()[ic][4]
|
142 |
-
if bbox_likelihood_th < pred:
|
143 |
-
draw_bbox_w_text(img_background,
|
144 |
-
bb_per_animal,
|
145 |
-
font_style=font_style,
|
146 |
-
font_size=font_size) # TODO: add selectable color for bbox?
|
147 |
-
|
148 |
-
|
149 |
-
# Save detection results as json
|
150 |
-
download_file = save_results_as_json(md_results,list_kpts_per_crop,map_label_id_to_str, bbox_likelihood_th,dlc_model_input_str,mega_model_input)
|
151 |
-
|
152 |
-
return img_background, download_file
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
inputs = gradio_inputs_for_MD_DLC(list(MD_models_dict.keys()),
|
157 |
-
list(DLC_models_dict.keys()))
|
158 |
-
outputs = gradio_outputs_for_MD_DLC()
|
159 |
-
[gr_title,
|
160 |
-
gr_description,
|
161 |
-
examples] = gradio_description_and_examples()
|
162 |
|
163 |
-
|
164 |
-
demo = gr.Interface(predict_pipeline,
|
165 |
-
inputs=inputs,
|
166 |
-
outputs=outputs,
|
167 |
-
title=gr_title,
|
168 |
-
description=gr_description,
|
169 |
-
examples=examples,
|
170 |
-
theme="huggingface")
|
171 |
|
172 |
-
demo.launch(enable_queue=True, share=True)
|
173 |
|
|
|
20 |
from detection_utils import predict_md, crop_animal_detections, predict_dlc
|
21 |
from ui_utils import gradio_inputs_for_MD_DLC, gradio_outputs_for_MD_DLC, gradio_description_and_examples
|
22 |
|
23 |
+
# directly from huggingface
|
|
|
|
|
24 |
|
25 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
model = gr.Interface.load("mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped",
|
28 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
model.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
|
|
32 |
|