import os
import time

import streamlit as st
import torch
import torch.nn.functional as F
import random
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import pandas as pd

from model.lstm import LSTMModel
from model.tcn import TCNModel
from model.tcn import move_custom_layers_to_device
from utils.lowlevel import LowLevel
from utils.highlevel import HighLevel
from utils.midpoint import MidPoint

from utils.transform import compute_gradient

st.set_page_config(page_title="Inference", page_icon=":chart_with_upwards_trend:", layout="wide", initial_sidebar_state="auto")

def uniform_sampling(data, n_sample):
    k = len(data) // n_sample
    return data[::k]

def low_level(option_time, slider_sample_orbit, progress_bar):
    time.sleep(0.1)
    low_level_total_start_time = time.time()
    low_level_30000_start_time = time.time()

    lowlevelhelper = LowLevel(j=slider_sample_orbit)
    j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()

    a1 = 1 / (2 - 2 ** (1 / 3))
    a2 = 1 - 2 * a1
    jn = 0
    t = 0.1

    # Calculate the total number of iterations for the progress bar update
    total_iterations = (float(option_time) - t) / h
    current_iteration = 0

    original_low_level_data = []

    while t < float(option_time):
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)

        t = t + h

        if jn % 10 == 0:
            original_low_level_data.append([b, x, y, z, px, py, pz])
            # Update progress bar
            progress_percentage = int((current_iteration / total_iterations) * 100)
            progress_bar.progress(progress_percentage)

        if jn == 300000:
            low_level_30000_end_time = time.time()
            low_level_30000_execute_time = low_level_30000_end_time - low_level_30000_start_time
            low_level_2000_start_time = time.time()
        jn = jn + 1
        current_iteration += 1

    progress_bar.progress(100)

    low_level_2000_end_time = time.time()
    low_level_2000_execute_time = low_level_2000_end_time - low_level_2000_start_time
    low_level_total_end_time = time.time()
    low_level_total_execute_time = low_level_total_end_time - low_level_total_start_time

    result = uniform_sampling(np.array(original_low_level_data), n_sample=int(option_time/100))

    return low_level_30000_execute_time, low_level_2000_execute_time, low_level_total_execute_time, result

def high_level(option_time, slider_sample_orbit, progress_bar):
    time.sleep(0.1)
    high_level_total_start_time = time.time()
    high_level_30000_start_time = time.time()

    highlevelhelper = HighLevel(j=slider_sample_orbit)
    j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = highlevelhelper.initial()

    a1 = 1 / (2 - 2 ** (1 / 3))
    a2 = 1 - 2 * a1
    jn = 0
    t = 0.1

    # Calculate the total number of iterations for the progress bar update
    total_iterations = (float(option_time) - t) / h
    current_iteration = 0

    original_high_level_data = []

    while t < float(option_time):
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)

        t = t + h
        vx, vy, vz, vpx, vpy, vpz, e = highlevelhelper.f(x, y, z, px, py, pz, b)
        en = np.asarray(e).astype(np.float64)

        if jn % 10 == 0:
            original_high_level_data.append([b, x, y, z, px, py, pz])
        if jn == 300000:
            high_level_30000_end_time = time.time()
            high_level_30000_execute_time = high_level_30000_end_time - high_level_30000_start_time
            high_level_2000_start_time = time.time()
        jn = jn + 1

        # Update progress bar
        progress_percentage = int((current_iteration / total_iterations) * 100)
        progress_bar.progress(progress_percentage)
        current_iteration += 1

    progress_bar.progress(100)
    high_level_2000_end_time = time.time()
    high_level_2000_execute_time = high_level_2000_end_time - high_level_2000_start_time
    high_level_total_end_time = time.time()
    high_level_total_execute_time = high_level_total_end_time - high_level_total_start_time

    result = uniform_sampling(np.array(original_high_level_data), n_sample=int(option_time / 100))

    return high_level_30000_execute_time, high_level_2000_execute_time, high_level_total_execute_time, result

def midpoint(option_time, slider_sample_orbit, progress_bar):
    time.sleep(0.1)
    mid_point_total_start_time = time.time()

    midpointhelper = MidPoint(j=slider_sample_orbit)
    j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()

    #en0 = np.asarray(e0).astype(np.float64)
    a1 = 1 / (2 - 2 ** (1 / 3))
    a2 = 1 - 2 * a1
    jn = 0
    t = 0.1

    # Calculate the total number of iterations for the progress bar update
    total_iterations = (float(option_time) - t) / h
    current_iteration = 0

    original_mid_point_data = []

    mid_point_30000_start_time = time.time()

    while t < float(option_time):
        x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
        x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b)
        x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)

        t = t + h

        if jn % 10 == 0:
            original_mid_point_data.append([b, x, y, z, px, py, pz])
            # Update progress bar
            progress_percentage = int((current_iteration / total_iterations) * 100)
            progress_bar.progress(progress_percentage)
        if jn == 300000:
            mid_point_30000_end_time = time.time()
            mid_point_30000_execute_time = mid_point_30000_end_time - mid_point_30000_start_time
            mid_point_2000_start_time = time.time()
        jn = jn + 1

        current_iteration += 1

    #mid_point_df.to_excel('mid_point_df_output.xlsx', index=False)
    progress_bar.progress(100)
    mid_point_2000_end_time = time.time()
    mid_point_2000_execute_time = mid_point_2000_end_time - mid_point_2000_start_time
    mid_point_total_end_time = time.time()
    mid_point_total_execute_time = mid_point_total_end_time - mid_point_total_start_time

    result = uniform_sampling(np.array(original_mid_point_data), n_sample=int(option_time / 100))

    return mid_point_30000_execute_time, mid_point_2000_execute_time, mid_point_total_execute_time, result

def low_level_lstm(slider_sample_orbit, lstm_progress_bar):
    time.sleep(0.1)
    total_start_time = time.time()

    lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
    lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
    lstm_model.to("cpu")
    lstm_model.eval()

    # Initialize variables for the classical method
    lowlevelhelper = LowLevel(j=slider_sample_orbit)
    j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()

    a1 = 1 / (2 - 2 ** (1 / 3))
    a2 = 1 - 2 * a1
    jn = 0
    t = 0.1

    # Calculate the total number of iterations for the progress bar update
    total_iterations = (float(30000) - t) / h
    current_iteration = 0

    original_low_level_data = []

    low_level_start_time = time.time()

    # Perform classical method prediction for the initial segment
    while t < float(30000):
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
        t = t + h

        if jn % 10 == 0:
            original_low_level_data.append([b, x, y, z, px, py, pz])
            # Update progress bar
            progress_percentage = int((current_iteration / total_iterations) * 100)
            lstm_progress_bar.progress(progress_percentage)

        jn = jn + 1
        current_iteration += 1

    original_low_level_data = np.array(original_low_level_data)
    low_level_end_time = time.time()
    low_level_data = original_low_level_data.copy()
    low_level_data = uniform_sampling(low_level_data, n_sample=300)
    scaler = MinMaxScaler()
    low_level_data = scaler.fit_transform(low_level_data)
    low_level_data = torch.tensor(np.stack(low_level_data)).float()
    low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)

    lstm_start_time = time.time()
    with torch.no_grad():
        lstm_preds = lstm_model(low_level_data[:, 100:300, :])
        lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())

    original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)

    lstm_end_time = time.time()
    lstm_progress_bar.progress(100)

    combined_preds = np.concatenate([original_low_level_data, lstm_innv_preds], axis=0)

    lstm_total_time = lstm_end_time - lstm_start_time
    low_level_total_time = low_level_end_time - low_level_start_time

    total_end_time = time.time()
    total_time = total_end_time - total_start_time

    return low_level_total_time, lstm_total_time, total_time, combined_preds

def mid_point_lstm(slider_sample_orbit, lstm_progress_bar):
    time.sleep(0.1)
    total_start_time = time.time()

    lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
    lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
    lstm_model.to("cpu")
    lstm_model.eval()

    midpointhelper = MidPoint(j=slider_sample_orbit)
    j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()

    a1 = 1 / (2 - 2 ** (1 / 3))
    a2 = 1 - 2 * a1
    jn = 0
    t = 0.1

    # Calculate the total number of iterations for the progress bar update
    total_iterations = (float(30000) - t) / h
    current_iteration = 0

    original_mid_point_data = []

    mid_point_start_time = time.time()

    while t < float(30000):
        x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
        x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b)
        x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)

        t = t + h

        if jn % 10 == 0:
            original_mid_point_data.append([b, x, y, z, px, py, pz])
            # Update progress bar
            progress_percentage = int((current_iteration / total_iterations) * 100)
            lstm_progress_bar.progress(progress_percentage)
        jn = jn + 1
        current_iteration += 1

    original_mid_point_data = np.array(original_mid_point_data)
    mid_point_end_time = time.time()
    mid_point_data = original_mid_point_data.copy()
    mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
    scaler = MinMaxScaler()
    mid_point_data = scaler.fit_transform(mid_point_data)
    mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
    mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)

    lstm_start_time = time.time()
    with torch.no_grad():
        lstm_preds = lstm_model(mid_point_data[:, 100:300, :])
        lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())

    original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)

    lstm_end_time = time.time()
    lstm_progress_bar.progress(100)

    combined_preds = np.concatenate([original_mid_point_data, lstm_innv_preds], axis=0)

    lstm_total_time = lstm_end_time - lstm_start_time
    mid_point_total_time = mid_point_end_time - mid_point_start_time

    total_end_time = time.time()
    total_time = total_end_time - total_start_time

    return mid_point_total_time, lstm_total_time, total_time, combined_preds

def low_level_tcn(slider_sample_orbit, tcn_progress_bar):
    time.sleep(0.1)
    total_start_time = time.time()

    tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
    tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
    move_custom_layers_to_device(tcn_model, "cpu")
    tcn_model.eval()

    # Initialize variables for the classical method
    lowlevelhelper = LowLevel(j=slider_sample_orbit)
    j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()

    a1 = 1 / (2 - 2 ** (1 / 3))
    a2 = 1 - 2 * a1
    jn = 0
    t = 0.1

    # Calculate the total number of iterations for the progress bar update
    total_iterations = (float(30000) - t) / h
    current_iteration = 0

    original_low_level_data = []

    low_level_start_time = time.time()

    # Perform classical method prediction for the initial segment
    while t < float(30000):
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
        x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza)

        t = t + h

        if jn % 10 == 0:
            original_low_level_data.append([b, x, y, z, px, py, pz])
            # Update progress bar
            progress_percentage = int((current_iteration / total_iterations) * 100)
            tcn_progress_bar.progress(progress_percentage)

        jn = jn + 1
        current_iteration += 1

    original_low_level_data = np.array(original_low_level_data)
    low_level_end_time = time.time()
    low_level_data = original_low_level_data.copy()
    low_level_data = uniform_sampling(low_level_data, n_sample=300)
    scaler = MinMaxScaler()
    low_level_data = scaler.fit_transform(low_level_data)
    low_level_data = torch.tensor(np.stack(low_level_data)).float()
    low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)

    tcn_start_time = time.time()
    with torch.no_grad():
        tcn_preds = None
        for i in range(20):
            if i == 0:
                tcn_preds = tcn_model(low_level_data[:, :300, :])
            else:
                gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
                output = tcn_model(torch.cat([low_level_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
                tcn_preds = torch.cat([tcn_preds, output], dim=1)
        tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())

    original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)

    tcn_end_time = time.time()
    tcn_progress_bar.progress(100)

    combined_preds = np.concatenate([original_low_level_data, tcn_innv_preds], axis=0)

    tcn_total_time = tcn_end_time - tcn_start_time
    low_level_total_time = low_level_end_time - low_level_start_time

    total_end_time = time.time()
    total_time = total_end_time - total_start_time

    return low_level_total_time, tcn_total_time, total_time, combined_preds

def mid_point_tcn(slider_sample_orbit, tcn_progress_bar):
    time.sleep(0.1)
    total_start_time = time.time()

    tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
    tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
    move_custom_layers_to_device(tcn_model, "cpu")
    tcn_model.eval()

    # Initialize variables for the classical method
    midpointhelper = MidPoint(j=slider_sample_orbit)
    j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()

    a1 = 1 / (2 - 2 ** (1 / 3))
    a2 = 1 - 2 * a1
    jn = 0
    t = 0.1

    # Calculate the total number of iterations for the progress bar update
    total_iterations = (float(30000) - t) / h
    current_iteration = 0

    original_mid_point_data = []

    mid_point_start_time = time.time()

    # Perform classical method prediction for the initial segment
    while t < float(30000):
        x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
        x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b)
        x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)

        t = t + h

        if jn % 10 == 0:
            original_mid_point_data.append([b, x, y, z, px, py, pz])
            # Update progress bar
            progress_percentage = int((current_iteration / total_iterations) * 100)
            tcn_progress_bar.progress(progress_percentage)
        jn = jn + 1
        current_iteration += 1

    original_mid_point_data = np.array(original_mid_point_data)
    mid_point_end_time = time.time()
    mid_point_data = original_mid_point_data.copy()
    mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
    scaler = MinMaxScaler()
    mid_point_data = scaler.fit_transform(mid_point_data)
    mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
    mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)

    tcn_start_time = time.time()
    with torch.no_grad():
        tcn_preds = None
        for i in range(20):
            if i == 0:
                tcn_preds = tcn_model(mid_point_data[:, :300, :])
            else:
                gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
                output = tcn_model(torch.cat([mid_point_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
                tcn_preds = torch.cat([tcn_preds, output], dim=1)
        tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())

    original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)

    tcn_end_time = time.time()
    tcn_progress_bar.progress(100)

    combined_preds = np.concatenate([original_mid_point_data, tcn_innv_preds], axis=0)

    tcn_total_time = tcn_end_time - tcn_start_time
    mid_point_total_time = mid_point_end_time - mid_point_start_time

    total_end_time = time.time()
    total_time = total_end_time - total_start_time

    return mid_point_total_time, tcn_total_time, total_time, combined_preds

container = st.container()
container1, container2 = st.columns(2)
plot_container = st.container()

with st.sidebar:
    slider_sample_orbit = st.slider('Orbit Sample ID', 1, 10, 1)
    option_time = 32000
    st.write(f'Total Time Step: {option_time}')
    options_method = st.multiselect(
        'Compared Methods',
        ['EPS', 'Midpoint', 'EPS with LSTM', 'EPS with TCN', 'Midpoint with LSTM', 'Midpoint with TCN'],
        ['EPS'])
    btn_go = st.button("Go", type="primary", use_container_width=True)

if btn_go:
    if 'EPS' in options_method:
        with container1:
            st.write('EPS Progress Bar')
            low_level_progress_bar = st.progress(0)
        low_level_30000_time, low_level_2000_time, low_level_total_time, low_level_result = low_level(option_time, slider_sample_orbit, low_level_progress_bar)
        with container2:
            st.table(pd.DataFrame({'Model':"EPS", '30000 Time Steps (s)': [low_level_30000_time], '2000 Time Steps (s)': [low_level_2000_time], 'Total Time (s)': [low_level_total_time]}))
    if 'High-Level' in options_method:
        with container1:
            st.write('High Level Progress Bar')
            high_level_progress_bar = st.progress(0)
        high_level_30000_time, high_level_2000_time, high_level_total_time, high_level_result = high_level(option_time, slider_sample_orbit, high_level_progress_bar)
        with container2:
            st.table(pd.DataFrame({'Model':"High Level", '30000 Time Steps (s)': [high_level_30000_time], '2000 Time Steps (s)': [high_level_2000_time], 'Total Time (s)': [high_level_total_time]}))
    if 'Midpoint' in options_method:
        with container1:
            st.write('Midpoint Progress Bar')
            mid_point_progress_bar = st.progress(0)
        mid_point_30000_time, mid_point_2000_time, mid_point_total_time, mid_point_result = midpoint(option_time, slider_sample_orbit, mid_point_progress_bar)
        with container2:
            st.table(pd.DataFrame({'Model':"Midpoint", '30000 Time Steps (s)': [mid_point_30000_time], '2000 Time Steps (s)': [mid_point_2000_time], 'Total Time (s)': [mid_point_total_time]}))
    if 'EPS with LSTM' in options_method:
        with container1:
            st.write('EPS LSTM Progress Bar')
            low_level_lstm_progress_bar = st.progress(0)
        lstm_30000_time, lstm_2000_time, lstm_total_time, lstm_result = low_level_lstm(slider_sample_orbit, low_level_lstm_progress_bar)
        with container2:
            st.table(pd.DataFrame({'Model':"EPS + LSTM", '30000 Time Steps (s)': [lstm_30000_time], '2000 Time Steps (s)': [lstm_2000_time], 'Total Time (s)': [lstm_total_time]}))
    if 'EPS with TCN' in options_method:
        with container1:
            st.write('EPS TCN Progress Bar')
            low_level_tcn_progress_bar = st.progress(0)
        tcn_30000_time, tcn_2000_time, tcn_total_time, tcn_result = low_level_tcn(slider_sample_orbit, low_level_tcn_progress_bar)
        with container2:
            st.table(pd.DataFrame({'Model':"EPS + TCN", '30000 Time Steps (s)': [tcn_30000_time], '2000 Time Steps (s)': [tcn_2000_time], 'Total Time (s)': [tcn_total_time]}))
    if 'Midpoint with LSTM' in options_method:
        with container1:
            st.write('Midpoint LSTM Progress Bar')
            mid_point_lstm_progress_bar = st.progress(0)
        md_lstm_30000_time, md_lstm_2000_time, md_lstm_total_time, md_lstm_result = mid_point_lstm(slider_sample_orbit, mid_point_lstm_progress_bar)
        with container2:
            st.table(pd.DataFrame({'Model':"Midpoint + LSTM", '30000 Time Steps (s)': [md_lstm_30000_time], '2000 Time Steps (s)': [md_lstm_2000_time], 'Total Time (s)': [md_lstm_total_time]}))
    if 'Midpoint with TCN' in options_method:
        with container1:
            st.write('Midpoint TCN Progress Bar')
            mid_point_tcn_progress_bar = st.progress(0)
        md_tcn_30000_time, md_tcn_2000_time, md_tcn_total_time, md_tcn_result = mid_point_tcn(slider_sample_orbit, mid_point_tcn_progress_bar)
        with container2:
            st.table(pd.DataFrame({'Model':"Midpoint + TCN", '30000 Time Steps (s)': [md_tcn_30000_time], '2000 Time Steps (s)': [md_tcn_2000_time], 'Total Time (s)': [md_tcn_total_time]}))