yan123yan
commited on
Commit
·
cc9df09
1
Parent(s):
af53761
update midpoint function
Browse files- pages/inference.py +25 -27
- utils/midpoint.py +112 -71
pages/inference.py
CHANGED
@@ -130,7 +130,6 @@ def high_level(option_time, slider_sample_orbit, progress_bar):
|
|
130 |
def midpoint(option_time, slider_sample_orbit, progress_bar):
|
131 |
time.sleep(0.1)
|
132 |
mid_point_total_start_time = time.time()
|
133 |
-
mid_point_30000_start_time = time.time()
|
134 |
|
135 |
midpointhelper = MidPoint(j=slider_sample_orbit)
|
136 |
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
|
@@ -147,25 +146,26 @@ def midpoint(option_time, slider_sample_orbit, progress_bar):
|
|
147 |
|
148 |
original_mid_point_data = []
|
149 |
|
|
|
|
|
150 |
while t < float(option_time):
|
151 |
-
x, y, z, px, py, pz
|
152 |
-
x, y, z, px, py, pz
|
153 |
-
x, y, z, px, py, pz
|
154 |
-
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
155 |
|
156 |
t = t + h
|
157 |
|
158 |
if jn % 10 == 0:
|
159 |
original_mid_point_data.append([b, x, y, z, px, py, pz])
|
|
|
|
|
|
|
160 |
if jn == 300000:
|
161 |
mid_point_30000_end_time = time.time()
|
162 |
mid_point_30000_execute_time = mid_point_30000_end_time - mid_point_30000_start_time
|
163 |
mid_point_2000_start_time = time.time()
|
164 |
jn = jn + 1
|
165 |
|
166 |
-
# Update progress bar
|
167 |
-
progress_percentage = int((current_iteration / total_iterations) * 100)
|
168 |
-
progress_bar.progress(progress_percentage)
|
169 |
current_iteration += 1
|
170 |
|
171 |
#mid_point_df.to_excel('mid_point_df_output.xlsx', index=False)
|
@@ -277,10 +277,9 @@ def mid_point_lstm(slider_sample_orbit, lstm_progress_bar):
|
|
277 |
mid_point_start_time = time.time()
|
278 |
|
279 |
while t < float(30000):
|
280 |
-
x, y, z, px, py, pz
|
281 |
-
x, y, z, px, py, pz
|
282 |
-
x, y, z, px, py, pz
|
283 |
-
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
284 |
|
285 |
t = t + h
|
286 |
|
@@ -429,10 +428,9 @@ def mid_point_tcn(slider_sample_orbit, tcn_progress_bar):
|
|
429 |
|
430 |
# Perform classical method prediction for the initial segment
|
431 |
while t < float(30000):
|
432 |
-
x, y, z, px, py, pz
|
433 |
-
x, y, z, px, py, pz
|
434 |
-
x, y, z, px, py, pz
|
435 |
-
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza)
|
436 |
|
437 |
t = t + h
|
438 |
|
@@ -490,18 +488,18 @@ with st.sidebar:
|
|
490 |
st.write(f'Total Time Step: {option_time}')
|
491 |
options_method = st.multiselect(
|
492 |
'Compared Methods',
|
493 |
-
['
|
494 |
-
['
|
495 |
btn_go = st.button("Go", type="primary", use_container_width=True)
|
496 |
|
497 |
if btn_go:
|
498 |
-
if '
|
499 |
with container1:
|
500 |
-
st.write('
|
501 |
low_level_progress_bar = st.progress(0)
|
502 |
low_level_30000_time, low_level_2000_time, low_level_total_time, low_level_result = low_level(option_time, slider_sample_orbit, low_level_progress_bar)
|
503 |
with container2:
|
504 |
-
st.table(pd.DataFrame({'Model':"
|
505 |
if 'High-Level' in options_method:
|
506 |
with container1:
|
507 |
st.write('High Level Progress Bar')
|
@@ -516,20 +514,20 @@ if btn_go:
|
|
516 |
mid_point_30000_time, mid_point_2000_time, mid_point_total_time, mid_point_result = midpoint(option_time, slider_sample_orbit, mid_point_progress_bar)
|
517 |
with container2:
|
518 |
st.table(pd.DataFrame({'Model':"Midpoint", '30000 Time Steps (s)': [mid_point_30000_time], '2000 Time Steps (s)': [mid_point_2000_time], 'Total Time (s)': [mid_point_total_time]}))
|
519 |
-
if '
|
520 |
with container1:
|
521 |
-
st.write('
|
522 |
low_level_lstm_progress_bar = st.progress(0)
|
523 |
lstm_30000_time, lstm_2000_time, lstm_total_time, lstm_result = low_level_lstm(slider_sample_orbit, low_level_lstm_progress_bar)
|
524 |
with container2:
|
525 |
-
st.table(pd.DataFrame({'Model':"
|
526 |
-
if '
|
527 |
with container1:
|
528 |
-
st.write('
|
529 |
low_level_tcn_progress_bar = st.progress(0)
|
530 |
tcn_30000_time, tcn_2000_time, tcn_total_time, tcn_result = low_level_tcn(slider_sample_orbit, low_level_tcn_progress_bar)
|
531 |
with container2:
|
532 |
-
st.table(pd.DataFrame({'Model':"
|
533 |
if 'Midpoint with LSTM' in options_method:
|
534 |
with container1:
|
535 |
st.write('Midpoint LSTM Progress Bar')
|
|
|
130 |
def midpoint(option_time, slider_sample_orbit, progress_bar):
|
131 |
time.sleep(0.1)
|
132 |
mid_point_total_start_time = time.time()
|
|
|
133 |
|
134 |
midpointhelper = MidPoint(j=slider_sample_orbit)
|
135 |
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
|
|
|
146 |
|
147 |
original_mid_point_data = []
|
148 |
|
149 |
+
mid_point_30000_start_time = time.time()
|
150 |
+
|
151 |
while t < float(option_time):
|
152 |
+
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
|
153 |
+
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b)
|
154 |
+
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
|
|
|
155 |
|
156 |
t = t + h
|
157 |
|
158 |
if jn % 10 == 0:
|
159 |
original_mid_point_data.append([b, x, y, z, px, py, pz])
|
160 |
+
# Update progress bar
|
161 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
162 |
+
progress_bar.progress(progress_percentage)
|
163 |
if jn == 300000:
|
164 |
mid_point_30000_end_time = time.time()
|
165 |
mid_point_30000_execute_time = mid_point_30000_end_time - mid_point_30000_start_time
|
166 |
mid_point_2000_start_time = time.time()
|
167 |
jn = jn + 1
|
168 |
|
|
|
|
|
|
|
169 |
current_iteration += 1
|
170 |
|
171 |
#mid_point_df.to_excel('mid_point_df_output.xlsx', index=False)
|
|
|
277 |
mid_point_start_time = time.time()
|
278 |
|
279 |
while t < float(30000):
|
280 |
+
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
|
281 |
+
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b)
|
282 |
+
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
|
|
|
283 |
|
284 |
t = t + h
|
285 |
|
|
|
428 |
|
429 |
# Perform classical method prediction for the initial segment
|
430 |
while t < float(30000):
|
431 |
+
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
|
432 |
+
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b)
|
433 |
+
x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b)
|
|
|
434 |
|
435 |
t = t + h
|
436 |
|
|
|
488 |
st.write(f'Total Time Step: {option_time}')
|
489 |
options_method = st.multiselect(
|
490 |
'Compared Methods',
|
491 |
+
['EPS', 'Midpoint', 'EPS with LSTM', 'EPS with TCN', 'Midpoint with LSTM', 'Midpoint with TCN'],
|
492 |
+
['EPS'])
|
493 |
btn_go = st.button("Go", type="primary", use_container_width=True)
|
494 |
|
495 |
if btn_go:
|
496 |
+
if 'EPS' in options_method:
|
497 |
with container1:
|
498 |
+
st.write('EPS Progress Bar')
|
499 |
low_level_progress_bar = st.progress(0)
|
500 |
low_level_30000_time, low_level_2000_time, low_level_total_time, low_level_result = low_level(option_time, slider_sample_orbit, low_level_progress_bar)
|
501 |
with container2:
|
502 |
+
st.table(pd.DataFrame({'Model':"EPS", '30000 Time Steps (s)': [low_level_30000_time], '2000 Time Steps (s)': [low_level_2000_time], 'Total Time (s)': [low_level_total_time]}))
|
503 |
if 'High-Level' in options_method:
|
504 |
with container1:
|
505 |
st.write('High Level Progress Bar')
|
|
|
514 |
mid_point_30000_time, mid_point_2000_time, mid_point_total_time, mid_point_result = midpoint(option_time, slider_sample_orbit, mid_point_progress_bar)
|
515 |
with container2:
|
516 |
st.table(pd.DataFrame({'Model':"Midpoint", '30000 Time Steps (s)': [mid_point_30000_time], '2000 Time Steps (s)': [mid_point_2000_time], 'Total Time (s)': [mid_point_total_time]}))
|
517 |
+
if 'EPS with LSTM' in options_method:
|
518 |
with container1:
|
519 |
+
st.write('EPS LSTM Progress Bar')
|
520 |
low_level_lstm_progress_bar = st.progress(0)
|
521 |
lstm_30000_time, lstm_2000_time, lstm_total_time, lstm_result = low_level_lstm(slider_sample_orbit, low_level_lstm_progress_bar)
|
522 |
with container2:
|
523 |
+
st.table(pd.DataFrame({'Model':"EPS + LSTM", '30000 Time Steps (s)': [lstm_30000_time], '2000 Time Steps (s)': [lstm_2000_time], 'Total Time (s)': [lstm_total_time]}))
|
524 |
+
if 'EPS with TCN' in options_method:
|
525 |
with container1:
|
526 |
+
st.write('EPS TCN Progress Bar')
|
527 |
low_level_tcn_progress_bar = st.progress(0)
|
528 |
tcn_30000_time, tcn_2000_time, tcn_total_time, tcn_result = low_level_tcn(slider_sample_orbit, low_level_tcn_progress_bar)
|
529 |
with container2:
|
530 |
+
st.table(pd.DataFrame({'Model':"EPS + TCN", '30000 Time Steps (s)': [tcn_30000_time], '2000 Time Steps (s)': [tcn_2000_time], 'Total Time (s)': [tcn_total_time]}))
|
531 |
if 'Midpoint with LSTM' in options_method:
|
532 |
with container1:
|
533 |
st.write('Midpoint LSTM Progress Bar')
|
utils/midpoint.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
|
2 |
class MidPoint():
|
3 |
|
@@ -91,74 +92,114 @@ class MidPoint():
|
|
91 |
|
92 |
e = ht + hv + h1pn
|
93 |
|
94 |
-
vnpx=px
|
95 |
-
v1pnpx=4*px*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (px*(n + 3) + (n*
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
|
3 |
class MidPoint():
|
4 |
|
|
|
92 |
|
93 |
e = ht + hv + h1pn
|
94 |
|
95 |
+
vnpx = px
|
96 |
+
v1pnpx = 4 * px * ((3 * n) / 8 - 1 / 8) * (px ** 2 + py ** 2 + pz ** 2) - (px * (n + 3) + (n * x * (
|
97 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
98 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2))) / (x ** 2 + y ** 2 + z ** 2) ** (
|
99 |
+
1 / 2)) / (
|
100 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
101 |
+
vpx = vnpx + v1pnpx
|
102 |
+
|
103 |
+
vnpy = py
|
104 |
+
v1pnpy = 4 * py * ((3 * n) / 8 - 1 / 8) * (px ** 2 + py ** 2 + pz ** 2) - (py * (n + 3) + (n * y * (
|
105 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
106 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2))) / (x ** 2 + y ** 2 + z ** 2) ** (
|
107 |
+
1 / 2)) / (
|
108 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
109 |
+
vpy = vnpy + v1pnpy
|
110 |
+
|
111 |
+
vnpz = pz
|
112 |
+
v1pnpz = 4 * pz * ((3 * n) / 8 - 1 / 8) * (px ** 2 + py ** 2 + pz ** 2) - (pz * (n + 3) + (n * z * (
|
113 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
114 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2))) / (x ** 2 + y ** 2 + z ** 2) ** (
|
115 |
+
1 / 2)) / (
|
116 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
117 |
+
vpz = vnpz + v1pnpz
|
118 |
+
|
119 |
+
vnx = x / (x ** 2 + y ** 2 + z ** 2) ** (3 / 2)
|
120 |
+
v1pnx = (x * ((n * (
|
121 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
122 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) ** 2) / 2 + (
|
123 |
+
(n + 3) * (px ** 2 + py ** 2 + pz ** 2)) / 2)) / (x ** 2 + y ** 2 + z ** 2) ** (
|
124 |
+
3 / 2) - (4 * x) / (2 * x ** 2 + 2 * y ** 2 + 2 * z ** 2) ** 2 + (n * (
|
125 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
126 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) * ((px * x ** 2) / (
|
127 |
+
x ** 2 + y ** 2 + z ** 2) ** (3 / 2) - px / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * x * y) / (
|
128 |
+
x ** 2 + y ** 2 + z ** 2) ** (
|
129 |
+
3 / 2) + (
|
130 |
+
pz * x * z) / (
|
131 |
+
x ** 2 + y ** 2 + z ** 2) ** (
|
132 |
+
3 / 2))) / (
|
133 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
134 |
+
vx = vnx + v1pnx
|
135 |
+
|
136 |
+
vny = y / (x ** 2 + y ** 2 + z ** 2) ** (3 / 2)
|
137 |
+
v1pny = (y * ((n * (
|
138 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
139 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) ** 2) / 2 + (
|
140 |
+
(n + 3) * (px ** 2 + py ** 2 + pz ** 2)) / 2)) / (x ** 2 + y ** 2 + z ** 2) ** (
|
141 |
+
3 / 2) - (4 * y) / (2 * x ** 2 + 2 * y ** 2 + 2 * z ** 2) ** 2 + (n * (
|
142 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
143 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) * ((py * y ** 2) / (
|
144 |
+
x ** 2 + y ** 2 + z ** 2) ** (3 / 2) - py / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (px * x * y) / (
|
145 |
+
x ** 2 + y ** 2 + z ** 2) ** (
|
146 |
+
3 / 2) + (
|
147 |
+
pz * y * z) / (
|
148 |
+
x ** 2 + y ** 2 + z ** 2) ** (
|
149 |
+
3 / 2))) / (
|
150 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
151 |
+
vy = vny + v1pny
|
152 |
+
|
153 |
+
vnz = z / (x ** 2 + y ** 2 + z ** 2) ** (3 / 2)
|
154 |
+
v1pnz = (z * ((n * (
|
155 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
156 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) ** 2) / 2 + (
|
157 |
+
(n + 3) * (px ** 2 + py ** 2 + pz ** 2)) / 2)) / (x ** 2 + y ** 2 + z ** 2) ** (
|
158 |
+
3 / 2) - (4 * z) / (2 * x ** 2 + 2 * y ** 2 + 2 * z ** 2) ** 2 + (n * (
|
159 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
160 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) * ((pz * z ** 2) / (
|
161 |
+
x ** 2 + y ** 2 + z ** 2) ** (3 / 2) - pz / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (px * x * z) / (
|
162 |
+
x ** 2 + y ** 2 + z ** 2) ** (
|
163 |
+
3 / 2) + (
|
164 |
+
py * y * z) / (
|
165 |
+
x ** 2 + y ** 2 + z ** 2) ** (
|
166 |
+
3 / 2))) / (
|
167 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
168 |
+
vz = vnz + v1pnz
|
169 |
+
return vx, vy, vz, vpx, vpy, vpz, e
|
170 |
+
|
171 |
+
def implecitsymplectic(self, h, x, y, z, px, py, pz, b):
|
172 |
+
|
173 |
+
x1 = x
|
174 |
+
y1 = y
|
175 |
+
z1 = z
|
176 |
+
px1 = px
|
177 |
+
py1 = py
|
178 |
+
pz1 = pz
|
179 |
+
num = 0
|
180 |
+
d = 1.1
|
181 |
+
while (d > 1e-16 and num < 1000):
|
182 |
+
px2 = px
|
183 |
+
py2 = py
|
184 |
+
pz2 = pz
|
185 |
+
x2 = x
|
186 |
+
y2 = y
|
187 |
+
z2 = z
|
188 |
+
x = (x1 + x) / 2
|
189 |
+
y = (y1 + y) / 2
|
190 |
+
z = (z1 + z) / 2
|
191 |
+
px = (px1 + px) / 2
|
192 |
+
py = (py1 + py) / 2
|
193 |
+
pz = (pz1 + pz) / 2
|
194 |
+
vx, vy, vz, vpx, vpy, vpz, e = self.f(x, y, z, px, py, pz, b)
|
195 |
+
x = x1 + h * vpx
|
196 |
+
y = y1 + h * vpy
|
197 |
+
z = z1 + h * vpz
|
198 |
+
px = px1 - h * vx
|
199 |
+
py = py1 - h * vy
|
200 |
+
pz = pz1 - h * vz
|
201 |
+
d = np.sqrt(np.abs(x - x2) + np.abs(y - y2) + np.abs(z - z2) + np.abs(px - px2) + np.abs(py - py2) + np.abs(
|
202 |
+
pz - pz2))
|
203 |
+
num += 1
|
204 |
+
#print(num, d)
|
205 |
+
return x, y, z, px, py, pz
|