Spaces:
Running
Running
Mathis Petrovich
commited on
Commit
·
5e4fa5e
1
Parent(s):
bdb661d
device
Browse files
app.py
CHANGED
@@ -56,7 +56,7 @@ EXAMPLES = [
|
|
56 |
"A person is taking the stairs",
|
57 |
"Someone is doing jumping jacks",
|
58 |
"The person walked forward and is picking up his toolbox",
|
59 |
-
"The person angrily punching the air"
|
60 |
]
|
61 |
|
62 |
# Show closest text in the training
|
@@ -94,6 +94,7 @@ CSS = """
|
|
94 |
|
95 |
DEFAULT_TEXT = "A person is "
|
96 |
|
|
|
97 |
def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
|
98 |
# Don't show the mirrored version of HumanMl3D
|
99 |
if "M" in keyid:
|
@@ -128,13 +129,15 @@ def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
|
|
128 |
"text": text,
|
129 |
"keyid": keyid,
|
130 |
"babel_id": babel_id,
|
131 |
-
"path": path
|
132 |
}
|
133 |
|
134 |
return data
|
135 |
|
136 |
|
137 |
-
def retrieve(
|
|
|
|
|
138 |
unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
|
139 |
keyids = np.concatenate([all_keyids[s] for s in splits])
|
140 |
|
@@ -169,7 +172,7 @@ def get_video_html(data, video_id, width=700, height=700):
|
|
169 |
path = data["path"]
|
170 |
|
171 |
trim = f"#t={start},{end}"
|
172 |
-
title = f
|
173 |
|
174 |
Corresponding text: {text}
|
175 |
|
@@ -177,18 +180,18 @@ HumanML3D keyid: {keyid}
|
|
177 |
|
178 |
BABEL keyid: {babel_id}
|
179 |
|
180 |
-
AMASS path: {path}
|
181 |
|
182 |
# class="wrap default svelte-gjihhp hide"
|
183 |
# <div class="contour_video" style="position: absolute; padding: 10px;">
|
184 |
# width="{width}" height="{height}"
|
185 |
-
video_html = f
|
186 |
<video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
|
187 |
autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
|
188 |
<source src="{url}{trim}" type="video/mp4">
|
189 |
Your browser does not support the video tag.
|
190 |
</video>
|
191 |
-
|
192 |
return video_html
|
193 |
|
194 |
|
@@ -208,16 +211,18 @@ def retrieve_component(retrieve_function, text, splits_choice, nvids, n_componen
|
|
208 |
htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
|
209 |
# get n_component exactly if asked less
|
210 |
# pad with dummy blocks
|
211 |
-
htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
|
212 |
return htmls
|
213 |
|
214 |
|
215 |
if not os.path.exists("data"):
|
216 |
-
gdown.download_folder(
|
217 |
-
|
|
|
|
|
218 |
|
219 |
|
220 |
-
device = torch.device(
|
221 |
|
222 |
# LOADING
|
223 |
model = load_model(device)
|
@@ -229,7 +234,9 @@ h3d_index = load_json("amass-annotations/humanml3d.json")
|
|
229 |
amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
|
230 |
|
231 |
keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
|
232 |
-
retrieve_function = partial(
|
|
|
|
|
233 |
|
234 |
# DEMO
|
235 |
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
|
@@ -242,33 +249,48 @@ with gr.Blocks(css=CSS, theme=theme) as demo:
|
|
242 |
with gr.Row():
|
243 |
with gr.Column(scale=3):
|
244 |
with gr.Column(scale=2):
|
245 |
-
text = gr.Textbox(
|
246 |
-
|
|
|
|
|
|
|
|
|
247 |
with gr.Column(scale=1):
|
248 |
-
btn = gr.Button("Retrieve", variant=
|
249 |
-
clear = gr.Button("Clear", variant=
|
250 |
|
251 |
with gr.Row():
|
252 |
with gr.Column(scale=1):
|
253 |
-
splits_choice = gr.Radio(
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
256 |
|
257 |
with gr.Column(scale=1):
|
258 |
# nvideo_slider = gr.Slider(minimum=4, maximum=24, step=4, value=8, label="Number of videos")
|
259 |
-
nvideo_slider = gr.Radio(
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
262 |
|
263 |
with gr.Column(scale=2):
|
|
|
264 |
def retrieve_example(text, splits_choice, nvideo_slider):
|
265 |
return retrieve_and_show(text, splits_choice, nvideo_slider)
|
266 |
|
267 |
-
examples = gr.Examples(
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
272 |
|
273 |
i = -1
|
274 |
# should indent
|
@@ -294,16 +316,28 @@ with gr.Blocks(css=CSS, theme=theme) as demo:
|
|
294 |
show_progress=False,
|
295 |
postprocess=False,
|
296 |
queue=False,
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
text.submit(
|
305 |
-
|
306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
|
308 |
def clear_videos():
|
309 |
return [None for x in range(24)] + [DEFAULT_TEXT]
|
|
|
56 |
"A person is taking the stairs",
|
57 |
"Someone is doing jumping jacks",
|
58 |
"The person walked forward and is picking up his toolbox",
|
59 |
+
"The person angrily punching the air",
|
60 |
]
|
61 |
|
62 |
# Show closest text in the training
|
|
|
94 |
|
95 |
DEFAULT_TEXT = "A person is "
|
96 |
|
97 |
+
|
98 |
def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
|
99 |
# Don't show the mirrored version of HumanMl3D
|
100 |
if "M" in keyid:
|
|
|
129 |
"text": text,
|
130 |
"keyid": keyid,
|
131 |
"babel_id": babel_id,
|
132 |
+
"path": path,
|
133 |
}
|
134 |
|
135 |
return data
|
136 |
|
137 |
|
138 |
+
def retrieve(
|
139 |
+
model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits=["test"], nmax=8
|
140 |
+
):
|
141 |
unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
|
142 |
keyids = np.concatenate([all_keyids[s] for s in splits])
|
143 |
|
|
|
172 |
path = data["path"]
|
173 |
|
174 |
trim = f"#t={start},{end}"
|
175 |
+
title = f"""Score = {score}
|
176 |
|
177 |
Corresponding text: {text}
|
178 |
|
|
|
180 |
|
181 |
BABEL keyid: {babel_id}
|
182 |
|
183 |
+
AMASS path: {path}"""
|
184 |
|
185 |
# class="wrap default svelte-gjihhp hide"
|
186 |
# <div class="contour_video" style="position: absolute; padding: 10px;">
|
187 |
# width="{width}" height="{height}"
|
188 |
+
video_html = f"""
|
189 |
<video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
|
190 |
autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
|
191 |
<source src="{url}{trim}" type="video/mp4">
|
192 |
Your browser does not support the video tag.
|
193 |
</video>
|
194 |
+
"""
|
195 |
return video_html
|
196 |
|
197 |
|
|
|
211 |
htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
|
212 |
# get n_component exactly if asked less
|
213 |
# pad with dummy blocks
|
214 |
+
htmls = htmls + [None for _ in range(max(0, n_component - nvids))]
|
215 |
return htmls
|
216 |
|
217 |
|
218 |
if not os.path.exists("data"):
|
219 |
+
gdown.download_folder(
|
220 |
+
"https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08",
|
221 |
+
use_cookies=False,
|
222 |
+
)
|
223 |
|
224 |
|
225 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
226 |
|
227 |
# LOADING
|
228 |
model = load_model(device)
|
|
|
234 |
amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
|
235 |
|
236 |
keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
|
237 |
+
retrieve_function = partial(
|
238 |
+
retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids
|
239 |
+
)
|
240 |
|
241 |
# DEMO
|
242 |
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
|
|
|
249 |
with gr.Row():
|
250 |
with gr.Column(scale=3):
|
251 |
with gr.Column(scale=2):
|
252 |
+
text = gr.Textbox(
|
253 |
+
placeholder="Type the motion you want to search with a sentence",
|
254 |
+
show_label=True,
|
255 |
+
label="Text prompt",
|
256 |
+
value=DEFAULT_TEXT,
|
257 |
+
)
|
258 |
with gr.Column(scale=1):
|
259 |
+
btn = gr.Button("Retrieve", variant="primary")
|
260 |
+
clear = gr.Button("Clear", variant="secondary")
|
261 |
|
262 |
with gr.Row():
|
263 |
with gr.Column(scale=1):
|
264 |
+
splits_choice = gr.Radio(
|
265 |
+
["All motions", "Unseen motions"],
|
266 |
+
label="Gallery of motion",
|
267 |
+
value="All motions",
|
268 |
+
info="The motion gallery is coming from HumanML3D",
|
269 |
+
)
|
270 |
|
271 |
with gr.Column(scale=1):
|
272 |
# nvideo_slider = gr.Slider(minimum=4, maximum=24, step=4, value=8, label="Number of videos")
|
273 |
+
nvideo_slider = gr.Radio(
|
274 |
+
[4, 8, 12, 16, 24],
|
275 |
+
label="Videos",
|
276 |
+
value=8,
|
277 |
+
info="Number of videos to display",
|
278 |
+
)
|
279 |
|
280 |
with gr.Column(scale=2):
|
281 |
+
|
282 |
def retrieve_example(text, splits_choice, nvideo_slider):
|
283 |
return retrieve_and_show(text, splits_choice, nvideo_slider)
|
284 |
|
285 |
+
examples = gr.Examples(
|
286 |
+
examples=[[x, None, None] for x in EXAMPLES],
|
287 |
+
inputs=[text, splits_choice, nvideo_slider],
|
288 |
+
examples_per_page=20,
|
289 |
+
run_on_click=False,
|
290 |
+
cache_examples=False,
|
291 |
+
fn=retrieve_example,
|
292 |
+
outputs=[],
|
293 |
+
)
|
294 |
|
295 |
i = -1
|
296 |
# should indent
|
|
|
316 |
show_progress=False,
|
317 |
postprocess=False,
|
318 |
queue=False,
|
319 |
+
).then(fn=retrieve_example, inputs=examples.inputs, outputs=videos)
|
320 |
+
|
321 |
+
btn.click(
|
322 |
+
fn=retrieve_and_show,
|
323 |
+
inputs=[text, splits_choice, nvideo_slider],
|
324 |
+
outputs=videos,
|
325 |
+
)
|
326 |
+
text.submit(
|
327 |
+
fn=retrieve_and_show,
|
328 |
+
inputs=[text, splits_choice, nvideo_slider],
|
329 |
+
outputs=videos,
|
330 |
+
)
|
331 |
+
splits_choice.change(
|
332 |
+
fn=retrieve_and_show,
|
333 |
+
inputs=[text, splits_choice, nvideo_slider],
|
334 |
+
outputs=videos,
|
335 |
+
)
|
336 |
+
nvideo_slider.change(
|
337 |
+
fn=retrieve_and_show,
|
338 |
+
inputs=[text, splits_choice, nvideo_slider],
|
339 |
+
outputs=videos,
|
340 |
+
)
|
341 |
|
342 |
def clear_videos():
|
343 |
return [None for x in range(24)] + [DEFAULT_TEXT]
|
load.py
CHANGED
@@ -20,10 +20,7 @@ def load_keyids(split):
|
|
20 |
|
21 |
|
22 |
def load_keyids_splits(splits):
|
23 |
-
return {
|
24 |
-
split: load_keyids(split)
|
25 |
-
for split in splits
|
26 |
-
}
|
27 |
|
28 |
|
29 |
def load_unit_motion_embs(split, device):
|
@@ -33,16 +30,17 @@ def load_unit_motion_embs(split, device):
|
|
33 |
|
34 |
|
35 |
def load_unit_motion_embs_splits(splits, device):
|
36 |
-
return {
|
37 |
-
split: load_unit_motion_embs(split, device)
|
38 |
-
for split in splits
|
39 |
-
}
|
40 |
|
41 |
|
42 |
def load_model(device):
|
43 |
text_params = {
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
}
|
47 |
"unit_motion_embs"
|
48 |
model = TMR_textencoder(**text_params)
|
@@ -50,4 +48,4 @@ def load_model(device):
|
|
50 |
# load values for the transformer only
|
51 |
model.load_state_dict(state_dict, strict=False)
|
52 |
model = model.eval()
|
53 |
-
return model
|
|
|
20 |
|
21 |
|
22 |
def load_keyids_splits(splits):
|
23 |
+
return {split: load_keyids(split) for split in splits}
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
def load_unit_motion_embs(split, device):
|
|
|
30 |
|
31 |
|
32 |
def load_unit_motion_embs_splits(splits, device):
|
33 |
+
return {split: load_unit_motion_embs(split, device) for split in splits}
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
def load_model(device):
|
37 |
text_params = {
|
38 |
+
"latent_dim": 256,
|
39 |
+
"ff_size": 1024,
|
40 |
+
"num_layers": 6,
|
41 |
+
"num_heads": 4,
|
42 |
+
"activation": "gelu",
|
43 |
+
"modelpath": "distilbert-base-uncased",
|
44 |
}
|
45 |
"unit_motion_embs"
|
46 |
model = TMR_textencoder(**text_params)
|
|
|
48 |
# load values for the transformer only
|
49 |
model.load_state_dict(state_dict, strict=False)
|
50 |
model = model.eval()
|
51 |
+
return model.to(device)
|