mwmathis commited on
Commit
c0f6831
Β·
1 Parent(s): 6db6a87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -146
app.py CHANGED
@@ -20,154 +20,13 @@ from viz_utils import save_results_as_json, draw_keypoints_on_image, draw_bbox_w
20
  from detection_utils import predict_md, crop_animal_detections, predict_dlc
21
  from ui_utils import gradio_inputs_for_MD_DLC, gradio_outputs_for_MD_DLC, gradio_description_and_examples
22
 
23
- # import pdb
24
- #########################################
25
- # Input params - Global vars
26
 
27
- MD_models_dict = {'md_v5a': "MD_models/md_v5a.0.0.pt", #
28
- 'md_v5b': "MD_models/md_v5b.0.0.pt"}
29
-
30
- # DLC models target dirs
31
- DLC_models_dict = {'supernanimal_quadruped': "DLC_models/supernanimal_quadruped/",
32
- 'supernanimal_topviewmouse': "DLC_models/supernanimal_topviewmouse/"
33
- }
34
-
35
- def predict_pipeline(img_input,
36
- mega_model_input,
37
- dlc_model_input_str,
38
- flag_dlc_only,
39
- flag_show_str_labels,
40
- bbox_likelihood_th,
41
- kpts_likelihood_th,
42
- font_style,
43
- font_size,
44
- keypt_color,
45
- marker_size,
46
- ):
47
-
48
- if not flag_dlc_only:
49
- ############################################################
50
- # ### Run Megadetector
51
- md_results = predict_md(img_input,
52
- MD_models_dict[mega_model_input], #mega_model_input,
53
- size=640) #Image.fromarray(results.imgs[0])
54
-
55
- ################################################################
56
- # Obtain animal crops for bboxes with confidence above th
57
- list_crops = crop_animal_detections(img_input,
58
- md_results,
59
- bbox_likelihood_th)
60
-
61
- ############################################################
62
- ## Get DLC model and label map
63
-
64
- # If model is found: do not download (previous execution is likely within same day)
65
- # TODO: can we ask the user whether to reload dlc model if a directory is found?
66
- if os.path.isdir(DLC_models_dict[dlc_model_input_str]) and \
67
- len(os.listdir(DLC_models_dict[dlc_model_input_str])) > 0:
68
- path_to_DLCmodel = DLC_models_dict[dlc_model_input_str]
69
- else:
70
- path_to_DLCmodel = DownloadModel(dlc_model_input_str,
71
- DLC_models_dict[dlc_model_input_str])
72
-
73
- # extract map label ids to strings
74
- pose_cfg_path = os.path.join(DLC_models_dict[dlc_model_input_str],
75
- 'pose_cfg.yaml')
76
- with open(pose_cfg_path, "r") as stream:
77
- pose_cfg_dict = yaml.safe_load(stream)
78
- map_label_id_to_str = dict([(k,v) for k,v in zip([el[0] for el in pose_cfg_dict['all_joints']], # pose_cfg_dict['all_joints'] is a list of one-element lists,
79
- pose_cfg_dict['all_joints_names'])])
80
-
81
- ##############################################################
82
- # Run DLC and visualise results
83
- dlc_proc = Processor()
84
-
85
- # if required: ignore MD crops and run DLC on full image [mostly for testing]
86
- if flag_dlc_only:
87
- # compute kpts on input img
88
- list_kpts_per_crop = predict_dlc([np.asarray(img_input)],
89
- kpts_likelihood_th,
90
- path_to_DLCmodel,
91
- dlc_proc)
92
- # draw kpts on input img #fix!
93
- draw_keypoints_on_image(img_input,
94
- list_kpts_per_crop[0], # a numpy array with shape [num_keypoints, 2].
95
- map_label_id_to_str,
96
- flag_show_str_labels,
97
- use_normalized_coordinates=False,
98
- font_style=font_style,
99
- font_size=font_size,
100
- keypt_color=keypt_color,
101
- marker_size=marker_size)
102
-
103
- donw_file = save_results_only_dlc(list_kpts_per_crop[0], map_label_id_to_str,dlc_model_input_str)
104
-
105
- return img_input, donw_file
106
-
107
- else:
108
- # Compute kpts for each crop
109
- list_kpts_per_crop = predict_dlc(list_crops,
110
- kpts_likelihood_th,
111
- path_to_DLCmodel,
112
- dlc_proc)
113
-
114
- # resize input image to match megadetector output
115
- img_background = img_input.resize((md_results.ims[0].shape[1],
116
- md_results.ims[0].shape[0]))
117
-
118
- # draw keypoints on each crop and paste to background img
119
- for ic, (np_crop, kpts_crop) in enumerate(zip(list_crops,
120
- list_kpts_per_crop)):
121
-
122
- img_crop = Image.fromarray(np_crop)
123
-
124
- # Draw keypts on crop
125
- draw_keypoints_on_image(img_crop,
126
- kpts_crop, # a numpy array with shape [num_keypoints, 2].
127
- map_label_id_to_str,
128
- flag_show_str_labels,
129
- use_normalized_coordinates=False, # if True, then I should use md_results.xyxyn for list_kpts_crop
130
- font_style=font_style,
131
- font_size=font_size,
132
- keypt_color=keypt_color,
133
- marker_size=marker_size)
134
-
135
- # Paste crop in original image
136
- img_background.paste(img_crop,
137
- box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))
138
-
139
- # Plot bbox
140
- bb_per_animal = md_results.xyxy[0].tolist()[ic]
141
- pred = md_results.xyxy[0].tolist()[ic][4]
142
- if bbox_likelihood_th < pred:
143
- draw_bbox_w_text(img_background,
144
- bb_per_animal,
145
- font_style=font_style,
146
- font_size=font_size) # TODO: add selectable color for bbox?
147
-
148
-
149
- # Save detection results as json
150
- download_file = save_results_as_json(md_results,list_kpts_per_crop,map_label_id_to_str, bbox_likelihood_th,dlc_model_input_str,mega_model_input)
151
-
152
- return img_background, download_file
153
 
154
- #########################################################
155
- # Define user interface and launch
156
- inputs = gradio_inputs_for_MD_DLC(list(MD_models_dict.keys()),
157
- list(DLC_models_dict.keys()))
158
- outputs = gradio_outputs_for_MD_DLC()
159
- [gr_title,
160
- gr_description,
161
- examples] = gradio_description_and_examples()
162
 
163
- # launch
164
- demo = gr.Interface(predict_pipeline,
165
- inputs=inputs,
166
- outputs=outputs,
167
- title=gr_title,
168
- description=gr_description,
169
- examples=examples,
170
- theme="huggingface")
171
 
172
- demo.launch(enable_queue=True, share=True)
173
 
 
20
  from detection_utils import predict_md, crop_animal_detections, predict_dlc
21
  from ui_utils import gradio_inputs_for_MD_DLC, gradio_outputs_for_MD_DLC, gradio_description_and_examples
22
 
23
+ # directly from huggingface
 
 
24
 
25
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ model = gr.Interface.load("mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped",
28
+ )
 
 
 
 
 
 
29
 
30
+ model.launch()
 
 
 
 
 
 
 
31
 
 
32