MTECBS / pages /inference.py
yan123yan
update midpoint function
cc9df09
import os
import time
import streamlit as st
import torch
import torch.nn.functional as F
import random
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import pandas as pd
from model.lstm import LSTMModel
from model.tcn import TCNModel
from model.tcn import move_custom_layers_to_device
from utils.lowlevel import LowLevel
from utils.highlevel import HighLevel
from utils.midpoint import MidPoint
from utils.transform import compute_gradient
st.set_page_config(page_title="Inference", page_icon=":chart_with_upwards_trend:", layout="wide", initial_sidebar_state="auto")
def uniform_sampling(data, n_sample):
k = len(data) // n_sample
return data[::k]
def low_level(option_time, slider_sample_orbit, progress_bar):
time.sleep(0.1)
low_level_total_start_time = time.time()
low_level_30000_start_time = time.time()
lowlevelhelper = LowLevel(j=slider_sample_orbit)
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
a1 = 1 / (2 - 2 ** (1 / 3))
a2 = 1 - 2 * a1
jn = 0
t = 0.1
# Calculate the total number of iterations for the progress bar update
total_iterations = (float(option_time) - t) / h
current_iteration = 0
original_low_level_data = []
while t < float(option_time):
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
t = t + h
if jn % 10 == 0:
original_low_level_data.append([b, x, y, z, px, py, pz])
# Update progress bar
progress_percentage = int((current_iteration / total_iterations) * 100)
progress_bar.progress(progress_percentage)
if jn == 300000:
low_level_30000_end_time = time.time()
low_level_30000_execute_time = low_level_30000_end_time - low_level_30000_start_time
low_level_2000_start_time = time.time()
jn = jn + 1
current_iteration += 1
progress_bar.progress(100)
low_level_2000_end_time = time.time()
low_level_2000_execute_time = low_level_2000_end_time - low_level_2000_start_time
low_level_total_end_time = time.time()
low_level_total_execute_time = low_level_total_end_time - low_level_total_start_time
result = uniform_sampling(np.array(original_low_level_data), n_sample=int(option_time/100))
return low_level_30000_execute_time, low_level_2000_execute_time, low_level_total_execute_time, result
def high_level(option_time, slider_sample_orbit, progress_bar):
time.sleep(0.1)
high_level_total_start_time = time.time()
high_level_30000_start_time = time.time()
highlevelhelper = HighLevel(j=slider_sample_orbit)
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = highlevelhelper.initial()
a1 = 1 / (2 - 2 ** (1 / 3))
a2 = 1 - 2 * a1
jn = 0
t = 0.1
# Calculate the total number of iterations for the progress bar update
total_iterations = (float(option_time) - t) / h
current_iteration = 0
original_high_level_data = []
while t < float(option_time):
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
t = t + h
vx, vy, vz, vpx, vpy, vpz, e = highlevelhelper.f(x, y, z, px, py, pz, b)
en = np.asarray(e).astype(np.float64)
if jn % 10 == 0:
original_high_level_data.append([b, x, y, z, px, py, pz])
if jn == 300000:
high_level_30000_end_time = time.time()
high_level_30000_execute_time = high_level_30000_end_time - high_level_30000_start_time
high_level_2000_start_time = time.time()
jn = jn + 1
# Update progress bar
progress_percentage = int((current_iteration / total_iterations) * 100)
progress_bar.progress(progress_percentage)
current_iteration += 1
progress_bar.progress(100)
high_level_2000_end_time = time.time()
high_level_2000_execute_time = high_level_2000_end_time - high_level_2000_start_time
high_level_total_end_time = time.time()
high_level_total_execute_time = high_level_total_end_time - high_level_total_start_time
result = uniform_sampling(np.array(original_high_level_data), n_sample=int(option_time / 100))
return high_level_30000_execute_time, high_level_2000_execute_time, high_level_total_execute_time, result
def midpoint(option_time, slider_sample_orbit, progress_bar):
time.sleep(0.1)
mid_point_total_start_time = time.time()
midpointhelper = MidPoint(j=slider_sample_orbit)
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
#en0 = np.asarray(e0).astype(np.float64)
a1 = 1 / (2 - 2 ** (1 / 3))
a2 = 1 - 2 * a1
jn = 0
t = 0.1
# Calculate the total number of iterations for the progress bar update
total_iterations = (float(option_time) - t) / h
current_iteration = 0
original_mid_point_data = []
mid_point_30000_start_time = time.time()
while t < float(option_time):
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b)
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
t = t + h
if jn % 10 == 0:
original_mid_point_data.append([b, x, y, z, px, py, pz])
# Update progress bar
progress_percentage = int((current_iteration / total_iterations) * 100)
progress_bar.progress(progress_percentage)
if jn == 300000:
mid_point_30000_end_time = time.time()
mid_point_30000_execute_time = mid_point_30000_end_time - mid_point_30000_start_time
mid_point_2000_start_time = time.time()
jn = jn + 1
current_iteration += 1
#mid_point_df.to_excel('mid_point_df_output.xlsx', index=False)
progress_bar.progress(100)
mid_point_2000_end_time = time.time()
mid_point_2000_execute_time = mid_point_2000_end_time - mid_point_2000_start_time
mid_point_total_end_time = time.time()
mid_point_total_execute_time = mid_point_total_end_time - mid_point_total_start_time
result = uniform_sampling(np.array(original_mid_point_data), n_sample=int(option_time / 100))
return mid_point_30000_execute_time, mid_point_2000_execute_time, mid_point_total_execute_time, result
def low_level_lstm(slider_sample_orbit, lstm_progress_bar):
time.sleep(0.1)
total_start_time = time.time()
lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
lstm_model.to("cpu")
lstm_model.eval()
# Initialize variables for the classical method
lowlevelhelper = LowLevel(j=slider_sample_orbit)
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
a1 = 1 / (2 - 2 ** (1 / 3))
a2 = 1 - 2 * a1
jn = 0
t = 0.1
# Calculate the total number of iterations for the progress bar update
total_iterations = (float(30000) - t) / h
current_iteration = 0
original_low_level_data = []
low_level_start_time = time.time()
# Perform classical method prediction for the initial segment
while t < float(30000):
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
t = t + h
if jn % 10 == 0:
original_low_level_data.append([b, x, y, z, px, py, pz])
# Update progress bar
progress_percentage = int((current_iteration / total_iterations) * 100)
lstm_progress_bar.progress(progress_percentage)
jn = jn + 1
current_iteration += 1
original_low_level_data = np.array(original_low_level_data)
low_level_end_time = time.time()
low_level_data = original_low_level_data.copy()
low_level_data = uniform_sampling(low_level_data, n_sample=300)
scaler = MinMaxScaler()
low_level_data = scaler.fit_transform(low_level_data)
low_level_data = torch.tensor(np.stack(low_level_data)).float()
low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)
lstm_start_time = time.time()
with torch.no_grad():
lstm_preds = lstm_model(low_level_data[:, 100:300, :])
lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)
lstm_end_time = time.time()
lstm_progress_bar.progress(100)
combined_preds = np.concatenate([original_low_level_data, lstm_innv_preds], axis=0)
lstm_total_time = lstm_end_time - lstm_start_time
low_level_total_time = low_level_end_time - low_level_start_time
total_end_time = time.time()
total_time = total_end_time - total_start_time
return low_level_total_time, lstm_total_time, total_time, combined_preds
def mid_point_lstm(slider_sample_orbit, lstm_progress_bar):
time.sleep(0.1)
total_start_time = time.time()
lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
lstm_model.to("cpu")
lstm_model.eval()
midpointhelper = MidPoint(j=slider_sample_orbit)
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
a1 = 1 / (2 - 2 ** (1 / 3))
a2 = 1 - 2 * a1
jn = 0
t = 0.1
# Calculate the total number of iterations for the progress bar update
total_iterations = (float(30000) - t) / h
current_iteration = 0
original_mid_point_data = []
mid_point_start_time = time.time()
while t < float(30000):
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b)
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
t = t + h
if jn % 10 == 0:
original_mid_point_data.append([b, x, y, z, px, py, pz])
# Update progress bar
progress_percentage = int((current_iteration / total_iterations) * 100)
lstm_progress_bar.progress(progress_percentage)
jn = jn + 1
current_iteration += 1
original_mid_point_data = np.array(original_mid_point_data)
mid_point_end_time = time.time()
mid_point_data = original_mid_point_data.copy()
mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
scaler = MinMaxScaler()
mid_point_data = scaler.fit_transform(mid_point_data)
mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)
lstm_start_time = time.time()
with torch.no_grad():
lstm_preds = lstm_model(mid_point_data[:, 100:300, :])
lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)
lstm_end_time = time.time()
lstm_progress_bar.progress(100)
combined_preds = np.concatenate([original_mid_point_data, lstm_innv_preds], axis=0)
lstm_total_time = lstm_end_time - lstm_start_time
mid_point_total_time = mid_point_end_time - mid_point_start_time
total_end_time = time.time()
total_time = total_end_time - total_start_time
return mid_point_total_time, lstm_total_time, total_time, combined_preds
def low_level_tcn(slider_sample_orbit, tcn_progress_bar):
time.sleep(0.1)
total_start_time = time.time()
tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
move_custom_layers_to_device(tcn_model, "cpu")
tcn_model.eval()
# Initialize variables for the classical method
lowlevelhelper = LowLevel(j=slider_sample_orbit)
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
a1 = 1 / (2 - 2 ** (1 / 3))
a2 = 1 - 2 * a1
jn = 0
t = 0.1
# Calculate the total number of iterations for the progress bar update
total_iterations = (float(30000) - t) / h
current_iteration = 0
original_low_level_data = []
low_level_start_time = time.time()
# Perform classical method prediction for the initial segment
while t < float(30000):
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza)
t = t + h
if jn % 10 == 0:
original_low_level_data.append([b, x, y, z, px, py, pz])
# Update progress bar
progress_percentage = int((current_iteration / total_iterations) * 100)
tcn_progress_bar.progress(progress_percentage)
jn = jn + 1
current_iteration += 1
original_low_level_data = np.array(original_low_level_data)
low_level_end_time = time.time()
low_level_data = original_low_level_data.copy()
low_level_data = uniform_sampling(low_level_data, n_sample=300)
scaler = MinMaxScaler()
low_level_data = scaler.fit_transform(low_level_data)
low_level_data = torch.tensor(np.stack(low_level_data)).float()
low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)
tcn_start_time = time.time()
with torch.no_grad():
tcn_preds = None
for i in range(20):
if i == 0:
tcn_preds = tcn_model(low_level_data[:, :300, :])
else:
gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
output = tcn_model(torch.cat([low_level_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
tcn_preds = torch.cat([tcn_preds, output], dim=1)
tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)
tcn_end_time = time.time()
tcn_progress_bar.progress(100)
combined_preds = np.concatenate([original_low_level_data, tcn_innv_preds], axis=0)
tcn_total_time = tcn_end_time - tcn_start_time
low_level_total_time = low_level_end_time - low_level_start_time
total_end_time = time.time()
total_time = total_end_time - total_start_time
return low_level_total_time, tcn_total_time, total_time, combined_preds
def mid_point_tcn(slider_sample_orbit, tcn_progress_bar):
time.sleep(0.1)
total_start_time = time.time()
tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
move_custom_layers_to_device(tcn_model, "cpu")
tcn_model.eval()
# Initialize variables for the classical method
midpointhelper = MidPoint(j=slider_sample_orbit)
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
a1 = 1 / (2 - 2 ** (1 / 3))
a2 = 1 - 2 * a1
jn = 0
t = 0.1
# Calculate the total number of iterations for the progress bar update
total_iterations = (float(30000) - t) / h
current_iteration = 0
original_mid_point_data = []
mid_point_start_time = time.time()
# Perform classical method prediction for the initial segment
while t < float(30000):
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b)
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
t = t + h
if jn % 10 == 0:
original_mid_point_data.append([b, x, y, z, px, py, pz])
# Update progress bar
progress_percentage = int((current_iteration / total_iterations) * 100)
tcn_progress_bar.progress(progress_percentage)
jn = jn + 1
current_iteration += 1
original_mid_point_data = np.array(original_mid_point_data)
mid_point_end_time = time.time()
mid_point_data = original_mid_point_data.copy()
mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
scaler = MinMaxScaler()
mid_point_data = scaler.fit_transform(mid_point_data)
mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)
tcn_start_time = time.time()
with torch.no_grad():
tcn_preds = None
for i in range(20):
if i == 0:
tcn_preds = tcn_model(mid_point_data[:, :300, :])
else:
gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
output = tcn_model(torch.cat([mid_point_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
tcn_preds = torch.cat([tcn_preds, output], dim=1)
tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)
tcn_end_time = time.time()
tcn_progress_bar.progress(100)
combined_preds = np.concatenate([original_mid_point_data, tcn_innv_preds], axis=0)
tcn_total_time = tcn_end_time - tcn_start_time
mid_point_total_time = mid_point_end_time - mid_point_start_time
total_end_time = time.time()
total_time = total_end_time - total_start_time
return mid_point_total_time, tcn_total_time, total_time, combined_preds
container = st.container()
container1, container2 = st.columns(2)
plot_container = st.container()
with st.sidebar:
slider_sample_orbit = st.slider('Orbit Sample ID', 1, 10, 1)
option_time = 32000
st.write(f'Total Time Step: {option_time}')
options_method = st.multiselect(
'Compared Methods',
['EPS', 'Midpoint', 'EPS with LSTM', 'EPS with TCN', 'Midpoint with LSTM', 'Midpoint with TCN'],
['EPS'])
btn_go = st.button("Go", type="primary", use_container_width=True)
if btn_go:
if 'EPS' in options_method:
with container1:
st.write('EPS Progress Bar')
low_level_progress_bar = st.progress(0)
low_level_30000_time, low_level_2000_time, low_level_total_time, low_level_result = low_level(option_time, slider_sample_orbit, low_level_progress_bar)
with container2:
st.table(pd.DataFrame({'Model':"EPS", '30000 Time Steps (s)': [low_level_30000_time], '2000 Time Steps (s)': [low_level_2000_time], 'Total Time (s)': [low_level_total_time]}))
if 'High-Level' in options_method:
with container1:
st.write('High Level Progress Bar')
high_level_progress_bar = st.progress(0)
high_level_30000_time, high_level_2000_time, high_level_total_time, high_level_result = high_level(option_time, slider_sample_orbit, high_level_progress_bar)
with container2:
st.table(pd.DataFrame({'Model':"High Level", '30000 Time Steps (s)': [high_level_30000_time], '2000 Time Steps (s)': [high_level_2000_time], 'Total Time (s)': [high_level_total_time]}))
if 'Midpoint' in options_method:
with container1:
st.write('Midpoint Progress Bar')
mid_point_progress_bar = st.progress(0)
mid_point_30000_time, mid_point_2000_time, mid_point_total_time, mid_point_result = midpoint(option_time, slider_sample_orbit, mid_point_progress_bar)
with container2:
st.table(pd.DataFrame({'Model':"Midpoint", '30000 Time Steps (s)': [mid_point_30000_time], '2000 Time Steps (s)': [mid_point_2000_time], 'Total Time (s)': [mid_point_total_time]}))
if 'EPS with LSTM' in options_method:
with container1:
st.write('EPS LSTM Progress Bar')
low_level_lstm_progress_bar = st.progress(0)
lstm_30000_time, lstm_2000_time, lstm_total_time, lstm_result = low_level_lstm(slider_sample_orbit, low_level_lstm_progress_bar)
with container2:
st.table(pd.DataFrame({'Model':"EPS + LSTM", '30000 Time Steps (s)': [lstm_30000_time], '2000 Time Steps (s)': [lstm_2000_time], 'Total Time (s)': [lstm_total_time]}))
if 'EPS with TCN' in options_method:
with container1:
st.write('EPS TCN Progress Bar')
low_level_tcn_progress_bar = st.progress(0)
tcn_30000_time, tcn_2000_time, tcn_total_time, tcn_result = low_level_tcn(slider_sample_orbit, low_level_tcn_progress_bar)
with container2:
st.table(pd.DataFrame({'Model':"EPS + TCN", '30000 Time Steps (s)': [tcn_30000_time], '2000 Time Steps (s)': [tcn_2000_time], 'Total Time (s)': [tcn_total_time]}))
if 'Midpoint with LSTM' in options_method:
with container1:
st.write('Midpoint LSTM Progress Bar')
mid_point_lstm_progress_bar = st.progress(0)
md_lstm_30000_time, md_lstm_2000_time, md_lstm_total_time, md_lstm_result = mid_point_lstm(slider_sample_orbit, mid_point_lstm_progress_bar)
with container2:
st.table(pd.DataFrame({'Model':"Midpoint + LSTM", '30000 Time Steps (s)': [md_lstm_30000_time], '2000 Time Steps (s)': [md_lstm_2000_time], 'Total Time (s)': [md_lstm_total_time]}))
if 'Midpoint with TCN' in options_method:
with container1:
st.write('Midpoint TCN Progress Bar')
mid_point_tcn_progress_bar = st.progress(0)
md_tcn_30000_time, md_tcn_2000_time, md_tcn_total_time, md_tcn_result = mid_point_tcn(slider_sample_orbit, mid_point_tcn_progress_bar)
with container2:
st.table(pd.DataFrame({'Model':"Midpoint + TCN", '30000 Time Steps (s)': [md_tcn_30000_time], '2000 Time Steps (s)': [md_tcn_2000_time], 'Total Time (s)': [md_tcn_total_time]}))