File size: 6,192 Bytes
7ece8c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
import os
import torch

from torch.utils.data import DataLoader
from config import get_config_universal
from dataset import DataSet
from datasetbuilder import DataSetBuilder
from test import Test
from visualization.steamlit_plot import plot_kinematic_predictions

dataset_name = 'camargo'
config = get_config_universal(dataset_name)

# model_file = 'transformertsai_g1g2rardsasd_g1g2rardsasd.pt'
# model = torch.load(os.path.join('./caches/trained_model/v05', model_file))
sensor_options = {'Thigh & Shank & Foot': ['foot', 'shank', 'thigh'],
                  'Thigh & Shank': ['thigh', 'shank'],
                  'Thigh & Foot': ['thigh', 'foot'],
                  'Shank & Foot': ['shank', 'foot'],
                  'Thigh': ['thigh'],
                  'Shank': ['shank'],
                  'Foot': ['foot']}

@st.cache
def fetch_data(config):
    dataset_handler = DataSet(config, load_dataset=True)
    kihadataset_train, kihadataset_test = dataset_handler.run_dataset_split_loop()
    kihadataset_train['x'], kihadataset_train['y'], kihadataset_train['labels'] = dataset_handler.run_segmentation(
        kihadataset_train['x'],
        kihadataset_train['y'], kihadataset_train['labels'])
    kihadataset_test['x'], kihadataset_test['y'], kihadataset_test['labels'] = dataset_handler.run_segmentation(
        kihadataset_test['x'],
        kihadataset_test['y'], kihadataset_test['labels'])
    train_dataset = DataSetBuilder(kihadataset_train['x'], kihadataset_train['y'], kihadataset_train['labels'],
                                   transform_method=config['data_transformer'], scaler=None, noise=None)
    test_dataset = DataSetBuilder(kihadataset_test['x'], kihadataset_test['y'], kihadataset_test['labels'],
                                  transform_method=config['data_transformer'], scaler=train_dataset.scaler,
                                  noise=None)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=config['batch_size'], shuffle=False)
    return test_dataloader, kihadataset_test

# @st.cache()
def fetch_model(sensor_name, model_name):
    device = torch.device('cpu')
    model_names = {'CNNLSTM':'hernandez2021cnnlstm', 'BiLSTM':'bilstm', 'BioMAT': 'transformertsai'}
    sensor_names = {'Thigh & Shank & Foot':'thighshankfoot',
                    'Thigh & Shank':'thighshank',
                   'Thigh & Foot':'thighfoot',
                   'Shank & Foot':'shankfoot',
                   'Thigh':'thigh',
                   'Shank':'shank',
                   'Foot':'foot'}
    if sensor_names[sensor_name]=='thighshankfoot':
        model_file = model_names[model_name] + '_g1g2rardsasd_g1g2rardsasd.pt'
    else:
        model_file = sensor_names[sensor_name] + '_' + model_names[model_name]+'_g1g2rardsasd_g1g2rardsasd.pt'
    # model = torch.load(os.path.join('./caches/trained_model/v05', model_file), map_location=device)
    st.write(model_file)
    model = torch.load(os.path.join('./caches/trained_model/v05', model_file))
    return model

# @st.cache
def fetch_predictions(model):
    test_handler = Test()
    y_pred, y_true, loss = test_handler.run_testing(config, model, test_dataloader=test_dataloader)
    y_true = y_true.detach().cpu().clone().numpy()
    y_pred = y_pred.detach().cpu().clone().numpy()
    return y_pred, y_true, loss

# sensor_name = 'Thigh & Shank & Foot'
# config['sensor_sensor'] = sensor_options[sensor_name]
# test_dataloader, kihadataset_test = fetch_data(config)
# model = fetch_model(sensor_name, 'BioMAT')
# y_pred, y_true, loss = fetch_predictions(model)
# fig = plot_kinematic_predictions(y_true, y_pred, kihadataset_test['labels'], 'AB24',
#                                  selected_activities= ['LevelGround Walking', 'Ramp Ascent', 'Ramp Descent', 'Stair Ascent', 'Stair Descent'],
#                                  selected_index_to_plot=1)
st.set_page_config(layout="wide")
# col1, col2, col3 = st.columns(3)
# with col2:
st.title('BioMAT:Biomechanical Multi-Activity Transformer Model for Joint Kinematic Prediction From IMUs')
# st.info('If you change the sensor configeration, press rerun', icon="ℹ️")

st.sidebar.title('Sensor Configuration')
selected_sensor = st.sidebar.selectbox('Pick one', ['Thigh & Shank & Foot',
                                                    'Thigh & Shank',
                                                    'Thigh & Foot',
                                                    'Shank & Foot',
                                                    'Thigh',
                                                    'Shank',
                                                    'Foot'])

config['selected_sensors'] = sensor_options[selected_sensor]
print(config)

st.sidebar.title('Model Configuration')
selected_model = st.sidebar.selectbox('Pick one', ['CNNLSTM',
                                                    'BiLSTM',
                                                   'BioMAT'])

st.sidebar.title('Subject')
selected_subject = st.sidebar.slider('Pick a Subject Number', min_value=23, max_value=25, step=1)

st.sidebar.title('Activity')
selected_activities = st.sidebar.multiselect('Pick Output Activities',
                                           ['LevelGround Walking', 'Ramp Ascent', 'Ramp Descent', 'Stair Ascent', 'Stair Descent'])

index_to_plot = st.sidebar.number_input('Enter a number between 0 and 5', min_value=0, max_value=5)

if st.sidebar.button('Predict'):
    with st.spinner('Data is loading...'):
        test_dataloader, kihadataset_test = fetch_data(config)
    st.success('Data is loaded!')
    with st.spinner('Model is loading...'):
        model = fetch_model(selected_sensor, selected_model)
    st.success('Model is loaded!')
    with st.spinner('Prediction ...'):
        y_pred, y_true, loss = fetch_predictions(model)
    st.success('Prediction is Completed!')
    st.write('plot ...')
    subject = 'AB' + str(selected_subject)
    fig = plot_kinematic_predictions(y_true, y_pred, kihadataset_test['labels'], subject,
                                 selected_activities=selected_activities, selected_index_to_plot=index_to_plot)
    st.plotly_chart(fig, use_container_width=True)
#