File size: 6,192 Bytes
7ece8c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import streamlit as st
import os
import torch
from torch.utils.data import DataLoader
from config import get_config_universal
from dataset import DataSet
from datasetbuilder import DataSetBuilder
from test import Test
from visualization.steamlit_plot import plot_kinematic_predictions
dataset_name = 'camargo'
config = get_config_universal(dataset_name)
# model_file = 'transformertsai_g1g2rardsasd_g1g2rardsasd.pt'
# model = torch.load(os.path.join('./caches/trained_model/v05', model_file))
sensor_options = {'Thigh & Shank & Foot': ['foot', 'shank', 'thigh'],
'Thigh & Shank': ['thigh', 'shank'],
'Thigh & Foot': ['thigh', 'foot'],
'Shank & Foot': ['shank', 'foot'],
'Thigh': ['thigh'],
'Shank': ['shank'],
'Foot': ['foot']}
@st.cache
def fetch_data(config):
dataset_handler = DataSet(config, load_dataset=True)
kihadataset_train, kihadataset_test = dataset_handler.run_dataset_split_loop()
kihadataset_train['x'], kihadataset_train['y'], kihadataset_train['labels'] = dataset_handler.run_segmentation(
kihadataset_train['x'],
kihadataset_train['y'], kihadataset_train['labels'])
kihadataset_test['x'], kihadataset_test['y'], kihadataset_test['labels'] = dataset_handler.run_segmentation(
kihadataset_test['x'],
kihadataset_test['y'], kihadataset_test['labels'])
train_dataset = DataSetBuilder(kihadataset_train['x'], kihadataset_train['y'], kihadataset_train['labels'],
transform_method=config['data_transformer'], scaler=None, noise=None)
test_dataset = DataSetBuilder(kihadataset_test['x'], kihadataset_test['y'], kihadataset_test['labels'],
transform_method=config['data_transformer'], scaler=train_dataset.scaler,
noise=None)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=config['batch_size'], shuffle=False)
return test_dataloader, kihadataset_test
# @st.cache()
def fetch_model(sensor_name, model_name):
device = torch.device('cpu')
model_names = {'CNNLSTM':'hernandez2021cnnlstm', 'BiLSTM':'bilstm', 'BioMAT': 'transformertsai'}
sensor_names = {'Thigh & Shank & Foot':'thighshankfoot',
'Thigh & Shank':'thighshank',
'Thigh & Foot':'thighfoot',
'Shank & Foot':'shankfoot',
'Thigh':'thigh',
'Shank':'shank',
'Foot':'foot'}
if sensor_names[sensor_name]=='thighshankfoot':
model_file = model_names[model_name] + '_g1g2rardsasd_g1g2rardsasd.pt'
else:
model_file = sensor_names[sensor_name] + '_' + model_names[model_name]+'_g1g2rardsasd_g1g2rardsasd.pt'
# model = torch.load(os.path.join('./caches/trained_model/v05', model_file), map_location=device)
st.write(model_file)
model = torch.load(os.path.join('./caches/trained_model/v05', model_file))
return model
# @st.cache
def fetch_predictions(model):
test_handler = Test()
y_pred, y_true, loss = test_handler.run_testing(config, model, test_dataloader=test_dataloader)
y_true = y_true.detach().cpu().clone().numpy()
y_pred = y_pred.detach().cpu().clone().numpy()
return y_pred, y_true, loss
# sensor_name = 'Thigh & Shank & Foot'
# config['sensor_sensor'] = sensor_options[sensor_name]
# test_dataloader, kihadataset_test = fetch_data(config)
# model = fetch_model(sensor_name, 'BioMAT')
# y_pred, y_true, loss = fetch_predictions(model)
# fig = plot_kinematic_predictions(y_true, y_pred, kihadataset_test['labels'], 'AB24',
# selected_activities= ['LevelGround Walking', 'Ramp Ascent', 'Ramp Descent', 'Stair Ascent', 'Stair Descent'],
# selected_index_to_plot=1)
st.set_page_config(layout="wide")
# col1, col2, col3 = st.columns(3)
# with col2:
st.title('BioMAT:Biomechanical Multi-Activity Transformer Model for Joint Kinematic Prediction From IMUs')
# st.info('If you change the sensor configeration, press rerun', icon="ℹ️")
st.sidebar.title('Sensor Configuration')
selected_sensor = st.sidebar.selectbox('Pick one', ['Thigh & Shank & Foot',
'Thigh & Shank',
'Thigh & Foot',
'Shank & Foot',
'Thigh',
'Shank',
'Foot'])
config['selected_sensors'] = sensor_options[selected_sensor]
print(config)
st.sidebar.title('Model Configuration')
selected_model = st.sidebar.selectbox('Pick one', ['CNNLSTM',
'BiLSTM',
'BioMAT'])
st.sidebar.title('Subject')
selected_subject = st.sidebar.slider('Pick a Subject Number', min_value=23, max_value=25, step=1)
st.sidebar.title('Activity')
selected_activities = st.sidebar.multiselect('Pick Output Activities',
['LevelGround Walking', 'Ramp Ascent', 'Ramp Descent', 'Stair Ascent', 'Stair Descent'])
index_to_plot = st.sidebar.number_input('Enter a number between 0 and 5', min_value=0, max_value=5)
if st.sidebar.button('Predict'):
with st.spinner('Data is loading...'):
test_dataloader, kihadataset_test = fetch_data(config)
st.success('Data is loaded!')
with st.spinner('Model is loading...'):
model = fetch_model(selected_sensor, selected_model)
st.success('Model is loaded!')
with st.spinner('Prediction ...'):
y_pred, y_true, loss = fetch_predictions(model)
st.success('Prediction is Completed!')
st.write('plot ...')
subject = 'AB' + str(selected_subject)
fig = plot_kinematic_predictions(y_true, y_pred, kihadataset_test['labels'], subject,
selected_activities=selected_activities, selected_index_to_plot=index_to_plot)
st.plotly_chart(fig, use_container_width=True)
#
|