yan123yan
commited on
Commit
·
47fe089
1
Parent(s):
8cf3e4f
first version
Browse files- data/file1.dat.npz +3 -0
- data/file10.dat.npz +3 -0
- data/file2.dat.npz +3 -0
- data/file3.dat.npz +3 -0
- data/file4.dat.npz +3 -0
- data/file5.dat.npz +3 -0
- data/file6.dat.npz +3 -0
- data/file7.dat.npz +3 -0
- data/file8.dat.npz +3 -0
- data/file9.dat.npz +3 -0
- model/__pycache__/lstm.cpython-310.pyc +0 -0
- model/__pycache__/tcn.cpython-310.pyc +0 -0
- model/__pycache__/tcn_module.cpython-310.pyc +0 -0
- model/lstm.ckpt +3 -0
- model/lstm.py +22 -0
- model/tcn.ckpt +3 -0
- model/tcn.py +40 -0
- model/tcn_module.py +511 -0
- pages/inference.py +547 -0
- prediction.py +196 -0
- requirements.txt +7 -0
- utils/__pycache__/highlevel.cpython-310.pyc +0 -0
- utils/__pycache__/lowlevel.cpython-310.pyc +0 -0
- utils/__pycache__/metrics.cpython-310.pyc +0 -0
- utils/__pycache__/midpoint.cpython-310.pyc +0 -0
- utils/__pycache__/transform.cpython-310.pyc +0 -0
- utils/highlevel.py +160 -0
- utils/lowlevel.py +158 -0
- utils/metrics.py +28 -0
- utils/midpoint.py +164 -0
- utils/transform.py +8 -0
data/file1.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90713e007e25e2b4467711981274dec5f15548666bf7867a30cdf5e189b94b80
|
3 |
+
size 7200262
|
data/file10.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:123c044fc18c7e1b0671261fc081aa7e9a60eab726f35b2758b75f3e70ebe76e
|
3 |
+
size 7200262
|
data/file2.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ebdbdcb3465e49f16073eb828dc18163a12e8b5968cb7f4661e3931c13ac0cea
|
3 |
+
size 7200262
|
data/file3.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b14cf26774c287dd0fb2351a6fc6ce3a2eae135a57c265a93c03f017d479d3a
|
3 |
+
size 7200262
|
data/file4.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7810f50b9c30a19fbfcb3db77bf838b02df4fc54f6caa63e8dd3ca0f2abba6c1
|
3 |
+
size 7200262
|
data/file5.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:34f3485194b57325b0f32554c4f14218389a6d47d5b6edb7e37626fc48d6aae8
|
3 |
+
size 7200262
|
data/file6.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d59084be5aff0eb578c7ff5ee62027ef853d1f5d8d2794d6230d5b706aa0f6aa
|
3 |
+
size 7200262
|
data/file7.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9035c00bb34d17940b033e3bae40097296c493a4852630fbcf963c6f80391f5a
|
3 |
+
size 7200262
|
data/file8.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:955406a5ec5af34373c4ae8573195ca580bfcfbbef13efdcb410fc05d58d66f8
|
3 |
+
size 7200262
|
data/file9.dat.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ea42398e0062a3280981795ea8fd3a49bd5e104de178decfad9e69e41c0de8c
|
3 |
+
size 7200262
|
model/__pycache__/lstm.cpython-310.pyc
ADDED
Binary file (1.2 kB). View file
|
|
model/__pycache__/tcn.cpython-310.pyc
ADDED
Binary file (1.59 kB). View file
|
|
model/__pycache__/tcn_module.cpython-310.pyc
ADDED
Binary file (15.6 kB). View file
|
|
model/lstm.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1080dbdc37acdb1e9e6a29c140711908d2426b39735617eafbadd49fb5772ef4
|
3 |
+
size 7286190
|
model/lstm.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class LSTMModel(pl.LightningModule):
|
6 |
+
def __init__(self, **config):
|
7 |
+
super(LSTMModel, self).__init__()
|
8 |
+
self.save_hyperparameters(config)
|
9 |
+
|
10 |
+
self.lstm = nn.LSTM(input_size=21,hidden_size=512,num_layers=3,proj_size=21,batch_first=True)
|
11 |
+
self.linear = nn.Linear(in_features=21, out_features=7)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
outputs = []
|
15 |
+
hidden, cell = None, None
|
16 |
+
for i in range(20):
|
17 |
+
if i == 0:
|
18 |
+
output, (hidden, cell) = self.lstm(x)
|
19 |
+
else:
|
20 |
+
output, (hidden, cell) = self.lstm(output[:, -1, :].unsqueeze(1), (hidden, cell))
|
21 |
+
outputs.append(self.linear(output[:, -1, :]))
|
22 |
+
return torch.stack(outputs, dim=1)
|
model/tcn.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc975af786670412a2e419978636038452467676e1a9dac59d2ed77033f9f67b
|
3 |
+
size 43742454
|
model/tcn.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["KERAS_BACKEND"] = "torch"
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from keras.layers import Input, Dense
|
8 |
+
from keras.models import Model
|
9 |
+
from model.tcn_module import TCN
|
10 |
+
|
11 |
+
|
12 |
+
class TCNModel(pl.LightningModule):
|
13 |
+
def __init__(self, **config):
|
14 |
+
super(TCNModel, self).__init__()
|
15 |
+
self.save_hyperparameters(config)
|
16 |
+
|
17 |
+
input_layer = Input(shape=(self.hparams.windows_size, self.hparams.input_size))
|
18 |
+
self.tcn = TCN(input_shape=(self.hparams.windows_size, self.hparams.input_size))(input_layer)
|
19 |
+
self.linear = Dense(7)(self.tcn)
|
20 |
+
self.model = Model(inputs=input_layer, outputs=self.linear)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
output = self.model(x)
|
24 |
+
return torch.stack([output], dim=1)
|
25 |
+
|
26 |
+
def move_custom_layers_to_device(model, device):
|
27 |
+
for name, module in model.named_children():
|
28 |
+
# 如果是标准层,named_children已经处理了
|
29 |
+
if isinstance(module, nn.Module):
|
30 |
+
continue
|
31 |
+
|
32 |
+
# 对于非标准层,例如包含在列表或字典中的层
|
33 |
+
if isinstance(module, list):
|
34 |
+
for sub_module in module:
|
35 |
+
if isinstance(sub_module, nn.Module):
|
36 |
+
sub_module.to(device)
|
37 |
+
elif isinstance(module, dict):
|
38 |
+
for sub_module in module.values():
|
39 |
+
if isinstance(sub_module, nn.Module):
|
40 |
+
sub_module.to(device)
|
model/tcn_module.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import os
|
5 |
+
os.environ["KERAS_BACKEND"] = "torch"
|
6 |
+
import keras
|
7 |
+
|
8 |
+
# from keras_core import backend as K, Model, Input, optimizers
|
9 |
+
# from keras_core import backend as Model, Input, optimizers
|
10 |
+
# from keras_core import backend as K
|
11 |
+
|
12 |
+
from keras import Model
|
13 |
+
from keras import optimizers
|
14 |
+
from keras import ops as K
|
15 |
+
from keras import config as KK
|
16 |
+
|
17 |
+
from keras import layers
|
18 |
+
from keras.layers import Input, Layer, Conv1D, Dense, BatchNormalization, LayerNormalization, Activation, SpatialDropout1D, Lambda
|
19 |
+
|
20 |
+
|
21 |
+
def is_power_of_two(num: int):
|
22 |
+
return num != 0 and ((num & (num - 1)) == 0)
|
23 |
+
|
24 |
+
|
25 |
+
def adjust_dilations(dilations: list):
|
26 |
+
if all([is_power_of_two(i) for i in dilations]):
|
27 |
+
return dilations
|
28 |
+
else:
|
29 |
+
new_dilations = [2 ** i for i in dilations]
|
30 |
+
return new_dilations
|
31 |
+
|
32 |
+
|
33 |
+
class ResidualBlock(Layer):
|
34 |
+
|
35 |
+
def __init__(self,
|
36 |
+
dilation_rate: int,
|
37 |
+
nb_filters: int,
|
38 |
+
kernel_size: int,
|
39 |
+
padding: str,
|
40 |
+
activation: str = 'relu',
|
41 |
+
dropout_rate: float = 0,
|
42 |
+
kernel_initializer: str = 'he_normal',
|
43 |
+
use_batch_norm: bool = False,
|
44 |
+
use_layer_norm: bool = False,
|
45 |
+
use_weight_norm: bool = False,
|
46 |
+
**kwargs):
|
47 |
+
"""Defines the residual block for the WaveNet TCN
|
48 |
+
Args:
|
49 |
+
x: The previous layer in the model
|
50 |
+
training: boolean indicating whether the layer should behave in training mode or in inference mode
|
51 |
+
dilation_rate: The dilation power of 2 we are using for this residual block
|
52 |
+
nb_filters: The number of convolutional filters to use in this block
|
53 |
+
kernel_size: The size of the convolutional kernel
|
54 |
+
padding: The padding used in the convolutional layers, 'same' or 'causal'.
|
55 |
+
activation: The final activation used in o = Activation(x + F(x))
|
56 |
+
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
|
57 |
+
kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
|
58 |
+
use_batch_norm: Whether to use batch normalization in the residual layers or not.
|
59 |
+
use_layer_norm: Whether to use layer normalization in the residual layers or not.
|
60 |
+
use_weight_norm: Whether to use weight normalization in the residual layers or not.
|
61 |
+
kwargs: Any initializers for Layer class.
|
62 |
+
"""
|
63 |
+
|
64 |
+
self.dilation_rate = dilation_rate
|
65 |
+
self.nb_filters = nb_filters
|
66 |
+
self.kernel_size = kernel_size
|
67 |
+
self.padding = padding
|
68 |
+
self.activation = activation
|
69 |
+
self.dropout_rate = dropout_rate
|
70 |
+
self.use_batch_norm = use_batch_norm
|
71 |
+
self.use_layer_norm = use_layer_norm
|
72 |
+
self.use_weight_norm = use_weight_norm
|
73 |
+
self.kernel_initializer = kernel_initializer
|
74 |
+
self.layers = []
|
75 |
+
self.shape_match_conv = None
|
76 |
+
self.res_output_shape = None
|
77 |
+
self.final_activation = None
|
78 |
+
|
79 |
+
super(ResidualBlock, self).__init__(**kwargs)
|
80 |
+
|
81 |
+
def _build_layer(self, layer):
|
82 |
+
"""Helper function for building layer
|
83 |
+
Args:
|
84 |
+
layer: Appends layer to internal layer list and builds it based on the current output
|
85 |
+
shape of ResidualBlocK. Updates current output shape.
|
86 |
+
"""
|
87 |
+
self.layers.append(layer)
|
88 |
+
self.layers[-1].build(self.res_output_shape)
|
89 |
+
self.res_output_shape = self.layers[-1].compute_output_shape(self.res_output_shape)
|
90 |
+
|
91 |
+
def build(self, input_shape):
|
92 |
+
|
93 |
+
#with K.name_scope(self.name): # name scope used to make sure weights get unique names
|
94 |
+
self.layers = []
|
95 |
+
self.res_output_shape = input_shape
|
96 |
+
|
97 |
+
for k in range(2): # dilated conv block.
|
98 |
+
name = 'conv1D_{}'.format(k)
|
99 |
+
# with K.name_scope(name): # name scope used to make sure weights get unique names
|
100 |
+
conv = Conv1D(
|
101 |
+
filters=self.nb_filters,
|
102 |
+
kernel_size=self.kernel_size,
|
103 |
+
dilation_rate=self.dilation_rate,
|
104 |
+
padding=self.padding,
|
105 |
+
name=name,
|
106 |
+
kernel_initializer=self.kernel_initializer
|
107 |
+
)
|
108 |
+
if self.use_weight_norm:
|
109 |
+
from tensorflow_addons.layers import WeightNormalization
|
110 |
+
# wrap it. WeightNormalization API is different than BatchNormalization or LayerNormalization.
|
111 |
+
#with K.name_scope('norm_{}'.format(k)):
|
112 |
+
conv = WeightNormalization(conv)
|
113 |
+
self._build_layer(conv)
|
114 |
+
|
115 |
+
#with K.name_scope('norm_{}'.format(k)):
|
116 |
+
if self.use_batch_norm:
|
117 |
+
self._build_layer(BatchNormalization())
|
118 |
+
elif self.use_layer_norm:
|
119 |
+
self._build_layer(LayerNormalization())
|
120 |
+
elif self.use_weight_norm:
|
121 |
+
pass # done above.
|
122 |
+
|
123 |
+
# with K.name_scope('act_and_dropout_{}'.format(k)):
|
124 |
+
self._build_layer(Activation(self.activation, name='Act_Conv1D_{}'.format(k)))
|
125 |
+
self._build_layer(SpatialDropout1D(rate=self.dropout_rate, name='SDropout_{}'.format(k)))
|
126 |
+
|
127 |
+
if self.nb_filters != input_shape[-1]:
|
128 |
+
# 1x1 conv to match the shapes (channel dimension).
|
129 |
+
name = 'matching_conv1D'
|
130 |
+
#with K.name_scope(name):
|
131 |
+
# make and build this layer separately because it directly uses input_shape.
|
132 |
+
# 1x1 conv.
|
133 |
+
self.shape_match_conv = Conv1D(
|
134 |
+
filters=self.nb_filters,
|
135 |
+
kernel_size=1,
|
136 |
+
padding='same',
|
137 |
+
name=name,
|
138 |
+
kernel_initializer=self.kernel_initializer
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
name = 'matching_identity'
|
142 |
+
self.shape_match_conv = Lambda(lambda x: x, name=name)
|
143 |
+
|
144 |
+
#with K.name_scope(name):
|
145 |
+
self.shape_match_conv.build(input_shape)
|
146 |
+
self.res_output_shape = self.shape_match_conv.compute_output_shape(input_shape)
|
147 |
+
|
148 |
+
self._build_layer(Activation(self.activation, name='Act_Conv_Blocks'))
|
149 |
+
self.final_activation = Activation(self.activation, name='Act_Res_Block')
|
150 |
+
self.final_activation.build(self.res_output_shape) # probably isn't necessary
|
151 |
+
|
152 |
+
# this is done to force Keras to add the layers in the list to self._layers
|
153 |
+
for layer in self.layers:
|
154 |
+
self.__setattr__(layer.name, layer)
|
155 |
+
self.__setattr__(self.shape_match_conv.name, self.shape_match_conv)
|
156 |
+
self.__setattr__(self.final_activation.name, self.final_activation)
|
157 |
+
|
158 |
+
super(ResidualBlock, self).build(input_shape) # done to make sure self.built is set True
|
159 |
+
|
160 |
+
def call(self, inputs, training=None, **kwargs):
|
161 |
+
"""
|
162 |
+
Returns: A tuple where the first element is the residual model tensor, and the second
|
163 |
+
is the skip connection tensor.
|
164 |
+
"""
|
165 |
+
# https://arxiv.org/pdf/1803.01271.pdf page 4, Figure 1 (b).
|
166 |
+
# x1: Dilated Conv -> Norm -> Dropout (x2).
|
167 |
+
# x2: Residual (1x1 matching conv - optional).
|
168 |
+
# Output: x1 + x2.
|
169 |
+
# x1 -> connected to skip connections.
|
170 |
+
# x1 + x2 -> connected to the next block.
|
171 |
+
# input
|
172 |
+
# x1 x2
|
173 |
+
# conv1D 1x1 Conv1D (optional)
|
174 |
+
# ...
|
175 |
+
# conv1D
|
176 |
+
# ...
|
177 |
+
# x1 + x2
|
178 |
+
x1 = inputs
|
179 |
+
for layer in self.layers:
|
180 |
+
training_flag = 'training' in dict(inspect.signature(layer.call).parameters)
|
181 |
+
x1 = layer(x1, training=training) if training_flag else layer(x1)
|
182 |
+
x2 = self.shape_match_conv(inputs)
|
183 |
+
x1_x2 = self.final_activation(layers.add([x2, x1], name='Add_Res'))
|
184 |
+
return [x1_x2, x1]
|
185 |
+
|
186 |
+
def compute_output_shape(self, input_shape):
|
187 |
+
return [self.res_output_shape, self.res_output_shape]
|
188 |
+
|
189 |
+
|
190 |
+
class TCN(Layer):
|
191 |
+
"""Creates a TCN layer.
|
192 |
+
Input shape:
|
193 |
+
A tensor of shape (batch_size, timesteps, input_dim).
|
194 |
+
Args:
|
195 |
+
nb_filters: The number of filters to use in the convolutional layers. Can be a list.
|
196 |
+
kernel_size: The size of the kernel to use in each convolutional layer.
|
197 |
+
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
|
198 |
+
nb_stacks : The number of stacks of residual blocks to use.
|
199 |
+
padding: The padding to use in the convolutional layers, 'causal' or 'same'.
|
200 |
+
use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
|
201 |
+
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
|
202 |
+
activation: The activation used in the residual blocks o = Activation(x + F(x)).
|
203 |
+
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
|
204 |
+
kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
|
205 |
+
use_batch_norm: Whether to use batch normalization in the residual layers or not.
|
206 |
+
use_layer_norm: Whether to use layer normalization in the residual layers or not.
|
207 |
+
use_weight_norm: Whether to use weight normalization in the residual layers or not.
|
208 |
+
kwargs: Any other arguments for configuring parent class Layer. For example "name=str", Name of the model.
|
209 |
+
Use unique names when using multiple TCN.
|
210 |
+
Returns:
|
211 |
+
A TCN layer.
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(self,
|
215 |
+
nb_filters=256,
|
216 |
+
kernel_size=5,
|
217 |
+
nb_stacks=1,
|
218 |
+
dilations=(1, 2, 4, 8, 16, 32),
|
219 |
+
padding='causal',
|
220 |
+
use_skip_connections=True,
|
221 |
+
dropout_rate=0.0,
|
222 |
+
return_sequences=False,
|
223 |
+
activation='relu',
|
224 |
+
kernel_initializer='he_normal',
|
225 |
+
use_batch_norm=False,
|
226 |
+
use_layer_norm=False,
|
227 |
+
use_weight_norm=False,
|
228 |
+
**kwargs):
|
229 |
+
print("nb_filters:", nb_filters, "kernel_size", kernel_size)
|
230 |
+
self.return_sequences = return_sequences
|
231 |
+
self.dropout_rate = dropout_rate
|
232 |
+
self.use_skip_connections = use_skip_connections
|
233 |
+
self.dilations = dilations
|
234 |
+
self.nb_stacks = nb_stacks
|
235 |
+
self.kernel_size = kernel_size
|
236 |
+
self.nb_filters = nb_filters
|
237 |
+
self.activation_name = activation
|
238 |
+
self.padding = padding
|
239 |
+
self.kernel_initializer = kernel_initializer
|
240 |
+
self.use_batch_norm = use_batch_norm
|
241 |
+
self.use_layer_norm = use_layer_norm
|
242 |
+
self.use_weight_norm = use_weight_norm
|
243 |
+
self.skip_connections = []
|
244 |
+
self.residual_blocks = []
|
245 |
+
self.layers_outputs = []
|
246 |
+
self.build_output_shape = None
|
247 |
+
self.slicer_layer = None # in case return_sequence=False
|
248 |
+
self.output_slice_index = None # in case return_sequence=False
|
249 |
+
self.padding_same_and_time_dim_unknown = False # edge case if padding='same' and time_dim = None
|
250 |
+
|
251 |
+
if self.use_batch_norm + self.use_layer_norm + self.use_weight_norm > 1:
|
252 |
+
raise ValueError('Only one normalization can be specified at once.')
|
253 |
+
|
254 |
+
if isinstance(self.nb_filters, list):
|
255 |
+
assert len(self.nb_filters) == len(self.dilations)
|
256 |
+
if len(set(self.nb_filters)) > 1 and self.use_skip_connections:
|
257 |
+
raise ValueError('Skip connections are not compatible '
|
258 |
+
'with a list of filters, unless they are all equal.')
|
259 |
+
|
260 |
+
if padding != 'causal' and padding != 'same':
|
261 |
+
raise ValueError("Only 'causal' or 'same' padding are compatible for this layer.")
|
262 |
+
|
263 |
+
# initialize parent class
|
264 |
+
super(TCN, self).__init__(**kwargs)
|
265 |
+
|
266 |
+
@property
|
267 |
+
def receptive_field(self):
|
268 |
+
return 1 + 2 * (self.kernel_size - 1) * self.nb_stacks * sum(self.dilations)
|
269 |
+
|
270 |
+
def build(self, input_shape):
|
271 |
+
|
272 |
+
# member to hold current output shape of the layer for building purposes
|
273 |
+
self.build_output_shape = input_shape
|
274 |
+
|
275 |
+
# list to hold all the member ResidualBlocks
|
276 |
+
self.residual_blocks = []
|
277 |
+
total_num_blocks = self.nb_stacks * len(self.dilations)
|
278 |
+
if not self.use_skip_connections:
|
279 |
+
total_num_blocks += 1 # cheap way to do a false case for below
|
280 |
+
|
281 |
+
for s in range(self.nb_stacks):
|
282 |
+
for i, d in enumerate(self.dilations):
|
283 |
+
res_block_filters = self.nb_filters[i] if isinstance(self.nb_filters, list) else self.nb_filters
|
284 |
+
self.residual_blocks.append(ResidualBlock(dilation_rate=d,
|
285 |
+
nb_filters=res_block_filters,
|
286 |
+
kernel_size=self.kernel_size,
|
287 |
+
padding=self.padding,
|
288 |
+
activation=self.activation_name,
|
289 |
+
dropout_rate=self.dropout_rate,
|
290 |
+
use_batch_norm=self.use_batch_norm,
|
291 |
+
use_layer_norm=self.use_layer_norm,
|
292 |
+
use_weight_norm=self.use_weight_norm,
|
293 |
+
kernel_initializer=self.kernel_initializer,
|
294 |
+
name='residual_block_{}'.format(len(self.residual_blocks))))
|
295 |
+
# build newest residual block
|
296 |
+
self.residual_blocks[-1].build(self.build_output_shape)
|
297 |
+
self.build_output_shape = self.residual_blocks[-1].res_output_shape
|
298 |
+
|
299 |
+
# this is done to force keras to add the layers in the list to self._layers
|
300 |
+
for layer in self.residual_blocks:
|
301 |
+
self.__setattr__(layer.name, layer)
|
302 |
+
|
303 |
+
self.output_slice_index = None
|
304 |
+
if self.padding == 'same':
|
305 |
+
time = self.build_output_shape.as_list()[1]
|
306 |
+
if time is not None: # if time dimension is defined. e.g. shape = (bs, 500, input_dim).
|
307 |
+
self.output_slice_index = int(self.build_output_shape.as_list()[1] / 2)
|
308 |
+
else:
|
309 |
+
# It will known at call time. c.f. self.call.
|
310 |
+
self.padding_same_and_time_dim_unknown = True
|
311 |
+
|
312 |
+
else:
|
313 |
+
self.output_slice_index = -1 # causal case.
|
314 |
+
self.slicer_layer = Lambda(lambda tt: tt[:, self.output_slice_index, :], name='Slice_Output')
|
315 |
+
|
316 |
+
if type(self.build_output_shape) == tuple:
|
317 |
+
static = list(self.build_output_shape)
|
318 |
+
else:
|
319 |
+
static = self.build_output_shape.as_list()
|
320 |
+
self.slicer_layer.build(static)
|
321 |
+
|
322 |
+
def compute_output_shape(self, input_shape):
|
323 |
+
"""
|
324 |
+
Overridden in case keras uses it somewhere... no idea. Just trying to avoid future errors.
|
325 |
+
"""
|
326 |
+
if not self.built:
|
327 |
+
self.build(input_shape)
|
328 |
+
if not self.return_sequences:
|
329 |
+
batch_size = self.build_output_shape[0]
|
330 |
+
batch_size = batch_size.value if hasattr(batch_size, 'value') else batch_size
|
331 |
+
nb_filters = self.build_output_shape[-1]
|
332 |
+
return [batch_size, nb_filters]
|
333 |
+
else:
|
334 |
+
# Compatibility tensorflow 1.x
|
335 |
+
return [v.value if hasattr(v, 'value') else v for v in self.build_output_shape]
|
336 |
+
|
337 |
+
def call(self, inputs, training=None, **kwargs):
|
338 |
+
x = inputs
|
339 |
+
self.layers_outputs = [x]
|
340 |
+
self.skip_connections = []
|
341 |
+
for res_block in self.residual_blocks:
|
342 |
+
# try:
|
343 |
+
# x, skip_out = res_block(x, training=training)
|
344 |
+
# except TypeError: # compatibility with tensorflow 1.x
|
345 |
+
# x, skip_out = res_block(K.cast(x, 'float32'), training=training)
|
346 |
+
x, skip_out = res_block(x, training=training)
|
347 |
+
|
348 |
+
self.skip_connections.append(skip_out)
|
349 |
+
self.layers_outputs.append(x)
|
350 |
+
|
351 |
+
if self.use_skip_connections:
|
352 |
+
x = layers.add(self.skip_connections, name='Add_Skip_Connections')
|
353 |
+
self.layers_outputs.append(x)
|
354 |
+
|
355 |
+
if not self.return_sequences:
|
356 |
+
# case: time dimension is unknown. e.g. (bs, None, input_dim).
|
357 |
+
if self.padding_same_and_time_dim_unknown:
|
358 |
+
self.output_slice_index = K.shape(self.layers_outputs[-1])[1] // 2
|
359 |
+
x = self.slicer_layer(x)
|
360 |
+
self.layers_outputs.append(x)
|
361 |
+
return x
|
362 |
+
|
363 |
+
def get_config(self):
|
364 |
+
"""
|
365 |
+
Returns the config of a the layer. This is used for saving and loading from a model
|
366 |
+
:return: python dictionary with specs to rebuild layer
|
367 |
+
"""
|
368 |
+
config = super(TCN, self).get_config()
|
369 |
+
config['nb_filters'] = self.nb_filters
|
370 |
+
config['kernel_size'] = self.kernel_size
|
371 |
+
config['nb_stacks'] = self.nb_stacks
|
372 |
+
config['dilations'] = self.dilations
|
373 |
+
config['padding'] = self.padding
|
374 |
+
config['use_skip_connections'] = self.use_skip_connections
|
375 |
+
config['dropout_rate'] = self.dropout_rate
|
376 |
+
config['return_sequences'] = self.return_sequences
|
377 |
+
config['activation'] = self.activation_name
|
378 |
+
config['use_batch_norm'] = self.use_batch_norm
|
379 |
+
config['use_layer_norm'] = self.use_layer_norm
|
380 |
+
config['use_weight_norm'] = self.use_weight_norm
|
381 |
+
config['kernel_initializer'] = self.kernel_initializer
|
382 |
+
return config
|
383 |
+
|
384 |
+
|
385 |
+
def compiled_tcn(num_feat, # type: int
|
386 |
+
num_classes, # type: int
|
387 |
+
nb_filters, # type: int
|
388 |
+
kernel_size, # type: int
|
389 |
+
dilations, # type: List[int]
|
390 |
+
nb_stacks, # type: int
|
391 |
+
max_len, # type: int
|
392 |
+
output_len=1, # type: int
|
393 |
+
padding='causal', # type: str
|
394 |
+
use_skip_connections=False, # type: bool
|
395 |
+
return_sequences=True,
|
396 |
+
regression=False, # type: bool
|
397 |
+
dropout_rate=0.05, # type: float
|
398 |
+
name='tcn', # type: str,
|
399 |
+
kernel_initializer='he_normal', # type: str,
|
400 |
+
activation='relu', # type:str,
|
401 |
+
opt='adam',
|
402 |
+
lr=0.002,
|
403 |
+
use_batch_norm=False,
|
404 |
+
use_layer_norm=False,
|
405 |
+
use_weight_norm=False):
|
406 |
+
# type: (...) -> Model
|
407 |
+
"""Creates a compiled TCN model for a given task (i.e. regression or classification).
|
408 |
+
Classification uses a sparse categorical loss. Please input class ids and not one-hot encodings.
|
409 |
+
Args:
|
410 |
+
num_feat: The number of features of your input, i.e. the last dimension of: (batch_size, timesteps, input_dim).
|
411 |
+
num_classes: The size of the final dense layer, how many classes we are predicting.
|
412 |
+
nb_filters: The number of filters to use in the convolutional layers.
|
413 |
+
kernel_size: The size of the kernel to use in each convolutional layer.
|
414 |
+
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
|
415 |
+
nb_stacks : The number of stacks of residual blocks to use.
|
416 |
+
max_len: The maximum sequence length, use None if the sequence length is dynamic.
|
417 |
+
padding: The padding to use in the convolutional layers.
|
418 |
+
use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
|
419 |
+
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
|
420 |
+
regression: Whether the output should be continuous or discrete.
|
421 |
+
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
|
422 |
+
activation: The activation used in the residual blocks o = Activation(x + F(x)).
|
423 |
+
name: Name of the model. Useful when having multiple TCN.
|
424 |
+
kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
|
425 |
+
opt: Optimizer name.
|
426 |
+
lr: Learning rate.
|
427 |
+
use_batch_norm: Whether to use batch normalization in the residual layers or not.
|
428 |
+
use_layer_norm: Whether to use layer normalization in the residual layers or not.
|
429 |
+
use_weight_norm: Whether to use weight normalization in the residual layers or not.
|
430 |
+
Returns:
|
431 |
+
A compiled keras TCN.
|
432 |
+
"""
|
433 |
+
|
434 |
+
dilations = adjust_dilations(dilations)
|
435 |
+
|
436 |
+
input_layer = Input(shape=(max_len, num_feat))
|
437 |
+
|
438 |
+
x = TCN(nb_filters, kernel_size, nb_stacks, dilations, padding,
|
439 |
+
use_skip_connections, dropout_rate, return_sequences,
|
440 |
+
activation, kernel_initializer, use_batch_norm, use_layer_norm,
|
441 |
+
use_weight_norm, name=name)(input_layer)
|
442 |
+
|
443 |
+
print('x.shape=', x.shape)
|
444 |
+
|
445 |
+
def get_opt():
|
446 |
+
if opt == 'adam':
|
447 |
+
return optimizers.Adam(lr=lr, clipnorm=1.)
|
448 |
+
elif opt == 'rmsprop':
|
449 |
+
return optimizers.RMSprop(lr=lr, clipnorm=1.)
|
450 |
+
else:
|
451 |
+
raise Exception('Only Adam and RMSProp are available here')
|
452 |
+
|
453 |
+
if not regression:
|
454 |
+
# classification
|
455 |
+
print('asdasfdasfa')
|
456 |
+
x = Dense(num_classes)(x)
|
457 |
+
x = Activation('softmax')(x)
|
458 |
+
output_layer = x
|
459 |
+
model = Model(input_layer, output_layer)
|
460 |
+
|
461 |
+
# https://github.com/keras-team/keras/pull/11373
|
462 |
+
# It's now in Keras@master but still not available with pip.
|
463 |
+
# TODO remove later.
|
464 |
+
def accuracy(y_true, y_pred):
|
465 |
+
# reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
|
466 |
+
if K.ndim(y_true) == K.ndim(y_pred):
|
467 |
+
y_true = K.squeeze(y_true, -1)
|
468 |
+
# convert dense predictions to labels
|
469 |
+
y_pred_labels = K.argmax(y_pred, axis=-1)
|
470 |
+
y_pred_labels = K.cast(y_pred_labels, KK.floatx())
|
471 |
+
return K.cast(K.equal(y_true, y_pred_labels), KK.floatx())
|
472 |
+
|
473 |
+
model.compile(get_opt(), loss='sparse_categorical_crossentropy', metrics=[accuracy])
|
474 |
+
else:
|
475 |
+
# regression
|
476 |
+
x = Dense(output_len)(x)
|
477 |
+
x = Activation('linear')(x)
|
478 |
+
output_layer = x
|
479 |
+
model = Model(input_layer, output_layer)
|
480 |
+
model.compile(get_opt(), loss='mean_squared_error')
|
481 |
+
print('model.x = {}'.format(input_layer.shape))
|
482 |
+
print('model.y = {}'.format(output_layer.shape))
|
483 |
+
return model
|
484 |
+
|
485 |
+
|
486 |
+
def tcn_full_summary(model: Model, expand_residual_blocks=True):
|
487 |
+
|
488 |
+
layers = model._layers.copy() # store existing layers
|
489 |
+
model._layers.clear() # clear layers
|
490 |
+
|
491 |
+
for i in range(len(layers)):
|
492 |
+
if isinstance(layers[i], TCN):
|
493 |
+
for layer in layers[i]._layers:
|
494 |
+
if not isinstance(layer, ResidualBlock):
|
495 |
+
if not hasattr(layer, '__iter__'):
|
496 |
+
model._layers.append(layer)
|
497 |
+
else:
|
498 |
+
if expand_residual_blocks:
|
499 |
+
for lyr in layer._layers:
|
500 |
+
if not hasattr(lyr, '__iter__'):
|
501 |
+
model._layers.append(lyr)
|
502 |
+
else:
|
503 |
+
model._layers.append(layer)
|
504 |
+
else:
|
505 |
+
model._layers.append(layers[i])
|
506 |
+
|
507 |
+
model.summary() # print summary
|
508 |
+
|
509 |
+
# restore original layers
|
510 |
+
model._layers.clear()
|
511 |
+
[model._layers.append(lyr) for lyr in layers]
|
pages/inference.py
ADDED
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import random
|
8 |
+
from sklearn.preprocessing import MinMaxScaler
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
from model.lstm import LSTMModel
|
13 |
+
from model.tcn import TCNModel
|
14 |
+
from model.tcn import move_custom_layers_to_device
|
15 |
+
from utils.lowlevel import LowLevel
|
16 |
+
from utils.highlevel import HighLevel
|
17 |
+
from utils.midpoint import MidPoint
|
18 |
+
|
19 |
+
from utils.transform import compute_gradient
|
20 |
+
|
21 |
+
st.set_page_config(page_title="Inference", page_icon=":chart_with_upwards_trend:", layout="wide", initial_sidebar_state="auto")
|
22 |
+
|
23 |
+
def uniform_sampling(data, n_sample):
|
24 |
+
k = len(data) // n_sample
|
25 |
+
return data[::k]
|
26 |
+
|
27 |
+
def low_level(option_time, slider_sample_orbit, progress_bar):
|
28 |
+
time.sleep(0.1)
|
29 |
+
low_level_total_start_time = time.time()
|
30 |
+
low_level_30000_start_time = time.time()
|
31 |
+
|
32 |
+
lowlevelhelper = LowLevel(j=slider_sample_orbit)
|
33 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
|
34 |
+
|
35 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
36 |
+
a2 = 1 - 2 * a1
|
37 |
+
jn = 0
|
38 |
+
t = 0.1
|
39 |
+
|
40 |
+
# Calculate the total number of iterations for the progress bar update
|
41 |
+
total_iterations = (float(option_time) - t) / h
|
42 |
+
current_iteration = 0
|
43 |
+
|
44 |
+
original_low_level_data = []
|
45 |
+
|
46 |
+
while t < float(option_time):
|
47 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
48 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
49 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
50 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
51 |
+
|
52 |
+
t = t + h
|
53 |
+
|
54 |
+
if jn % 10 == 0:
|
55 |
+
original_low_level_data.append([b, x, y, z, px, py, pz])
|
56 |
+
# Update progress bar
|
57 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
58 |
+
progress_bar.progress(progress_percentage)
|
59 |
+
|
60 |
+
if jn == 300000:
|
61 |
+
low_level_30000_end_time = time.time()
|
62 |
+
low_level_30000_execute_time = low_level_30000_end_time - low_level_30000_start_time
|
63 |
+
low_level_2000_start_time = time.time()
|
64 |
+
jn = jn + 1
|
65 |
+
current_iteration += 1
|
66 |
+
|
67 |
+
progress_bar.progress(100)
|
68 |
+
|
69 |
+
low_level_2000_end_time = time.time()
|
70 |
+
low_level_2000_execute_time = low_level_2000_end_time - low_level_2000_start_time
|
71 |
+
low_level_total_end_time = time.time()
|
72 |
+
low_level_total_execute_time = low_level_total_end_time - low_level_total_start_time
|
73 |
+
|
74 |
+
result = uniform_sampling(np.array(original_low_level_data), n_sample=int(option_time/100))
|
75 |
+
|
76 |
+
return low_level_30000_execute_time, low_level_2000_execute_time, low_level_total_execute_time, result
|
77 |
+
|
78 |
+
def high_level(option_time, slider_sample_orbit, progress_bar):
|
79 |
+
time.sleep(0.1)
|
80 |
+
high_level_total_start_time = time.time()
|
81 |
+
high_level_30000_start_time = time.time()
|
82 |
+
|
83 |
+
highlevelhelper = HighLevel(j=slider_sample_orbit)
|
84 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = highlevelhelper.initial()
|
85 |
+
|
86 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
87 |
+
a2 = 1 - 2 * a1
|
88 |
+
jn = 0
|
89 |
+
t = 0.1
|
90 |
+
|
91 |
+
# Calculate the total number of iterations for the progress bar update
|
92 |
+
total_iterations = (float(option_time) - t) / h
|
93 |
+
current_iteration = 0
|
94 |
+
|
95 |
+
original_high_level_data = []
|
96 |
+
|
97 |
+
while t < float(option_time):
|
98 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
99 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
100 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
101 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
102 |
+
|
103 |
+
t = t + h
|
104 |
+
vx, vy, vz, vpx, vpy, vpz, e = highlevelhelper.f(x, y, z, px, py, pz, b)
|
105 |
+
en = np.asarray(e).astype(np.float64)
|
106 |
+
|
107 |
+
if jn % 10 == 0:
|
108 |
+
original_high_level_data.append([b, x, y, z, px, py, pz])
|
109 |
+
if jn == 300000:
|
110 |
+
high_level_30000_end_time = time.time()
|
111 |
+
high_level_30000_execute_time = high_level_30000_end_time - high_level_30000_start_time
|
112 |
+
high_level_2000_start_time = time.time()
|
113 |
+
jn = jn + 1
|
114 |
+
|
115 |
+
# Update progress bar
|
116 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
117 |
+
progress_bar.progress(progress_percentage)
|
118 |
+
current_iteration += 1
|
119 |
+
|
120 |
+
progress_bar.progress(100)
|
121 |
+
high_level_2000_end_time = time.time()
|
122 |
+
high_level_2000_execute_time = high_level_2000_end_time - high_level_2000_start_time
|
123 |
+
high_level_total_end_time = time.time()
|
124 |
+
high_level_total_execute_time = high_level_total_end_time - high_level_total_start_time
|
125 |
+
|
126 |
+
result = uniform_sampling(np.array(original_high_level_data), n_sample=int(option_time / 100))
|
127 |
+
|
128 |
+
return high_level_30000_execute_time, high_level_2000_execute_time, high_level_total_execute_time, result
|
129 |
+
|
130 |
+
def midpoint(option_time, slider_sample_orbit, progress_bar):
|
131 |
+
time.sleep(0.1)
|
132 |
+
mid_point_total_start_time = time.time()
|
133 |
+
mid_point_30000_start_time = time.time()
|
134 |
+
|
135 |
+
midpointhelper = MidPoint(j=slider_sample_orbit)
|
136 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
|
137 |
+
|
138 |
+
#en0 = np.asarray(e0).astype(np.float64)
|
139 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
140 |
+
a2 = 1 - 2 * a1
|
141 |
+
jn = 0
|
142 |
+
t = 0.1
|
143 |
+
|
144 |
+
# Calculate the total number of iterations for the progress bar update
|
145 |
+
total_iterations = (float(option_time) - t) / h
|
146 |
+
current_iteration = 0
|
147 |
+
|
148 |
+
original_mid_point_data = []
|
149 |
+
|
150 |
+
while t < float(option_time):
|
151 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
152 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
153 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
154 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
155 |
+
|
156 |
+
t = t + h
|
157 |
+
|
158 |
+
if jn % 10 == 0:
|
159 |
+
original_mid_point_data.append([b, x, y, z, px, py, pz])
|
160 |
+
if jn == 300000:
|
161 |
+
mid_point_30000_end_time = time.time()
|
162 |
+
mid_point_30000_execute_time = mid_point_30000_end_time - mid_point_30000_start_time
|
163 |
+
mid_point_2000_start_time = time.time()
|
164 |
+
jn = jn + 1
|
165 |
+
|
166 |
+
# Update progress bar
|
167 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
168 |
+
progress_bar.progress(progress_percentage)
|
169 |
+
current_iteration += 1
|
170 |
+
|
171 |
+
#mid_point_df.to_excel('mid_point_df_output.xlsx', index=False)
|
172 |
+
progress_bar.progress(100)
|
173 |
+
mid_point_2000_end_time = time.time()
|
174 |
+
mid_point_2000_execute_time = mid_point_2000_end_time - mid_point_2000_start_time
|
175 |
+
mid_point_total_end_time = time.time()
|
176 |
+
mid_point_total_execute_time = mid_point_total_end_time - mid_point_total_start_time
|
177 |
+
|
178 |
+
result = uniform_sampling(np.array(original_mid_point_data), n_sample=int(option_time / 100))
|
179 |
+
|
180 |
+
return mid_point_30000_execute_time, mid_point_2000_execute_time, mid_point_total_execute_time, result
|
181 |
+
|
182 |
+
def low_level_lstm(slider_sample_orbit, lstm_progress_bar):
|
183 |
+
time.sleep(0.1)
|
184 |
+
total_start_time = time.time()
|
185 |
+
|
186 |
+
lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
|
187 |
+
lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
|
188 |
+
lstm_model.to("cpu")
|
189 |
+
lstm_model.eval()
|
190 |
+
|
191 |
+
# Initialize variables for the classical method
|
192 |
+
lowlevelhelper = LowLevel(j=slider_sample_orbit)
|
193 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
|
194 |
+
|
195 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
196 |
+
a2 = 1 - 2 * a1
|
197 |
+
jn = 0
|
198 |
+
t = 0.1
|
199 |
+
|
200 |
+
# Calculate the total number of iterations for the progress bar update
|
201 |
+
total_iterations = (float(30000) - t) / h
|
202 |
+
current_iteration = 0
|
203 |
+
|
204 |
+
original_low_level_data = []
|
205 |
+
|
206 |
+
low_level_start_time = time.time()
|
207 |
+
|
208 |
+
# Perform classical method prediction for the initial segment
|
209 |
+
while t < float(30000):
|
210 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
211 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
212 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
213 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
214 |
+
t = t + h
|
215 |
+
|
216 |
+
if jn % 10 == 0:
|
217 |
+
original_low_level_data.append([b, x, y, z, px, py, pz])
|
218 |
+
# Update progress bar
|
219 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
220 |
+
lstm_progress_bar.progress(progress_percentage)
|
221 |
+
|
222 |
+
jn = jn + 1
|
223 |
+
current_iteration += 1
|
224 |
+
|
225 |
+
original_low_level_data = np.array(original_low_level_data)
|
226 |
+
low_level_end_time = time.time()
|
227 |
+
low_level_data = original_low_level_data.copy()
|
228 |
+
low_level_data = uniform_sampling(low_level_data, n_sample=300)
|
229 |
+
scaler = MinMaxScaler()
|
230 |
+
low_level_data = scaler.fit_transform(low_level_data)
|
231 |
+
low_level_data = torch.tensor(np.stack(low_level_data)).float()
|
232 |
+
low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)
|
233 |
+
|
234 |
+
lstm_start_time = time.time()
|
235 |
+
with torch.no_grad():
|
236 |
+
lstm_preds = lstm_model(low_level_data[:, 100:300, :])
|
237 |
+
lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
|
238 |
+
|
239 |
+
original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)
|
240 |
+
|
241 |
+
lstm_end_time = time.time()
|
242 |
+
lstm_progress_bar.progress(100)
|
243 |
+
|
244 |
+
combined_preds = np.concatenate([original_low_level_data, lstm_innv_preds], axis=0)
|
245 |
+
|
246 |
+
lstm_total_time = lstm_end_time - lstm_start_time
|
247 |
+
low_level_total_time = low_level_end_time - low_level_start_time
|
248 |
+
|
249 |
+
total_end_time = time.time()
|
250 |
+
total_time = total_end_time - total_start_time
|
251 |
+
|
252 |
+
return low_level_total_time, lstm_total_time, total_time, combined_preds
|
253 |
+
|
254 |
+
def mid_point_lstm(slider_sample_orbit, lstm_progress_bar):
|
255 |
+
time.sleep(0.1)
|
256 |
+
total_start_time = time.time()
|
257 |
+
|
258 |
+
lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
|
259 |
+
lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
|
260 |
+
lstm_model.to("cpu")
|
261 |
+
lstm_model.eval()
|
262 |
+
|
263 |
+
midpointhelper = MidPoint(j=slider_sample_orbit)
|
264 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
|
265 |
+
|
266 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
267 |
+
a2 = 1 - 2 * a1
|
268 |
+
jn = 0
|
269 |
+
t = 0.1
|
270 |
+
|
271 |
+
# Calculate the total number of iterations for the progress bar update
|
272 |
+
total_iterations = (float(30000) - t) / h
|
273 |
+
current_iteration = 0
|
274 |
+
|
275 |
+
original_mid_point_data = []
|
276 |
+
|
277 |
+
mid_point_start_time = time.time()
|
278 |
+
|
279 |
+
while t < float(30000):
|
280 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
281 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
282 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
283 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
284 |
+
|
285 |
+
t = t + h
|
286 |
+
|
287 |
+
if jn % 10 == 0:
|
288 |
+
original_mid_point_data.append([b, x, y, z, px, py, pz])
|
289 |
+
# Update progress bar
|
290 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
291 |
+
lstm_progress_bar.progress(progress_percentage)
|
292 |
+
jn = jn + 1
|
293 |
+
current_iteration += 1
|
294 |
+
|
295 |
+
original_mid_point_data = np.array(original_mid_point_data)
|
296 |
+
mid_point_end_time = time.time()
|
297 |
+
mid_point_data = original_mid_point_data.copy()
|
298 |
+
mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
|
299 |
+
scaler = MinMaxScaler()
|
300 |
+
mid_point_data = scaler.fit_transform(mid_point_data)
|
301 |
+
mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
|
302 |
+
mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)
|
303 |
+
|
304 |
+
lstm_start_time = time.time()
|
305 |
+
with torch.no_grad():
|
306 |
+
lstm_preds = lstm_model(mid_point_data[:, 100:300, :])
|
307 |
+
lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
|
308 |
+
|
309 |
+
original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)
|
310 |
+
|
311 |
+
lstm_end_time = time.time()
|
312 |
+
lstm_progress_bar.progress(100)
|
313 |
+
|
314 |
+
combined_preds = np.concatenate([original_mid_point_data, lstm_innv_preds], axis=0)
|
315 |
+
|
316 |
+
lstm_total_time = lstm_end_time - lstm_start_time
|
317 |
+
mid_point_total_time = mid_point_end_time - mid_point_start_time
|
318 |
+
|
319 |
+
total_end_time = time.time()
|
320 |
+
total_time = total_end_time - total_start_time
|
321 |
+
|
322 |
+
return mid_point_total_time, lstm_total_time, total_time, combined_preds
|
323 |
+
|
324 |
+
def low_level_tcn(slider_sample_orbit, tcn_progress_bar):
|
325 |
+
time.sleep(0.1)
|
326 |
+
total_start_time = time.time()
|
327 |
+
|
328 |
+
tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
|
329 |
+
tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
|
330 |
+
move_custom_layers_to_device(tcn_model, "cpu")
|
331 |
+
tcn_model.eval()
|
332 |
+
|
333 |
+
# Initialize variables for the classical method
|
334 |
+
lowlevelhelper = LowLevel(j=slider_sample_orbit)
|
335 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
|
336 |
+
|
337 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
338 |
+
a2 = 1 - 2 * a1
|
339 |
+
jn = 0
|
340 |
+
t = 0.1
|
341 |
+
|
342 |
+
# Calculate the total number of iterations for the progress bar update
|
343 |
+
total_iterations = (float(30000) - t) / h
|
344 |
+
current_iteration = 0
|
345 |
+
|
346 |
+
original_low_level_data = []
|
347 |
+
|
348 |
+
low_level_start_time = time.time()
|
349 |
+
|
350 |
+
# Perform classical method prediction for the initial segment
|
351 |
+
while t < float(30000):
|
352 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
353 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
354 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
355 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza)
|
356 |
+
|
357 |
+
t = t + h
|
358 |
+
|
359 |
+
if jn % 10 == 0:
|
360 |
+
original_low_level_data.append([b, x, y, z, px, py, pz])
|
361 |
+
# Update progress bar
|
362 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
363 |
+
tcn_progress_bar.progress(progress_percentage)
|
364 |
+
|
365 |
+
jn = jn + 1
|
366 |
+
current_iteration += 1
|
367 |
+
|
368 |
+
original_low_level_data = np.array(original_low_level_data)
|
369 |
+
low_level_end_time = time.time()
|
370 |
+
low_level_data = original_low_level_data.copy()
|
371 |
+
low_level_data = uniform_sampling(low_level_data, n_sample=300)
|
372 |
+
scaler = MinMaxScaler()
|
373 |
+
low_level_data = scaler.fit_transform(low_level_data)
|
374 |
+
low_level_data = torch.tensor(np.stack(low_level_data)).float()
|
375 |
+
low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)
|
376 |
+
|
377 |
+
tcn_start_time = time.time()
|
378 |
+
with torch.no_grad():
|
379 |
+
tcn_preds = None
|
380 |
+
for i in range(20):
|
381 |
+
if i == 0:
|
382 |
+
tcn_preds = tcn_model(low_level_data[:, :300, :])
|
383 |
+
else:
|
384 |
+
gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
|
385 |
+
output = tcn_model(torch.cat([low_level_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
|
386 |
+
tcn_preds = torch.cat([tcn_preds, output], dim=1)
|
387 |
+
tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
|
388 |
+
|
389 |
+
original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)
|
390 |
+
|
391 |
+
tcn_end_time = time.time()
|
392 |
+
tcn_progress_bar.progress(100)
|
393 |
+
|
394 |
+
combined_preds = np.concatenate([original_low_level_data, tcn_innv_preds], axis=0)
|
395 |
+
|
396 |
+
tcn_total_time = tcn_end_time - tcn_start_time
|
397 |
+
low_level_total_time = low_level_end_time - low_level_start_time
|
398 |
+
|
399 |
+
total_end_time = time.time()
|
400 |
+
total_time = total_end_time - total_start_time
|
401 |
+
|
402 |
+
return low_level_total_time, tcn_total_time, total_time, combined_preds
|
403 |
+
|
404 |
+
def mid_point_tcn(slider_sample_orbit, tcn_progress_bar):
|
405 |
+
time.sleep(0.1)
|
406 |
+
total_start_time = time.time()
|
407 |
+
|
408 |
+
tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
|
409 |
+
tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
|
410 |
+
move_custom_layers_to_device(tcn_model, "cpu")
|
411 |
+
tcn_model.eval()
|
412 |
+
|
413 |
+
# Initialize variables for the classical method
|
414 |
+
midpointhelper = MidPoint(j=slider_sample_orbit)
|
415 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
|
416 |
+
|
417 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
418 |
+
a2 = 1 - 2 * a1
|
419 |
+
jn = 0
|
420 |
+
t = 0.1
|
421 |
+
|
422 |
+
# Calculate the total number of iterations for the progress bar update
|
423 |
+
total_iterations = (float(30000) - t) / h
|
424 |
+
current_iteration = 0
|
425 |
+
|
426 |
+
original_mid_point_data = []
|
427 |
+
|
428 |
+
mid_point_start_time = time.time()
|
429 |
+
|
430 |
+
# Perform classical method prediction for the initial segment
|
431 |
+
while t < float(30000):
|
432 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
433 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
434 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
435 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza)
|
436 |
+
|
437 |
+
t = t + h
|
438 |
+
|
439 |
+
if jn % 10 == 0:
|
440 |
+
original_mid_point_data.append([b, x, y, z, px, py, pz])
|
441 |
+
# Update progress bar
|
442 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
443 |
+
tcn_progress_bar.progress(progress_percentage)
|
444 |
+
jn = jn + 1
|
445 |
+
current_iteration += 1
|
446 |
+
|
447 |
+
original_mid_point_data = np.array(original_mid_point_data)
|
448 |
+
mid_point_end_time = time.time()
|
449 |
+
mid_point_data = original_mid_point_data.copy()
|
450 |
+
mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
|
451 |
+
scaler = MinMaxScaler()
|
452 |
+
mid_point_data = scaler.fit_transform(mid_point_data)
|
453 |
+
mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
|
454 |
+
mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)
|
455 |
+
|
456 |
+
tcn_start_time = time.time()
|
457 |
+
with torch.no_grad():
|
458 |
+
tcn_preds = None
|
459 |
+
for i in range(20):
|
460 |
+
if i == 0:
|
461 |
+
tcn_preds = tcn_model(mid_point_data[:, :300, :])
|
462 |
+
else:
|
463 |
+
gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
|
464 |
+
output = tcn_model(torch.cat([mid_point_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
|
465 |
+
tcn_preds = torch.cat([tcn_preds, output], dim=1)
|
466 |
+
tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
|
467 |
+
|
468 |
+
original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)
|
469 |
+
|
470 |
+
tcn_end_time = time.time()
|
471 |
+
tcn_progress_bar.progress(100)
|
472 |
+
|
473 |
+
combined_preds = np.concatenate([original_mid_point_data, tcn_innv_preds], axis=0)
|
474 |
+
|
475 |
+
tcn_total_time = tcn_end_time - tcn_start_time
|
476 |
+
mid_point_total_time = mid_point_end_time - mid_point_start_time
|
477 |
+
|
478 |
+
total_end_time = time.time()
|
479 |
+
total_time = total_end_time - total_start_time
|
480 |
+
|
481 |
+
return mid_point_total_time, tcn_total_time, total_time, combined_preds
|
482 |
+
|
483 |
+
container = st.container()
|
484 |
+
container1, container2 = st.columns(2)
|
485 |
+
plot_container = st.container()
|
486 |
+
|
487 |
+
with st.sidebar:
|
488 |
+
slider_sample_orbit = st.slider('Orbit Sample ID', 1, 10, 1)
|
489 |
+
option_time = 32000
|
490 |
+
st.write(f'Total Time Step: {option_time}')
|
491 |
+
options_method = st.multiselect(
|
492 |
+
'Compared Methods',
|
493 |
+
['Low-Level', 'High-Level', 'Midpoint', 'Low-Level with LSTM', 'Low-Level with TCN', 'Midpoint with LSTM', 'Midpoint with TCN'],
|
494 |
+
['Low-Level'])
|
495 |
+
btn_go = st.button("Go", type="primary", use_container_width=True)
|
496 |
+
|
497 |
+
if btn_go:
|
498 |
+
if 'Low-Level' in options_method:
|
499 |
+
with container1:
|
500 |
+
st.write('Low Level Progress Bar')
|
501 |
+
low_level_progress_bar = st.progress(0)
|
502 |
+
low_level_30000_time, low_level_2000_time, low_level_total_time, low_level_result = low_level(option_time, slider_sample_orbit, low_level_progress_bar)
|
503 |
+
with container2:
|
504 |
+
st.table(pd.DataFrame({'Model':"Low Level", '30000 Time Steps (s)': [low_level_30000_time], '2000 Time Steps (s)': [low_level_2000_time], 'Total Time (s)': [low_level_total_time]}))
|
505 |
+
if 'High-Level' in options_method:
|
506 |
+
with container1:
|
507 |
+
st.write('High Level Progress Bar')
|
508 |
+
high_level_progress_bar = st.progress(0)
|
509 |
+
high_level_30000_time, high_level_2000_time, high_level_total_time, high_level_result = high_level(option_time, slider_sample_orbit, high_level_progress_bar)
|
510 |
+
with container2:
|
511 |
+
st.table(pd.DataFrame({'Model':"High Level", '30000 Time Steps (s)': [high_level_30000_time], '2000 Time Steps (s)': [high_level_2000_time], 'Total Time (s)': [high_level_total_time]}))
|
512 |
+
if 'Midpoint' in options_method:
|
513 |
+
with container1:
|
514 |
+
st.write('Midpoint Progress Bar')
|
515 |
+
mid_point_progress_bar = st.progress(0)
|
516 |
+
mid_point_30000_time, mid_point_2000_time, mid_point_total_time, mid_point_result = midpoint(option_time, slider_sample_orbit, mid_point_progress_bar)
|
517 |
+
with container2:
|
518 |
+
st.table(pd.DataFrame({'Model':"Midpoint", '30000 Time Steps (s)': [mid_point_30000_time], '2000 Time Steps (s)': [mid_point_2000_time], 'Total Time (s)': [mid_point_total_time]}))
|
519 |
+
if 'Low-Level with LSTM' in options_method:
|
520 |
+
with container1:
|
521 |
+
st.write('Low Level LSTM Progress Bar')
|
522 |
+
low_level_lstm_progress_bar = st.progress(0)
|
523 |
+
lstm_30000_time, lstm_2000_time, lstm_total_time, lstm_result = low_level_lstm(slider_sample_orbit, low_level_lstm_progress_bar)
|
524 |
+
with container2:
|
525 |
+
st.table(pd.DataFrame({'Model':"Low Level + LSTM", '30000 Time Steps (s)': [lstm_30000_time], '2000 Time Steps (s)': [lstm_2000_time], 'Total Time (s)': [lstm_total_time]}))
|
526 |
+
if 'Low-Level with TCN' in options_method:
|
527 |
+
with container1:
|
528 |
+
st.write('Low Level TCN Progress Bar')
|
529 |
+
low_level_tcn_progress_bar = st.progress(0)
|
530 |
+
tcn_30000_time, tcn_2000_time, tcn_total_time, tcn_result = low_level_tcn(slider_sample_orbit, low_level_tcn_progress_bar)
|
531 |
+
with container2:
|
532 |
+
st.table(pd.DataFrame({'Model':"Low Level + TCN", '30000 Time Steps (s)': [tcn_30000_time], '2000 Time Steps (s)': [tcn_2000_time], 'Total Time (s)': [tcn_total_time]}))
|
533 |
+
if 'Midpoint with LSTM' in options_method:
|
534 |
+
with container1:
|
535 |
+
st.write('Midpoint LSTM Progress Bar')
|
536 |
+
mid_point_lstm_progress_bar = st.progress(0)
|
537 |
+
md_lstm_30000_time, md_lstm_2000_time, md_lstm_total_time, md_lstm_result = mid_point_lstm(slider_sample_orbit, mid_point_lstm_progress_bar)
|
538 |
+
with container2:
|
539 |
+
st.table(pd.DataFrame({'Model':"Midpoint + LSTM", '30000 Time Steps (s)': [md_lstm_30000_time], '2000 Time Steps (s)': [md_lstm_2000_time], 'Total Time (s)': [md_lstm_total_time]}))
|
540 |
+
if 'Midpoint with TCN' in options_method:
|
541 |
+
with container1:
|
542 |
+
st.write('Midpoint TCN Progress Bar')
|
543 |
+
mid_point_tcn_progress_bar = st.progress(0)
|
544 |
+
md_tcn_30000_time, md_tcn_2000_time, md_tcn_total_time, md_tcn_result = mid_point_tcn(slider_sample_orbit, mid_point_tcn_progress_bar)
|
545 |
+
with container2:
|
546 |
+
st.table(pd.DataFrame({'Model':"Midpoint + TCN", '30000 Time Steps (s)': [md_tcn_30000_time], '2000 Time Steps (s)': [md_tcn_2000_time], 'Total Time (s)': [md_tcn_total_time]}))
|
547 |
+
|
prediction.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from sklearn.preprocessing import MinMaxScaler
|
6 |
+
import torch
|
7 |
+
import time
|
8 |
+
|
9 |
+
from utils.transform import compute_gradient
|
10 |
+
from model.lstm import LSTMModel
|
11 |
+
from model.tcn import TCNModel
|
12 |
+
from model.tcn import move_custom_layers_to_device
|
13 |
+
from utils.metrics import calculate_metrics
|
14 |
+
|
15 |
+
each_feature_name = ["q_1","q_2","q_3","p_1","p_2","p_3"]
|
16 |
+
|
17 |
+
def uniform_sampling(data, n_sample):
|
18 |
+
k = len(data) // n_sample
|
19 |
+
return data[::k]
|
20 |
+
|
21 |
+
st.set_page_config(page_title="Prediction", page_icon=":chart_with_upwards_trend:", layout="wide", initial_sidebar_state="auto")
|
22 |
+
|
23 |
+
#st.title("Prediction")
|
24 |
+
|
25 |
+
with st.sidebar:
|
26 |
+
slider_predict_step = st.slider('Predicted Step', 0, 20, 20)
|
27 |
+
|
28 |
+
number_input_sample_id = st.number_input("Select Sample ID 1~10", value=1, placeholder="Type a number...", min_value=1, max_value=10, step=1)
|
29 |
+
|
30 |
+
squences_start_idx = st.slider('Squences Start Index', 0, 700 - slider_predict_step, 0)
|
31 |
+
|
32 |
+
st.subheader("Model Configuration")
|
33 |
+
st.write("LSTM Window Size: ", 200)
|
34 |
+
st.write("TCN Window Size: ", 300)
|
35 |
+
st.write("Predicted Step: ", slider_predict_step)
|
36 |
+
st.write("Feature Augmentation: ", "Second-order derivative")
|
37 |
+
|
38 |
+
file_path = os.path.join("data", "file"+str(number_input_sample_id)+".dat.npz")
|
39 |
+
data = pd.DataFrame(np.load(file_path)['data'])
|
40 |
+
scaler = MinMaxScaler()
|
41 |
+
uniform_data = uniform_sampling(data, n_sample=1000).sort_index().values[:, 1:8]
|
42 |
+
normal_uniform_data = scaler.fit_transform(uniform_data)
|
43 |
+
data_sequences = torch.tensor(np.stack(normal_uniform_data)).float()
|
44 |
+
original_data_sequences = torch.tensor(np.stack(uniform_data)).float()
|
45 |
+
selected_data = data_sequences[squences_start_idx:squences_start_idx+300+slider_predict_step]
|
46 |
+
original_selected_data = original_data_sequences[squences_start_idx:squences_start_idx+300+slider_predict_step]
|
47 |
+
input_data = torch.stack([compute_gradient(i, degree=2) for i in selected_data]).unsqueeze(0)
|
48 |
+
|
49 |
+
with st.sidebar:
|
50 |
+
st.subheader("Data Configuration")
|
51 |
+
st.write("Sample ID: ", number_input_sample_id)
|
52 |
+
#st.write("Origianl Shape: ", data_sequences.shape)
|
53 |
+
st.write("Squences Start Index: ", squences_start_idx)
|
54 |
+
#st.write("Selected Shape: ", selected_data.shape)
|
55 |
+
#st.write("Input Shape: ", input_data.shape)
|
56 |
+
|
57 |
+
# st.write(selected_data[0])
|
58 |
+
# st.write(input_data[0][0])
|
59 |
+
|
60 |
+
#################################################
|
61 |
+
## LSTM GPU Inference
|
62 |
+
#################################################
|
63 |
+
lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
|
64 |
+
lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
|
65 |
+
lstm_model.eval()
|
66 |
+
lstm_start_time = time.time()
|
67 |
+
with torch.no_grad():
|
68 |
+
lstm_preds = lstm_model(input_data[:, 100:300, :].cuda())
|
69 |
+
lstm_end_time = time.time()
|
70 |
+
lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
|
71 |
+
lstm_normal_preds = lstm_preds.squeeze().cpu().numpy()
|
72 |
+
|
73 |
+
#lstm_model.to_onnx("model/lstm.onnx", torch.randn((1, 200, 21)), export_params=True)
|
74 |
+
|
75 |
+
del lstm_model
|
76 |
+
|
77 |
+
#################################################
|
78 |
+
## LSTM CPU Inference
|
79 |
+
#################################################
|
80 |
+
lstm_cpu_ckpt_file = os.path.join("model", "lstm.ckpt")
|
81 |
+
lstm_cpu_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
|
82 |
+
lstm_cpu_model.to("cpu")
|
83 |
+
lstm_cpu_model.eval()
|
84 |
+
lstm_cpu_start_time = time.time()
|
85 |
+
with torch.no_grad():
|
86 |
+
lstm_cpu_preds = lstm_cpu_model(input_data[:, 100:300, :])
|
87 |
+
lstm_cpu_end_time = time.time()
|
88 |
+
|
89 |
+
del lstm_cpu_model
|
90 |
+
|
91 |
+
#################################################
|
92 |
+
## TCN GPU Inference
|
93 |
+
#################################################
|
94 |
+
tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
|
95 |
+
tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
|
96 |
+
tcn_model.eval()
|
97 |
+
tcn_start_time = time.time()
|
98 |
+
with torch.no_grad():
|
99 |
+
input_data_cuda = input_data[:,:300,:].cuda()
|
100 |
+
y_hat = tcn_model(input_data_cuda)
|
101 |
+
for i in range(1, slider_predict_step):
|
102 |
+
gd_y_hat = compute_gradient(y_hat[:, :i, :], degree=2).cuda()
|
103 |
+
output = tcn_model(torch.cat([input_data[:, i:300, :].cuda(), gd_y_hat], dim=1)).cuda()
|
104 |
+
y_hat = torch.cat([y_hat, output], dim=1)
|
105 |
+
tcn_end_time = time.time()
|
106 |
+
tcn_preds = y_hat
|
107 |
+
tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
|
108 |
+
tcn_normal_preds = tcn_preds.squeeze().cpu().numpy()
|
109 |
+
|
110 |
+
#tcn_model.to_onnx("model/tcn.onnx", torch.randn((1, 300, 21)), export_params=True)
|
111 |
+
|
112 |
+
del tcn_model
|
113 |
+
del y_hat, gd_y_hat, output
|
114 |
+
|
115 |
+
#################################################
|
116 |
+
## TCN CPU Inference
|
117 |
+
#################################################
|
118 |
+
input_data_cpu = input_data.to("cpu")
|
119 |
+
tcn_cpu_ckpt_file = os.path.join("model", "tcn.ckpt")
|
120 |
+
tcn_cpu_model = TCNModel.load_from_checkpoint(tcn_cpu_ckpt_file)
|
121 |
+
move_custom_layers_to_device(tcn_cpu_model, "cpu")
|
122 |
+
tcn_cpu_model.eval()
|
123 |
+
tcn_cpu_start_time = time.time()
|
124 |
+
with torch.no_grad():
|
125 |
+
y_hat = None
|
126 |
+
for i in range(slider_predict_step):
|
127 |
+
if i == 0:
|
128 |
+
y_hat = tcn_cpu_model(input_data_cpu[:,:300,:])
|
129 |
+
else:
|
130 |
+
gd_y_hat = compute_gradient(y_hat[:, :i, :], degree=2).to('cpu')
|
131 |
+
output = tcn_cpu_model(torch.concatenate([input_data_cpu[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
|
132 |
+
y_hat = torch.concatenate([y_hat, output], dim=1)
|
133 |
+
tcn_cpu_preds = y_hat
|
134 |
+
tcn_cpu_end_time = time.time()
|
135 |
+
|
136 |
+
del tcn_cpu_model
|
137 |
+
|
138 |
+
st.subheader("Normalized Prediction")
|
139 |
+
|
140 |
+
i = 1
|
141 |
+
for each_col in st.columns(6):
|
142 |
+
with each_col:
|
143 |
+
raw_data = selected_data[:, i]
|
144 |
+
lstm_data = [np.nan] * 300 + lstm_normal_preds[:slider_predict_step, :][:, i].tolist()
|
145 |
+
tcn_data = [np.nan] * 300 + tcn_normal_preds[:, i].tolist()
|
146 |
+
st.markdown(f"<div style='text-align: center'>{each_feature_name[i-1]}</div>", unsafe_allow_html=True)
|
147 |
+
#st.write(np.array(raw_data).shape, np.array(lstm_data).shape, np.array(tcn_data).shape)
|
148 |
+
st.line_chart(pd.DataFrame({"Original": raw_data, "LSTM": lstm_data, "TCN": tcn_data}),
|
149 |
+
color=["#EE4035", "#0077BB", "#7BC043"])
|
150 |
+
i += 1
|
151 |
+
|
152 |
+
# with st.sidebar:
|
153 |
+
# st.write("Predicted Shape: ", lstm_preds.shape)
|
154 |
+
|
155 |
+
st.subheader("Inverse Normalized Prediction")
|
156 |
+
|
157 |
+
i = 1
|
158 |
+
for each_col in st.columns(6):
|
159 |
+
with each_col:
|
160 |
+
raw_data = original_selected_data[:, i]
|
161 |
+
lstm_data = [np.nan] * 300 + lstm_innv_preds[:slider_predict_step, :][:, i].tolist()
|
162 |
+
tcn_data = [np.nan] * 300 + tcn_innv_preds[:, i].tolist()
|
163 |
+
st.markdown(f"<div style='text-align: center'>{each_feature_name[i - 1]}</div>", unsafe_allow_html=True)
|
164 |
+
st.line_chart(pd.DataFrame({"Original": raw_data, "LSTM": lstm_data, "TCN": tcn_data}),
|
165 |
+
color=["#EE4035", "#0077BB", "#7BC043"])
|
166 |
+
i += 1
|
167 |
+
|
168 |
+
LSTM_SMAPE, LSTM_MSE, LSTM_RMSE, LSTM_MAE, LSTM_R2, LSTM_PSD = calculate_metrics(selected_data[300:300+slider_predict_step, :].cpu().numpy(), lstm_normal_preds[:slider_predict_step, :])
|
169 |
+
TCN_SMAPE, TCN_MSE, TCN_RMSE, TCN_MAE, TCN_R2, TCN_PSD = calculate_metrics(selected_data[300:300+slider_predict_step, :].cpu().numpy(), tcn_normal_preds)
|
170 |
+
|
171 |
+
results_df = pd.DataFrame({
|
172 |
+
"Model": ["LSTM", "TCN"],
|
173 |
+
"SMAPE": [LSTM_SMAPE, TCN_SMAPE],
|
174 |
+
"MSE": [LSTM_MSE, TCN_MSE],
|
175 |
+
"RMSE": [LSTM_RMSE, TCN_RMSE],
|
176 |
+
"MAE": [LSTM_MAE, TCN_MAE],
|
177 |
+
"R2": [LSTM_R2, TCN_R2],
|
178 |
+
"PSD": [LSTM_PSD, TCN_PSD]
|
179 |
+
})
|
180 |
+
|
181 |
+
time_df = pd.DataFrame({
|
182 |
+
"Model": ["LSTM-GPU", "TCN-GPU", "LSTM-CPU", "TCN-CPU"],
|
183 |
+
"Time(ms)": [(lstm_end_time - lstm_start_time)*1000,
|
184 |
+
(tcn_end_time - tcn_start_time)*1000,
|
185 |
+
(lstm_cpu_end_time - lstm_cpu_start_time)*1000,
|
186 |
+
(tcn_cpu_end_time - tcn_cpu_start_time)*1000]
|
187 |
+
})
|
188 |
+
|
189 |
+
col1, col2 = st.columns(2)
|
190 |
+
with col1:
|
191 |
+
st.subheader("Evaluation Metrics")
|
192 |
+
st.write(results_df)
|
193 |
+
|
194 |
+
with col2:
|
195 |
+
st.subheader("Prediction Time")
|
196 |
+
st.write(time_df)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas==1.5.3
|
2 |
+
scikit-learn==1.3.2
|
3 |
+
torch==2.1.1
|
4 |
+
streamlit==1.32.2
|
5 |
+
keras==3.0.0
|
6 |
+
torchvision==0.16.1
|
7 |
+
pytorch-lightning==2.1.2
|
utils/__pycache__/highlevel.cpython-310.pyc
ADDED
Binary file (4.08 kB). View file
|
|
utils/__pycache__/lowlevel.cpython-310.pyc
ADDED
Binary file (6.04 kB). View file
|
|
utils/__pycache__/metrics.cpython-310.pyc
ADDED
Binary file (2.22 kB). View file
|
|
utils/__pycache__/midpoint.cpython-310.pyc
ADDED
Binary file (6.02 kB). View file
|
|
utils/__pycache__/transform.cpython-310.pyc
ADDED
Binary file (448 Bytes). View file
|
|
utils/highlevel.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sympy as sp
|
2 |
+
|
3 |
+
class HighLevel():
|
4 |
+
|
5 |
+
def __init__(self, j):
|
6 |
+
self.j = j
|
7 |
+
|
8 |
+
def initial(self):
|
9 |
+
j = self.j
|
10 |
+
|
11 |
+
#init parameters
|
12 |
+
h, b, n, u = 0.1, None, None, None
|
13 |
+
x, y, z, px, py, pz = None, 0.1, 0.001, 0.01, None, 0.0001
|
14 |
+
xa, ya, za, pxa, pya, pza = None, None, None, None, None, None
|
15 |
+
|
16 |
+
if j == 1:
|
17 |
+
b = 5.0 / 4
|
18 |
+
n = b / (1 + b) ** 2
|
19 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
20 |
+
x = 10.0
|
21 |
+
py = 0.5
|
22 |
+
elif j == 2:
|
23 |
+
b = 3.0 / 4
|
24 |
+
n = b / (1 + b) ** 2
|
25 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
26 |
+
x = 8.3
|
27 |
+
py = 0.6
|
28 |
+
elif j == 3:
|
29 |
+
b = 3.0 / 2
|
30 |
+
x = 12.0
|
31 |
+
py = 0.4
|
32 |
+
elif j == 4:
|
33 |
+
b = 7.0 / 4
|
34 |
+
n = b / (1 + b) ** 2
|
35 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
36 |
+
x = 15.0
|
37 |
+
py = 0.35
|
38 |
+
elif j == 5:
|
39 |
+
b = 1.0
|
40 |
+
n = b / (1 + b) ** 2
|
41 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
42 |
+
x = 18.0
|
43 |
+
py = 0.3
|
44 |
+
elif j == 6:
|
45 |
+
b = 3.0 / 5
|
46 |
+
n = b / (1 + b) ** 2
|
47 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
48 |
+
x = 20.0
|
49 |
+
py = 0.25
|
50 |
+
elif j == 7:
|
51 |
+
b = 5.0 / 7
|
52 |
+
n = b / (1 + b) ** 2
|
53 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
54 |
+
x = 22.0
|
55 |
+
py = 0.22
|
56 |
+
elif j == 8:
|
57 |
+
b = 2.0
|
58 |
+
x = 26.0
|
59 |
+
py = 0.2
|
60 |
+
elif j == 9:
|
61 |
+
b = 0.5
|
62 |
+
n = b / (1 + b) ** 2
|
63 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
64 |
+
x = 30.0
|
65 |
+
y = 0.5
|
66 |
+
z = 0.1
|
67 |
+
pz = 0.01
|
68 |
+
elif j == 10:
|
69 |
+
b = 5.0
|
70 |
+
n = b / (1 + b) ** 2
|
71 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
72 |
+
x = 35.0
|
73 |
+
y = 2.0
|
74 |
+
z = 0.1
|
75 |
+
pz = 0.03
|
76 |
+
py = 0.15
|
77 |
+
|
78 |
+
xa, ya, za, pxa, pya, pza = x, y, z, px, py, pz
|
79 |
+
return j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza
|
80 |
+
|
81 |
+
def f(self, x, y, z, px, py, pz, b):
|
82 |
+
x_val, y_val, z_val, px_val, py_val, pz_val, b_val = x, y, z, px, py, pz, b
|
83 |
+
x, y, z, px, py, pz, b = sp.symbols('x y z px py pz b')
|
84 |
+
|
85 |
+
c = 1.0
|
86 |
+
|
87 |
+
u = 1 / (1 / b + b + 2)
|
88 |
+
ht = px ** 2 / 2 + py ** 2 / 2 + pz ** 2 / 2
|
89 |
+
hv = -1 / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
90 |
+
h1pn = 1 / (2 * x ** 2 + 2 * y ** 2 + 2 * z ** 2) - (((u + 3) * (px ** 2 + py ** 2 + pz ** 2)) / 2 + (u * (
|
91 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
92 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) ** 2) / 2) / (
|
93 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + ((3 * u) / 8 - 1 / 8) * (
|
94 |
+
px ** 2 + py ** 2 + pz ** 2) ** 2
|
95 |
+
|
96 |
+
e = ht + hv + h1pn
|
97 |
+
|
98 |
+
de_dx = sp.diff(e, x)
|
99 |
+
de_dy = sp.diff(e, y)
|
100 |
+
de_dz = sp.diff(e, z)
|
101 |
+
de_dpx = sp.diff(e, px)
|
102 |
+
de_dpy = sp.diff(e, py)
|
103 |
+
de_dpz = sp.diff(e, pz)
|
104 |
+
|
105 |
+
de_dx_val = de_dx.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
106 |
+
de_dy_val = de_dy.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
107 |
+
de_dz_val = de_dz.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
108 |
+
de_dpx_val = de_dpx.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
109 |
+
de_dpy_val = de_dpy.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
110 |
+
de_dpz_val = de_dpz.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
111 |
+
|
112 |
+
e_val = e.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
113 |
+
|
114 |
+
return de_dx_val, de_dy_val, de_dz_val, de_dpx_val, de_dpy_val, de_dpz_val, e_val
|
115 |
+
|
116 |
+
def rejust(self, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza):
|
117 |
+
|
118 |
+
x = (x + xa) / 2
|
119 |
+
y = (y + ya) / 2
|
120 |
+
z = (z + za) / 2
|
121 |
+
|
122 |
+
px = (px + pxa) / 2
|
123 |
+
py = (py + pya) / 2
|
124 |
+
pz = (pz + pza) / 2
|
125 |
+
xa = x
|
126 |
+
ya = y
|
127 |
+
za = z
|
128 |
+
pxa = px
|
129 |
+
pya = py
|
130 |
+
pza = pz
|
131 |
+
|
132 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
133 |
+
|
134 |
+
def symplectic(self, h, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza, b):
|
135 |
+
|
136 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
137 |
+
x = x + h / 2 * vpx
|
138 |
+
y = y + h / 2 * vpy
|
139 |
+
z = z + h / 2 * vpz
|
140 |
+
pxa = pxa - h / 2 * vxa
|
141 |
+
pya = pya - h / 2 * vya
|
142 |
+
pza = pza - h / 2 * vza
|
143 |
+
|
144 |
+
vx, vy, vz, vpxa, vpya, vpza, e = self.f(x, y, z, pxa, pya, pza, b)
|
145 |
+
xa = xa + h * vpxa
|
146 |
+
ya = ya + h * vpya
|
147 |
+
za = za + h * vpza
|
148 |
+
px = px - h * vx
|
149 |
+
py = py - h * vy
|
150 |
+
pz = pz - h * vz
|
151 |
+
|
152 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
153 |
+
x = x + h / 2 * vpx
|
154 |
+
y = y + h / 2 * vpy
|
155 |
+
z = z + h / 2 * vpz
|
156 |
+
pxa = pxa - h / 2 * vxa
|
157 |
+
pya = pya - h / 2 * vya
|
158 |
+
pza = pza - h / 2 * vza
|
159 |
+
|
160 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
utils/lowlevel.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class LowLevel():
|
5 |
+
|
6 |
+
def __init__(self, j):
|
7 |
+
self.j = j
|
8 |
+
|
9 |
+
def initial(self):
|
10 |
+
j = self.j
|
11 |
+
|
12 |
+
#init parameters
|
13 |
+
h, b, n, u = 0.1, None, None, None
|
14 |
+
x, y, z, px, py, pz = None, 0.1, 0.001, 0.01, None, 0.0001
|
15 |
+
xa, ya, za, pxa, pya, pza = None, None, None, None, None, None
|
16 |
+
|
17 |
+
if j == 1:
|
18 |
+
b = 5.0 / 4
|
19 |
+
n = b / (1 + b) ** 2
|
20 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
21 |
+
x = 10.0
|
22 |
+
py = 0.5
|
23 |
+
elif j == 2:
|
24 |
+
b = 3.0 / 4
|
25 |
+
n = b / (1 + b) ** 2
|
26 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
27 |
+
x = 8.3
|
28 |
+
py = 0.6
|
29 |
+
elif j == 3:
|
30 |
+
b = 3.0 / 2
|
31 |
+
x = 12.0
|
32 |
+
py = 0.4
|
33 |
+
elif j == 4:
|
34 |
+
b = 7.0 / 4
|
35 |
+
n = b / (1 + b) ** 2
|
36 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
37 |
+
x = 15.0
|
38 |
+
py = 0.35
|
39 |
+
elif j == 5:
|
40 |
+
b = 1.0
|
41 |
+
n = b / (1 + b) ** 2
|
42 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
43 |
+
x = 18.0
|
44 |
+
py = 0.3
|
45 |
+
elif j == 6:
|
46 |
+
b = 3.0 / 5
|
47 |
+
n = b / (1 + b) ** 2
|
48 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
49 |
+
x = 20.0
|
50 |
+
py = 0.25
|
51 |
+
elif j == 7:
|
52 |
+
b = 5.0 / 7
|
53 |
+
n = b / (1 + b) ** 2
|
54 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
55 |
+
x = 22.0
|
56 |
+
py = 0.22
|
57 |
+
elif j == 8:
|
58 |
+
b = 2.0
|
59 |
+
x = 26.0
|
60 |
+
py = 0.2
|
61 |
+
elif j == 9:
|
62 |
+
b = 0.5
|
63 |
+
n = b / (1 + b) ** 2
|
64 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
65 |
+
x = 30.0
|
66 |
+
y = 0.5
|
67 |
+
z = 0.1
|
68 |
+
pz = 0.01
|
69 |
+
elif j == 10:
|
70 |
+
b = 5.0
|
71 |
+
n = b / (1 + b) ** 2
|
72 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
73 |
+
x = 35.0
|
74 |
+
y = 2.0
|
75 |
+
z = 0.1
|
76 |
+
pz = 0.03
|
77 |
+
py = 0.15
|
78 |
+
|
79 |
+
xa, ya, za, pxa, pya, pza = x, y, z, px, py, pz
|
80 |
+
return j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza
|
81 |
+
|
82 |
+
def f(self, x, y, z, px, py, pz, b):
|
83 |
+
n = b / (1 + b)**2
|
84 |
+
u = 1 / (1 / b + b + 2)
|
85 |
+
ht = px**2 / 2 + py**2 / 2 + pz**2 / 2
|
86 |
+
hv = -1 / (x**2 + y**2 + z**2)**(1/2)
|
87 |
+
h1pn = 1/(2*x**2 + 2*y**2 + 2*z**2) - (((u + 3)*(px**2 + py**2 +pz**2))/2 + (u*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/ (x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))**2)/2)/(x**2 + y**2 + z**2)**(1/2) + ((3*u)/8 -1/8)*(px**2 + py**2 + pz**2)**2
|
88 |
+
|
89 |
+
e = ht + hv + h1pn
|
90 |
+
|
91 |
+
vnpx=px
|
92 |
+
v1pnpx=4*px*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (px*(n + 3) + (n*x*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
93 |
+
vpx=vnpx+v1pnpx
|
94 |
+
|
95 |
+
vnpy=py
|
96 |
+
v1pnpy=4*py*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (py*(n + 3) + (n*y*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
97 |
+
vpy=vnpy+v1pnpy
|
98 |
+
|
99 |
+
vnpz=pz
|
100 |
+
v1pnpz=4*pz*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (pz*(n + 3) + (n*z*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
101 |
+
vpz=vnpz+v1pnpz
|
102 |
+
|
103 |
+
vnx=x/(x**2 + y**2 + z**2)**(3/2)
|
104 |
+
v1pnx=(x*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*x)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((px*x**2)/(x**2 + y**2 + z**2)**(3/2) -px/(x**2 + y**2 + z**2)**(1/2) + (py*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*x*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
105 |
+
vx=vnx+v1pnx
|
106 |
+
|
107 |
+
vny=y/(x**2 + y**2 + z**2)**(3/2)
|
108 |
+
v1pny=(y*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*y)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((py*y**2)/(x**2 + y**2 + z**2)**(3/2) -py/(x**2 + y**2 + z**2)**(1/2) + (px*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
109 |
+
vy=vny+v1pny
|
110 |
+
|
111 |
+
vnz=z/(x**2 + y**2 + z**2)**(3/2)
|
112 |
+
v1pnz=(z*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*z)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((pz*z**2)/(x**2 + y**2 + z**2)**(3/2) -pz/(x**2 + y**2 + z**2)**(1/2) + (px*x*z)/(x**2 + y**2 + z**2)**(3/2) + (py*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
113 |
+
vz=vnz+v1pnz
|
114 |
+
|
115 |
+
return vx,vy,vz,vpx,vpy,vpz,e
|
116 |
+
|
117 |
+
def rejust(self, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza):
|
118 |
+
x = (x + xa) / 2
|
119 |
+
y = (y + ya) / 2
|
120 |
+
z = (z + za) / 2
|
121 |
+
|
122 |
+
px = (px + pxa) / 2
|
123 |
+
py = (py + pya) / 2
|
124 |
+
pz = (pz + pza) / 2
|
125 |
+
xa = x
|
126 |
+
ya = y
|
127 |
+
za = z
|
128 |
+
pxa = px
|
129 |
+
pya = py
|
130 |
+
pza = pz
|
131 |
+
return x,y,z,px,py,pz,xa,ya,za,pxa,pya,pza
|
132 |
+
|
133 |
+
def symplectic(self, h, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza, b):
|
134 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
135 |
+
x = x + h / 2 * vpx
|
136 |
+
y = y + h / 2 * vpy
|
137 |
+
z = z + h / 2 * vpz
|
138 |
+
pxa = pxa - h / 2 * vxa
|
139 |
+
pya = pya - h / 2 * vya
|
140 |
+
pza = pza - h / 2 * vza
|
141 |
+
|
142 |
+
vx, vy, vz, vpxa, vpya, vpza, e = self.f(x, y, z, pxa, pya, pza, b)
|
143 |
+
xa = xa + h * vpxa
|
144 |
+
ya = ya + h * vpya
|
145 |
+
za = za + h * vpza
|
146 |
+
px = px - h * vx
|
147 |
+
py = py - h * vy
|
148 |
+
pz = pz - h * vz
|
149 |
+
|
150 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
151 |
+
x = x + h / 2 * vpx
|
152 |
+
y = y + h / 2 * vpy
|
153 |
+
z = z + h / 2 * vpz
|
154 |
+
pxa = pxa - h / 2 * vxa
|
155 |
+
pya = pya - h / 2 * vya
|
156 |
+
pza = pza - h / 2 * vza
|
157 |
+
|
158 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
utils/metrics.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
3 |
+
|
4 |
+
def calculate_metrics(y, y_hat, y_train=None):
|
5 |
+
def smape(a, f):
|
6 |
+
return 1/len(a) * np.sum(2 * np.abs(f - a) / (np.abs(a) + np.abs(f) + np.finfo(float).eps))
|
7 |
+
|
8 |
+
def mase(y_actual, y_pred, y_train):
|
9 |
+
n = y_train.shape[1]
|
10 |
+
d = np.abs(np.diff(y_train)).sum() / (n - 1)
|
11 |
+
errors = np.abs(y_actual - y_pred)
|
12 |
+
return errors.mean() / d
|
13 |
+
|
14 |
+
def phase_space_distance(y_actual, y_pred):
|
15 |
+
return np.sqrt(np.sum(np.square(y_actual - y_pred)))
|
16 |
+
|
17 |
+
SMAPE = np.mean([smape(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
18 |
+
MSE = np.mean([mean_squared_error(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
19 |
+
RMSE = np.mean([np.sqrt(mean_squared_error(yi.reshape(-1), y_hati.reshape(-1))) for yi, y_hati in zip(y, y_hat)])
|
20 |
+
MAE = np.mean([mean_absolute_error(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
21 |
+
R2 = np.mean([r2_score(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
22 |
+
PSD = np.mean([phase_space_distance(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
23 |
+
|
24 |
+
if y_train is None:
|
25 |
+
return SMAPE, MSE, RMSE, MAE, R2, PSD
|
26 |
+
else:
|
27 |
+
MASE = np.mean([mase(yi, y_hati, yt) for yi, y_hati, yt in zip(y, y_hat, y_train)])
|
28 |
+
return SMAPE, MSE, RMSE, MAE, R2, MASE, PSD
|
utils/midpoint.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class MidPoint():
|
3 |
+
|
4 |
+
def __init__(self, j):
|
5 |
+
self.j = j
|
6 |
+
|
7 |
+
def initial(self):
|
8 |
+
j = self.j
|
9 |
+
|
10 |
+
#init parameters
|
11 |
+
h, b, n, u = 0.1, None, None, None
|
12 |
+
x, y, z, px, py, pz = None, 0.1, 0.001, 0.01, None, 0.0001
|
13 |
+
xa, ya, za, pxa, pya, pza = None, None, None, None, None, None
|
14 |
+
|
15 |
+
if j == 1:
|
16 |
+
b = 5.0 / 4
|
17 |
+
n = b / (1 + b) ** 2
|
18 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
19 |
+
x = 10.0
|
20 |
+
py = 0.5
|
21 |
+
elif j == 2:
|
22 |
+
b = 3.0 / 4
|
23 |
+
n = b / (1 + b) ** 2
|
24 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
25 |
+
x = 8.3
|
26 |
+
py = 0.6
|
27 |
+
elif j == 3:
|
28 |
+
b = 3.0 / 2
|
29 |
+
x = 12.0
|
30 |
+
py = 0.4
|
31 |
+
elif j == 4:
|
32 |
+
b = 7.0 / 4
|
33 |
+
n = b / (1 + b) ** 2
|
34 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
35 |
+
x = 15.0
|
36 |
+
py = 0.35
|
37 |
+
elif j == 5:
|
38 |
+
b = 1.0
|
39 |
+
n = b / (1 + b) ** 2
|
40 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
41 |
+
x = 18.0
|
42 |
+
py = 0.3
|
43 |
+
elif j == 6:
|
44 |
+
b = 3.0 / 5
|
45 |
+
n = b / (1 + b) ** 2
|
46 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
47 |
+
x = 20.0
|
48 |
+
py = 0.25
|
49 |
+
elif j == 7:
|
50 |
+
b = 5.0 / 7
|
51 |
+
n = b / (1 + b) ** 2
|
52 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
53 |
+
x = 22.0
|
54 |
+
py = 0.22
|
55 |
+
elif j == 8:
|
56 |
+
b = 2.0
|
57 |
+
x = 26.0
|
58 |
+
py = 0.2
|
59 |
+
elif j == 9:
|
60 |
+
b = 0.5
|
61 |
+
n = b / (1 + b) ** 2
|
62 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
63 |
+
x = 30.0
|
64 |
+
y = 0.5
|
65 |
+
z = 0.1
|
66 |
+
pz = 0.01
|
67 |
+
elif j == 10:
|
68 |
+
b = 5.0
|
69 |
+
n = b / (1 + b) ** 2
|
70 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
71 |
+
x = 35.0
|
72 |
+
y = 2.0
|
73 |
+
z = 0.1
|
74 |
+
pz = 0.03
|
75 |
+
py = 0.15
|
76 |
+
|
77 |
+
xa, ya, za, pxa, pya, pza = x, y, z, px, py, pz
|
78 |
+
return j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza
|
79 |
+
|
80 |
+
def f(self, x, y, z, px, py, pz, b):
|
81 |
+
|
82 |
+
n = b / (1 + b) ** 2
|
83 |
+
u = 1 / (1 / b + b + 2)
|
84 |
+
ht = px ** 2 / 2 + py ** 2 / 2 + pz ** 2 / 2
|
85 |
+
hv = -1 / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
86 |
+
h1pn = 1 / (2 * x ** 2 + 2 * y ** 2 + 2 * z ** 2) - (((u + 3) * (px ** 2 + py ** 2 + pz ** 2)) / 2 + (u * (
|
87 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
88 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) ** 2) / 2) / (
|
89 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + ((3 * u) / 8 - 1 / 8) * (
|
90 |
+
px ** 2 + py ** 2 + pz ** 2) ** 2
|
91 |
+
|
92 |
+
e = ht + hv + h1pn
|
93 |
+
|
94 |
+
vnpx=px
|
95 |
+
v1pnpx=4*px*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (px*(n + 3) + (n*x*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
96 |
+
vpx=vnpx+v1pnpx
|
97 |
+
|
98 |
+
vnpy=py
|
99 |
+
v1pnpy=4*py*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (py*(n + 3) + (n*y*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
100 |
+
vpy=vnpy+v1pnpy
|
101 |
+
|
102 |
+
vnpz=pz
|
103 |
+
v1pnpz=4*pz*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (pz*(n + 3) + (n*z*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
104 |
+
vpz=vnpz+v1pnpz
|
105 |
+
|
106 |
+
vnx=x/(x**2 + y**2 + z**2)**(3/2)
|
107 |
+
v1pnx=(x*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*x)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((px*x**2)/(x**2 + y**2 + z**2)**(3/2) -px/(x**2 + y**2 + z**2)**(1/2) + (py*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*x*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
108 |
+
vx=vnx+v1pnx
|
109 |
+
|
110 |
+
vny=y/(x**2 + y**2 + z**2)**(3/2)
|
111 |
+
v1pny=(y*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*y)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((py*y**2)/(x**2 + y**2 + z**2)**(3/2) -py/(x**2 + y**2 + z**2)**(1/2) + (px*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
112 |
+
vy=vny+v1pny
|
113 |
+
|
114 |
+
vnz=z/(x**2 + y**2 + z**2)**(3/2)
|
115 |
+
v1pnz=(z*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*z)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((pz*z**2)/(x**2 + y**2 + z**2)**(3/2) -pz/(x**2 + y**2 + z**2)**(1/2) + (px*x*z)/(x**2 + y**2 + z**2)**(3/2) + (py*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
116 |
+
vz=vnz+v1pnz
|
117 |
+
|
118 |
+
return vx,vy,vz,vpx,vpy,vpz,e
|
119 |
+
|
120 |
+
def rejust(self, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza):
|
121 |
+
|
122 |
+
x = (x + xa) / 2
|
123 |
+
y = (y + ya) / 2
|
124 |
+
z = (z + za) / 2
|
125 |
+
|
126 |
+
px = (px + pxa) / 2
|
127 |
+
py = (py + pya) / 2
|
128 |
+
pz = (pz + pza) / 2
|
129 |
+
xa = x
|
130 |
+
ya = y
|
131 |
+
za = z
|
132 |
+
pxa = px
|
133 |
+
pya = py
|
134 |
+
pza = pz
|
135 |
+
|
136 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
137 |
+
|
138 |
+
def symplectic(self, h, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza, b):
|
139 |
+
|
140 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
141 |
+
x = x + h / 2 * vpx
|
142 |
+
y = y + h / 2 * vpy
|
143 |
+
z = z + h / 2 * vpz
|
144 |
+
pxa = pxa - h / 2 * vxa
|
145 |
+
pya = pya - h / 2 * vya
|
146 |
+
pza = pza - h / 2 * vza
|
147 |
+
|
148 |
+
vx, vy, vz, vpxa, vpya, vpza, e = self.f(x, y, z, pxa, pya, pza, b)
|
149 |
+
xa = xa + h * vpxa
|
150 |
+
ya = ya + h * vpya
|
151 |
+
za = za + h * vpza
|
152 |
+
px = px - h * vx
|
153 |
+
py = py - h * vy
|
154 |
+
pz = pz - h * vz
|
155 |
+
|
156 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
157 |
+
x = x + h / 2 * vpx
|
158 |
+
y = y + h / 2 * vpy
|
159 |
+
z = z + h / 2 * vpz
|
160 |
+
pxa = pxa - h / 2 * vxa
|
161 |
+
pya = pya - h / 2 * vya
|
162 |
+
pza = pza - h / 2 * vza
|
163 |
+
|
164 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
utils/transform.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def compute_gradient(x, degree):
|
4 |
+
gradients = [x]
|
5 |
+
for i in range(degree):
|
6 |
+
x = torch.diff(x, dim=-1, prepend=x[..., 0:1])
|
7 |
+
gradients.append(x)
|
8 |
+
return torch.concatenate(gradients, dim=-1)
|