Spaces:
Running
Running
File size: 4,560 Bytes
a12075d 258887b a12075d 258887b a12075d 258887b 17bf45d a12075d 258887b a12075d 258887b a12075d 258887b a12075d 258887b a12075d 258887b a12075d 258887b a12075d 258887b a12075d 258887b 17bf45d 258887b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
data_path = "./data/"
import pandas as pd
import datasets
# load the csv into motion_capture_data
import streamlit as st
dataset_names = ['Fold_towels', 'Pipette', 'Take_the_item', 'Twist_the_tube']
def load_data():
print("Loading data")
# load the motion capture data
all_datasets = {}
for name in dataset_names:
print("Loading dataset: ", name)
all_datasets[name] = pd.DataFrame(datasets.load_dataset("cyberorigin/"+name)['train'])
total_period = 0
for dataset in all_datasets.values():
# dataset["timestamp"] = dataset["timestamp"].astype(float)
traj_period = dataset["timestamp"].iloc[-1] - dataset["timestamp"].iloc[0]
total_period += traj_period
return all_datasets, total_period
@st.fragment
def visualize(data):
dataset_option = st.selectbox(
'Select a dataset:',
dataset_names
)
# create a streamlit app that displays the motion capture data
# and the video data
st.video("https://huggingface.co/datasets/cyberorigin/"+dataset_option+"/resolve/main/Video/video.mp4")
motion_capture_data = data[dataset_option]
body_part_names = ['Left Shoulder',
'Right Upper Arm',
'Left Lower Leg',
'Spine1',
'Right Upper Leg',
'Spine3',
'Right Lower Arm',
'Left Foot',
'Right Lower Leg',
'Right Shoulder',
'Left Hand',
'Left Upper Leg',
'Right Foot',
'Spine',
'Spine2',
'Left Lower Arm',
'Left Toe',
'Neck',
'Right Hand',
'Right Toe',
'Head',
'Left Upper Arm',
'Hips',]
motion_capture_x = motion_capture_data[[body_part_name+"_x" for body_part_name in body_part_names]]
motion_capture_y = motion_capture_data[[body_part_name+"_y" for body_part_name in body_part_names]]
motion_capture_z = motion_capture_data[[body_part_name+"_z" for body_part_name in body_part_names]]
import plotly.graph_objects as go
import numpy as np
# Sample Data Preparation
data = []
times = motion_capture_data["timestamp"]
frames = [go.Frame(
data=[
go.Scatter3d(
x=motion_capture_x.iloc[k],
y=motion_capture_y.iloc[k],
z=motion_capture_z.iloc[k],
mode='markers',
marker=dict(size=5, color='blue')
)
],
name=str(k)
) for k in range(len(times))]
# Create the initial scatter plot
initial_scatter = go.Scatter3d(
x=motion_capture_x.iloc[0],
y=motion_capture_y.iloc[0],
z=motion_capture_z.iloc[0],
mode='markers',
marker=dict(size=5, color='blue')
)
# Create the layout with slider
layout = go.Layout(
title='Motion Capture Visualization',
updatemenus=[{
'buttons': [
{
'args': [None, {'frame': {'duration': 1, 'redraw': True}, 'fromcurrent': True}],
'label': 'Play',
'method': 'animate'
},
{
'args': [[None], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate', 'transition': {'duration': 0}}],
'label': 'Pause',
'method': 'animate'
}
],
'direction': 'left',
'pad': {'r': 10, 't': 87},
'showactive': True,
'type': 'buttons',
'x': 0.1,
'xanchor': 'right',
'y': 0,
'yanchor': 'top'
}],
sliders=[{
'active': 0,
'steps': [{
'label': str(k),
'method': 'animate',
'args': [
[str(k)],
{'mode': 'immediate', 'frame': {'duration': 300, 'redraw': True}, 'transition': {'duration': len(times)/30}}
]
} for k in range(len(times))],
'currentvalue': {
'prefix': 'Time: ',
'visible': True,
'xanchor': 'right'
},
'pad': {'b': 10},
'len': 0.9,
'x': 0.1,
'y': 0,
}]
)
# Create the figure
fig = go.Figure(data=[initial_scatter], frames=frames, layout=layout)
# Display the figure in the streamlit app
st.plotly_chart(fig)
st.title("CyberOrigin Data Visualization")
data, period = load_data()
# display the total period of the data up to 2 decimal places
st.write("Total period of data: ", round(period, 2), " seconds")
visualize(data) |