LMM / app.py
mingyuan's picture
app
69bafea
import os
import sys
import gradio as gr
import time
os.makedirs("outputs", exist_ok=True)
sys.path.insert(0, '.')
import argparse
import os.path as osp
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmcv.parallel import MMDataParallel
from scipy.ndimage import gaussian_filter
from IPython.display import Image
from mogen.models.utils.imagebind_wrapper import (
extract_text_feature,
extract_audio_feature,
imagebind_huge
)
from mogen.models import build_architecture
from mogen.utils.plot_utils import (
plot_3d_motion,
add_audio,
get_audio_length
)
from mogen.datasets.paramUtil import (
t2m_body_hand_kinematic_chain,
t2m_kinematic_chain
)
from mogen.datasets.utils import recover_from_ric
from mogen.datasets.pipelines import RetargetSkeleton
import requests
from huggingface_hub import hf_hub_download
from huggingface_hub import login
def load_large_files(relative_path):
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# URL to file in the Hugging Face Hub
url = "https://huggingface.co/mingyuan/data_hf/blob/main/" + relative_path
# Temporary download in Space
# file_path = "/tmp/" + relative_path.split('/')[-1]
# response = requests.get(url, stream=True)
# response.raise_for_status()
# total_size = int(response.headers.get("content-length", 0))
# with open(file_path, "wb") as file, tqdm(
# desc="Downloading",
# total=total_size,
# unit="B",
# unit_scale=True,
# unit_divisor=1024,
# ) as bar:
# for chunk in response.iter_content(chunk_size=8192):
# file.write(chunk)
# bar.update(len(chunk))
# print(f"File downloaded successfully and saved as: {file_path}")
file_path = hf_hub_download(repo_id="mingyuan/data_hf", filename=relative_path, repo_type="dataset")
return file_path
def motion_temporal_filter(motion, sigma=1):
motion = motion.reshape(motion.shape[0], -1)
for i in range(motion.shape[1]):
motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest")
return motion.reshape(motion.shape[0], -1, 3)
def plot_tomato(data, kinematic_chain, result_path, npy_path, fps, sigma=None):
joints = recover_from_ric(torch.from_numpy(data).float(), 52).numpy()
joints = motion_temporal_filter(joints, sigma=2.5)
joints = rtg_skl({"keypoints3d": joints, "meta_data": {"has_lhnd": True}})["keypoints3d"]
plot_3d_motion(
out_path=result_path,
joints=joints,
kinematic_chain=kinematic_chain,
title=None,
fps=fps)
if npy_path is not None:
np.save(npy_path, joints)
def create_lmm():
config_path = "configs/lmm/lmm_small_demo.py"
ckpt_path = load_file_list["lmm"]
cfg = mmcv.Config.fromfile(config_path)
model = build_architecture(cfg.model)
load_checkpoint(model, ckpt_path, map_location='cpu')
if device == 'cpu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=[0])
model.eval()
return model
# device = 'cpu'
device = 'cuda'
type = "hf"
# type = "local"
if type == "local":
load_file_list = {
"mean": "../data_hf/data/motionverse/statistics/mean.npy",
"std": "../data_hf/data/motionverse/statistics/std.npy",
"skeleton": "../data_hf/data/motionverse/statistics/skeleton.npy",
"lmm": "../data_hf/data/motionverse/pretrained/lmm_small_demo.pth",
"imagebind": "../data_hf/data/motionverse/pretrained/imagebind_huge.pth"
}
os.environ["NO_PROXY"] = os.environ["no_proxy"] = "localhost, 127.0.0.1:7860"
else:
src_file_list = {
"mean": "data/motionverse/statistics/mean.npy",
"std": "data/motionverse/statistics/std.npy",
"skeleton": "data/motionverse/statistics/skeleton.npy",
"lmm": "data/motionverse/pretrained/lmm_small_demo.pth",
"imagebind": "data/motionverse/pretrained/imagebind_huge.pth"
}
load_file_list = {}
for key in src_file_list.keys():
load_file_list[key] = load_large_files(src_file_list[key])
load_file_list["audio_placeholder"] = "./examples/placeholder.m4a"
load_file_list["audio_surprise"] = "./examples/surprise.m4a"
load_file_list["audio_angry"] = "./examples/angry.m4a"
model_lmm = create_lmm()
model_imagebind = imagebind_huge(pretrained=True, ckpt_path=load_file_list["imagebind"])
model_imagebind.eval()
model_imagebind.to(device)
rtg_skl = RetargetSkeleton(tgt_skel_file=load_file_list["skeleton"])
mean_path = load_file_list["mean"]
std_path = load_file_list["std"]
mean = np.load(mean_path)
std = np.load(std_path)
def show_generation_result(model, text, audio_path, motion_length, result_path):
fps = 20
if audio_path is not None:
motion_length = min(200, int(get_audio_length(audio_path) * fps) + 1)
motion = torch.zeros(1, motion_length, 669).to(device)
motion_mask = torch.ones(1, motion_length).to(device)
motion_mask[0, :motion_length] = 1
motion_mask = motion_mask.unsqueeze(-1).repeat(1, 1, 10)
motion_mask[:, :, 9] = 0
dataset_name = "humanml3d_t2m"
kinematic_chain = t2m_body_hand_kinematic_chain
rotation_type = "h3d_rot"
motion_metas = [{
'meta_data': dict(framerate=fps, dataset_name=dataset_name, rotation_type=rotation_type)
}]
motion_length = torch.Tensor([motion_length]).long().to(device)
if text is None and audio_path is not None:
text = "A person is standing and speaking."
model = model.to(device)
input = {
'motion': motion,
'motion_mask': motion_mask,
'motion_length': motion_length,
'motion_metas': motion_metas,
'num_intervals': 1
}
if text is not None:
text_word_feat, text_seq_feat = \
extract_text_feature([text], model_imagebind, device)
assert text_word_feat.shape[0] == 1
assert text_word_feat.shape[1] == 77
assert text_word_feat.shape[2] == 1024
assert text_seq_feat.shape[0] == 1
assert text_seq_feat.shape[1] == 1024
input['text_word_feat'] = text_word_feat
input['text_seq_feat'] = text_seq_feat
input['text_cond'] = torch.Tensor([1.0] * 1).to(device)
else:
input['text_word_feat'] = torch.zeros(1, 77, 1024).to(device)
input['text_seq_feat'] = torch.zeros(1, 1024)
input['text_cond'] = torch.Tensor([0] * 1).to(device)
if audio_path is not None:
speech_word_feat, speech_seq_feat = \
extract_audio_feature([audio_path], model_imagebind, device)
assert speech_word_feat.shape[0] == 1
assert speech_word_feat.shape[1] == 229
assert speech_word_feat.shape[2] == 768
assert speech_seq_feat.shape[0] == 1
assert speech_seq_feat.shape[1] == 1024
input['speech_word_feat'] = speech_word_feat
input['speech_seq_feat'] = speech_seq_feat
input['speech_cond'] = torch.Tensor([1.0] * 1).to(device)
else:
input['speech_word_feat'] = torch.zeros(1, 229, 768).to(device)
input['speech_seq_feat'] = torch.zeros(1, 1024)
input['speech_cond'] = torch.Tensor([0] * 1).to(device)
all_pred_motion = []
with torch.no_grad():
input['inference_kwargs'] = {}
output = model(**input)[0]['pred_motion'][:motion_length]
pred_motion = output.cpu().detach().numpy()
pred_motion = pred_motion * std + mean
plot_tomato(pred_motion, kinematic_chain, result_path, None, fps, 2)
if audio_path is not None:
add_audio(result_path, [audio_path])
def generate(prompt, audio_check, audio_path, length):
if not os.path.exists("outputs"):
os.mkdir("outputs")
result_path = "outputs/" + str(int(time.time())) + ".mp4"
print(audio_path)
if not audio_check:
audio_path = None
if audio_path is not None and not os.path.exists(audio_path):
audio_path = None
if audio_path is not None and audio_path.endswith("placeholder.wav"):
audio_path = None
if len(prompt) == 0:
prompt = None
show_generation_result(model_lmm, prompt, audio_path, length, result_path)
return result_path
input_audio = gr.Audio(
type='filepath',
format='wav',
label="Audio (1-10s, overwrite motion length):",
show_label=True,
sources=["upload", "microphone"],
min_length=1,
max_length=10,
waveform_options=gr.WaveformOptions(
waveform_color="#01C6FF",
waveform_progress_color="#0066B4",
skip_length=2,
show_controls=False,
),
)
input_text = gr.Textbox(
label="Text prompt:"
)
audio_check = gr.Checkbox(
label="Enable audio? "
)
demo = gr.Interface(
fn=generate,
inputs=[input_text, audio_check, input_audio, gr.Slider(20, 200, value=60, label="Motion length (fps 20):")],
outputs=gr.Video(label="Video:"),
examples=[
["A person walks in a circle.", False, load_file_list["audio_placeholder"], 120],
["A person jumps forward.", False, load_file_list["audio_placeholder"], 100],
["A person is stretching arms.", False, load_file_list["audio_placeholder"], 80],
["", True, load_file_list["audio_surprise"], 200],
["", True, load_file_list["audio_angry"], 200],
],
title="LMM: Large Motion Model for Unified Multi-Modal Motion Generation",
description="\nThis is an interactive demo for LMM. For more information, feel free to visit our project page(https://github.com/mingyuan-zhang/LMM).")
demo.queue()
demo.launch()