File size: 4,560 Bytes
a12075d
 
 
258887b
a12075d
 
 
258887b
a12075d
258887b
 
 
 
 
 
 
17bf45d
 
 
 
 
 
a12075d
258887b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a12075d
258887b
 
 
a12075d
258887b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a12075d
258887b
 
 
 
 
 
 
 
a12075d
258887b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a12075d
258887b
 
 
 
 
 
 
 
 
a12075d
258887b
 
a12075d
258887b
 
17bf45d
 
 
258887b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
data_path = "./data/"

import pandas as pd
import datasets

# load the csv into motion_capture_data
import streamlit as st
dataset_names = ['Fold_towels', 'Pipette', 'Take_the_item', 'Twist_the_tube']

def load_data():
    print("Loading data")
    # load the motion capture data
    all_datasets = {}
    for name in dataset_names:
        print("Loading dataset: ", name)
        all_datasets[name] = pd.DataFrame(datasets.load_dataset("cyberorigin/"+name)['train'])
    total_period = 0
    for dataset in all_datasets.values():
        # dataset["timestamp"] = dataset["timestamp"].astype(float)
        traj_period = dataset["timestamp"].iloc[-1] - dataset["timestamp"].iloc[0]
        total_period += traj_period
    return all_datasets, total_period

@st.fragment
def visualize(data):
    dataset_option = st.selectbox(
        'Select a dataset:',
        dataset_names
    )
    # create a streamlit app that displays the motion capture data
    # and the video data
    st.video("https://huggingface.co/datasets/cyberorigin/"+dataset_option+"/resolve/main/Video/video.mp4")
    motion_capture_data = data[dataset_option]
    body_part_names = ['Left Shoulder',
    'Right Upper Arm',
    'Left Lower Leg',
    'Spine1',
    'Right Upper Leg',
    'Spine3',
    'Right Lower Arm',
    'Left Foot',
    'Right Lower Leg',
    'Right Shoulder',
    'Left Hand',
    'Left Upper Leg',
    'Right Foot',
    'Spine',
    'Spine2',
    'Left Lower Arm',
    'Left Toe',
    'Neck',
    'Right Hand',
    'Right Toe',
    'Head',
    'Left Upper Arm',
    'Hips',]

    motion_capture_x = motion_capture_data[[body_part_name+"_x" for body_part_name in body_part_names]]
    motion_capture_y = motion_capture_data[[body_part_name+"_y" for body_part_name in body_part_names]]
    motion_capture_z = motion_capture_data[[body_part_name+"_z" for body_part_name in body_part_names]]

    import plotly.graph_objects as go
    import numpy as np

    # Sample Data Preparation
    data = []
    times = motion_capture_data["timestamp"]
    frames = [go.Frame(
        data=[
            go.Scatter3d(
                x=motion_capture_x.iloc[k],
                y=motion_capture_y.iloc[k],
                z=motion_capture_z.iloc[k],
                mode='markers',
                marker=dict(size=5, color='blue')
            )
        ],
        name=str(k)
    ) for k in range(len(times))]

    # Create the initial scatter plot
    initial_scatter = go.Scatter3d(
        x=motion_capture_x.iloc[0],
        y=motion_capture_y.iloc[0],
        z=motion_capture_z.iloc[0],
        mode='markers',
        marker=dict(size=5, color='blue')
    )

    # Create the layout with slider
    layout = go.Layout(
        title='Motion Capture Visualization',
        updatemenus=[{
            'buttons': [
                {
                    'args': [None, {'frame': {'duration': 1, 'redraw': True}, 'fromcurrent': True}],
                    'label': 'Play',
                    'method': 'animate'
                },
                {
                    'args': [[None], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate', 'transition': {'duration': 0}}],
                    'label': 'Pause',
                    'method': 'animate'
                }
            ],
            'direction': 'left',
            'pad': {'r': 10, 't': 87},
            'showactive': True,
            'type': 'buttons',
            'x': 0.1,
            'xanchor': 'right',
            'y': 0,
            'yanchor': 'top'
        }],
        sliders=[{
            'active': 0,
            'steps': [{
                'label': str(k),
                'method': 'animate',
                'args': [
                    [str(k)],
                    {'mode': 'immediate', 'frame': {'duration': 300, 'redraw': True}, 'transition': {'duration': len(times)/30}}
                ]
            } for k in range(len(times))],
            'currentvalue': {
                'prefix': 'Time: ',
                'visible': True,
                'xanchor': 'right'
            },
            'pad': {'b': 10},
            'len': 0.9,
            'x': 0.1,
            'y': 0,
        }]
    )

    # Create the figure
    fig = go.Figure(data=[initial_scatter], frames=frames, layout=layout)

    # Display the figure in the streamlit app
    st.plotly_chart(fig)


st.title("CyberOrigin Data Visualization")
data, period = load_data()
# display the total period of the data up to 2 decimal places
st.write("Total period of data: ", round(period, 2), " seconds")
visualize(data)