yan123yan commited on
Commit
47fe089
·
1 Parent(s): 8cf3e4f

first version

Browse files
data/file1.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90713e007e25e2b4467711981274dec5f15548666bf7867a30cdf5e189b94b80
3
+ size 7200262
data/file10.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:123c044fc18c7e1b0671261fc081aa7e9a60eab726f35b2758b75f3e70ebe76e
3
+ size 7200262
data/file2.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebdbdcb3465e49f16073eb828dc18163a12e8b5968cb7f4661e3931c13ac0cea
3
+ size 7200262
data/file3.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b14cf26774c287dd0fb2351a6fc6ce3a2eae135a57c265a93c03f017d479d3a
3
+ size 7200262
data/file4.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7810f50b9c30a19fbfcb3db77bf838b02df4fc54f6caa63e8dd3ca0f2abba6c1
3
+ size 7200262
data/file5.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34f3485194b57325b0f32554c4f14218389a6d47d5b6edb7e37626fc48d6aae8
3
+ size 7200262
data/file6.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d59084be5aff0eb578c7ff5ee62027ef853d1f5d8d2794d6230d5b706aa0f6aa
3
+ size 7200262
data/file7.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9035c00bb34d17940b033e3bae40097296c493a4852630fbcf963c6f80391f5a
3
+ size 7200262
data/file8.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:955406a5ec5af34373c4ae8573195ca580bfcfbbef13efdcb410fc05d58d66f8
3
+ size 7200262
data/file9.dat.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ea42398e0062a3280981795ea8fd3a49bd5e104de178decfad9e69e41c0de8c
3
+ size 7200262
model/__pycache__/lstm.cpython-310.pyc ADDED
Binary file (1.2 kB). View file
 
model/__pycache__/tcn.cpython-310.pyc ADDED
Binary file (1.59 kB). View file
 
model/__pycache__/tcn_module.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
model/lstm.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1080dbdc37acdb1e9e6a29c140711908d2426b39735617eafbadd49fb5772ef4
3
+ size 7286190
model/lstm.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class LSTMModel(pl.LightningModule):
6
+ def __init__(self, **config):
7
+ super(LSTMModel, self).__init__()
8
+ self.save_hyperparameters(config)
9
+
10
+ self.lstm = nn.LSTM(input_size=21,hidden_size=512,num_layers=3,proj_size=21,batch_first=True)
11
+ self.linear = nn.Linear(in_features=21, out_features=7)
12
+
13
+ def forward(self, x):
14
+ outputs = []
15
+ hidden, cell = None, None
16
+ for i in range(20):
17
+ if i == 0:
18
+ output, (hidden, cell) = self.lstm(x)
19
+ else:
20
+ output, (hidden, cell) = self.lstm(output[:, -1, :].unsqueeze(1), (hidden, cell))
21
+ outputs.append(self.linear(output[:, -1, :]))
22
+ return torch.stack(outputs, dim=1)
model/tcn.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc975af786670412a2e419978636038452467676e1a9dac59d2ed77033f9f67b
3
+ size 43742454
model/tcn.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["KERAS_BACKEND"] = "torch"
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from keras.layers import Input, Dense
8
+ from keras.models import Model
9
+ from model.tcn_module import TCN
10
+
11
+
12
+ class TCNModel(pl.LightningModule):
13
+ def __init__(self, **config):
14
+ super(TCNModel, self).__init__()
15
+ self.save_hyperparameters(config)
16
+
17
+ input_layer = Input(shape=(self.hparams.windows_size, self.hparams.input_size))
18
+ self.tcn = TCN(input_shape=(self.hparams.windows_size, self.hparams.input_size))(input_layer)
19
+ self.linear = Dense(7)(self.tcn)
20
+ self.model = Model(inputs=input_layer, outputs=self.linear)
21
+
22
+ def forward(self, x):
23
+ output = self.model(x)
24
+ return torch.stack([output], dim=1)
25
+
26
+ def move_custom_layers_to_device(model, device):
27
+ for name, module in model.named_children():
28
+ # 如果是标准层,named_children已经处理了
29
+ if isinstance(module, nn.Module):
30
+ continue
31
+
32
+ # 对于非标准层,例如包含在列表或字典中的层
33
+ if isinstance(module, list):
34
+ for sub_module in module:
35
+ if isinstance(sub_module, nn.Module):
36
+ sub_module.to(device)
37
+ elif isinstance(module, dict):
38
+ for sub_module in module.values():
39
+ if isinstance(sub_module, nn.Module):
40
+ sub_module.to(device)
model/tcn_module.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List
3
+
4
+ import os
5
+ os.environ["KERAS_BACKEND"] = "torch"
6
+ import keras
7
+
8
+ # from keras_core import backend as K, Model, Input, optimizers
9
+ # from keras_core import backend as Model, Input, optimizers
10
+ # from keras_core import backend as K
11
+
12
+ from keras import Model
13
+ from keras import optimizers
14
+ from keras import ops as K
15
+ from keras import config as KK
16
+
17
+ from keras import layers
18
+ from keras.layers import Input, Layer, Conv1D, Dense, BatchNormalization, LayerNormalization, Activation, SpatialDropout1D, Lambda
19
+
20
+
21
+ def is_power_of_two(num: int):
22
+ return num != 0 and ((num & (num - 1)) == 0)
23
+
24
+
25
+ def adjust_dilations(dilations: list):
26
+ if all([is_power_of_two(i) for i in dilations]):
27
+ return dilations
28
+ else:
29
+ new_dilations = [2 ** i for i in dilations]
30
+ return new_dilations
31
+
32
+
33
+ class ResidualBlock(Layer):
34
+
35
+ def __init__(self,
36
+ dilation_rate: int,
37
+ nb_filters: int,
38
+ kernel_size: int,
39
+ padding: str,
40
+ activation: str = 'relu',
41
+ dropout_rate: float = 0,
42
+ kernel_initializer: str = 'he_normal',
43
+ use_batch_norm: bool = False,
44
+ use_layer_norm: bool = False,
45
+ use_weight_norm: bool = False,
46
+ **kwargs):
47
+ """Defines the residual block for the WaveNet TCN
48
+ Args:
49
+ x: The previous layer in the model
50
+ training: boolean indicating whether the layer should behave in training mode or in inference mode
51
+ dilation_rate: The dilation power of 2 we are using for this residual block
52
+ nb_filters: The number of convolutional filters to use in this block
53
+ kernel_size: The size of the convolutional kernel
54
+ padding: The padding used in the convolutional layers, 'same' or 'causal'.
55
+ activation: The final activation used in o = Activation(x + F(x))
56
+ dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
57
+ kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
58
+ use_batch_norm: Whether to use batch normalization in the residual layers or not.
59
+ use_layer_norm: Whether to use layer normalization in the residual layers or not.
60
+ use_weight_norm: Whether to use weight normalization in the residual layers or not.
61
+ kwargs: Any initializers for Layer class.
62
+ """
63
+
64
+ self.dilation_rate = dilation_rate
65
+ self.nb_filters = nb_filters
66
+ self.kernel_size = kernel_size
67
+ self.padding = padding
68
+ self.activation = activation
69
+ self.dropout_rate = dropout_rate
70
+ self.use_batch_norm = use_batch_norm
71
+ self.use_layer_norm = use_layer_norm
72
+ self.use_weight_norm = use_weight_norm
73
+ self.kernel_initializer = kernel_initializer
74
+ self.layers = []
75
+ self.shape_match_conv = None
76
+ self.res_output_shape = None
77
+ self.final_activation = None
78
+
79
+ super(ResidualBlock, self).__init__(**kwargs)
80
+
81
+ def _build_layer(self, layer):
82
+ """Helper function for building layer
83
+ Args:
84
+ layer: Appends layer to internal layer list and builds it based on the current output
85
+ shape of ResidualBlocK. Updates current output shape.
86
+ """
87
+ self.layers.append(layer)
88
+ self.layers[-1].build(self.res_output_shape)
89
+ self.res_output_shape = self.layers[-1].compute_output_shape(self.res_output_shape)
90
+
91
+ def build(self, input_shape):
92
+
93
+ #with K.name_scope(self.name): # name scope used to make sure weights get unique names
94
+ self.layers = []
95
+ self.res_output_shape = input_shape
96
+
97
+ for k in range(2): # dilated conv block.
98
+ name = 'conv1D_{}'.format(k)
99
+ # with K.name_scope(name): # name scope used to make sure weights get unique names
100
+ conv = Conv1D(
101
+ filters=self.nb_filters,
102
+ kernel_size=self.kernel_size,
103
+ dilation_rate=self.dilation_rate,
104
+ padding=self.padding,
105
+ name=name,
106
+ kernel_initializer=self.kernel_initializer
107
+ )
108
+ if self.use_weight_norm:
109
+ from tensorflow_addons.layers import WeightNormalization
110
+ # wrap it. WeightNormalization API is different than BatchNormalization or LayerNormalization.
111
+ #with K.name_scope('norm_{}'.format(k)):
112
+ conv = WeightNormalization(conv)
113
+ self._build_layer(conv)
114
+
115
+ #with K.name_scope('norm_{}'.format(k)):
116
+ if self.use_batch_norm:
117
+ self._build_layer(BatchNormalization())
118
+ elif self.use_layer_norm:
119
+ self._build_layer(LayerNormalization())
120
+ elif self.use_weight_norm:
121
+ pass # done above.
122
+
123
+ # with K.name_scope('act_and_dropout_{}'.format(k)):
124
+ self._build_layer(Activation(self.activation, name='Act_Conv1D_{}'.format(k)))
125
+ self._build_layer(SpatialDropout1D(rate=self.dropout_rate, name='SDropout_{}'.format(k)))
126
+
127
+ if self.nb_filters != input_shape[-1]:
128
+ # 1x1 conv to match the shapes (channel dimension).
129
+ name = 'matching_conv1D'
130
+ #with K.name_scope(name):
131
+ # make and build this layer separately because it directly uses input_shape.
132
+ # 1x1 conv.
133
+ self.shape_match_conv = Conv1D(
134
+ filters=self.nb_filters,
135
+ kernel_size=1,
136
+ padding='same',
137
+ name=name,
138
+ kernel_initializer=self.kernel_initializer
139
+ )
140
+ else:
141
+ name = 'matching_identity'
142
+ self.shape_match_conv = Lambda(lambda x: x, name=name)
143
+
144
+ #with K.name_scope(name):
145
+ self.shape_match_conv.build(input_shape)
146
+ self.res_output_shape = self.shape_match_conv.compute_output_shape(input_shape)
147
+
148
+ self._build_layer(Activation(self.activation, name='Act_Conv_Blocks'))
149
+ self.final_activation = Activation(self.activation, name='Act_Res_Block')
150
+ self.final_activation.build(self.res_output_shape) # probably isn't necessary
151
+
152
+ # this is done to force Keras to add the layers in the list to self._layers
153
+ for layer in self.layers:
154
+ self.__setattr__(layer.name, layer)
155
+ self.__setattr__(self.shape_match_conv.name, self.shape_match_conv)
156
+ self.__setattr__(self.final_activation.name, self.final_activation)
157
+
158
+ super(ResidualBlock, self).build(input_shape) # done to make sure self.built is set True
159
+
160
+ def call(self, inputs, training=None, **kwargs):
161
+ """
162
+ Returns: A tuple where the first element is the residual model tensor, and the second
163
+ is the skip connection tensor.
164
+ """
165
+ # https://arxiv.org/pdf/1803.01271.pdf page 4, Figure 1 (b).
166
+ # x1: Dilated Conv -> Norm -> Dropout (x2).
167
+ # x2: Residual (1x1 matching conv - optional).
168
+ # Output: x1 + x2.
169
+ # x1 -> connected to skip connections.
170
+ # x1 + x2 -> connected to the next block.
171
+ # input
172
+ # x1 x2
173
+ # conv1D 1x1 Conv1D (optional)
174
+ # ...
175
+ # conv1D
176
+ # ...
177
+ # x1 + x2
178
+ x1 = inputs
179
+ for layer in self.layers:
180
+ training_flag = 'training' in dict(inspect.signature(layer.call).parameters)
181
+ x1 = layer(x1, training=training) if training_flag else layer(x1)
182
+ x2 = self.shape_match_conv(inputs)
183
+ x1_x2 = self.final_activation(layers.add([x2, x1], name='Add_Res'))
184
+ return [x1_x2, x1]
185
+
186
+ def compute_output_shape(self, input_shape):
187
+ return [self.res_output_shape, self.res_output_shape]
188
+
189
+
190
+ class TCN(Layer):
191
+ """Creates a TCN layer.
192
+ Input shape:
193
+ A tensor of shape (batch_size, timesteps, input_dim).
194
+ Args:
195
+ nb_filters: The number of filters to use in the convolutional layers. Can be a list.
196
+ kernel_size: The size of the kernel to use in each convolutional layer.
197
+ dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
198
+ nb_stacks : The number of stacks of residual blocks to use.
199
+ padding: The padding to use in the convolutional layers, 'causal' or 'same'.
200
+ use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
201
+ return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
202
+ activation: The activation used in the residual blocks o = Activation(x + F(x)).
203
+ dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
204
+ kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
205
+ use_batch_norm: Whether to use batch normalization in the residual layers or not.
206
+ use_layer_norm: Whether to use layer normalization in the residual layers or not.
207
+ use_weight_norm: Whether to use weight normalization in the residual layers or not.
208
+ kwargs: Any other arguments for configuring parent class Layer. For example "name=str", Name of the model.
209
+ Use unique names when using multiple TCN.
210
+ Returns:
211
+ A TCN layer.
212
+ """
213
+
214
+ def __init__(self,
215
+ nb_filters=256,
216
+ kernel_size=5,
217
+ nb_stacks=1,
218
+ dilations=(1, 2, 4, 8, 16, 32),
219
+ padding='causal',
220
+ use_skip_connections=True,
221
+ dropout_rate=0.0,
222
+ return_sequences=False,
223
+ activation='relu',
224
+ kernel_initializer='he_normal',
225
+ use_batch_norm=False,
226
+ use_layer_norm=False,
227
+ use_weight_norm=False,
228
+ **kwargs):
229
+ print("nb_filters:", nb_filters, "kernel_size", kernel_size)
230
+ self.return_sequences = return_sequences
231
+ self.dropout_rate = dropout_rate
232
+ self.use_skip_connections = use_skip_connections
233
+ self.dilations = dilations
234
+ self.nb_stacks = nb_stacks
235
+ self.kernel_size = kernel_size
236
+ self.nb_filters = nb_filters
237
+ self.activation_name = activation
238
+ self.padding = padding
239
+ self.kernel_initializer = kernel_initializer
240
+ self.use_batch_norm = use_batch_norm
241
+ self.use_layer_norm = use_layer_norm
242
+ self.use_weight_norm = use_weight_norm
243
+ self.skip_connections = []
244
+ self.residual_blocks = []
245
+ self.layers_outputs = []
246
+ self.build_output_shape = None
247
+ self.slicer_layer = None # in case return_sequence=False
248
+ self.output_slice_index = None # in case return_sequence=False
249
+ self.padding_same_and_time_dim_unknown = False # edge case if padding='same' and time_dim = None
250
+
251
+ if self.use_batch_norm + self.use_layer_norm + self.use_weight_norm > 1:
252
+ raise ValueError('Only one normalization can be specified at once.')
253
+
254
+ if isinstance(self.nb_filters, list):
255
+ assert len(self.nb_filters) == len(self.dilations)
256
+ if len(set(self.nb_filters)) > 1 and self.use_skip_connections:
257
+ raise ValueError('Skip connections are not compatible '
258
+ 'with a list of filters, unless they are all equal.')
259
+
260
+ if padding != 'causal' and padding != 'same':
261
+ raise ValueError("Only 'causal' or 'same' padding are compatible for this layer.")
262
+
263
+ # initialize parent class
264
+ super(TCN, self).__init__(**kwargs)
265
+
266
+ @property
267
+ def receptive_field(self):
268
+ return 1 + 2 * (self.kernel_size - 1) * self.nb_stacks * sum(self.dilations)
269
+
270
+ def build(self, input_shape):
271
+
272
+ # member to hold current output shape of the layer for building purposes
273
+ self.build_output_shape = input_shape
274
+
275
+ # list to hold all the member ResidualBlocks
276
+ self.residual_blocks = []
277
+ total_num_blocks = self.nb_stacks * len(self.dilations)
278
+ if not self.use_skip_connections:
279
+ total_num_blocks += 1 # cheap way to do a false case for below
280
+
281
+ for s in range(self.nb_stacks):
282
+ for i, d in enumerate(self.dilations):
283
+ res_block_filters = self.nb_filters[i] if isinstance(self.nb_filters, list) else self.nb_filters
284
+ self.residual_blocks.append(ResidualBlock(dilation_rate=d,
285
+ nb_filters=res_block_filters,
286
+ kernel_size=self.kernel_size,
287
+ padding=self.padding,
288
+ activation=self.activation_name,
289
+ dropout_rate=self.dropout_rate,
290
+ use_batch_norm=self.use_batch_norm,
291
+ use_layer_norm=self.use_layer_norm,
292
+ use_weight_norm=self.use_weight_norm,
293
+ kernel_initializer=self.kernel_initializer,
294
+ name='residual_block_{}'.format(len(self.residual_blocks))))
295
+ # build newest residual block
296
+ self.residual_blocks[-1].build(self.build_output_shape)
297
+ self.build_output_shape = self.residual_blocks[-1].res_output_shape
298
+
299
+ # this is done to force keras to add the layers in the list to self._layers
300
+ for layer in self.residual_blocks:
301
+ self.__setattr__(layer.name, layer)
302
+
303
+ self.output_slice_index = None
304
+ if self.padding == 'same':
305
+ time = self.build_output_shape.as_list()[1]
306
+ if time is not None: # if time dimension is defined. e.g. shape = (bs, 500, input_dim).
307
+ self.output_slice_index = int(self.build_output_shape.as_list()[1] / 2)
308
+ else:
309
+ # It will known at call time. c.f. self.call.
310
+ self.padding_same_and_time_dim_unknown = True
311
+
312
+ else:
313
+ self.output_slice_index = -1 # causal case.
314
+ self.slicer_layer = Lambda(lambda tt: tt[:, self.output_slice_index, :], name='Slice_Output')
315
+
316
+ if type(self.build_output_shape) == tuple:
317
+ static = list(self.build_output_shape)
318
+ else:
319
+ static = self.build_output_shape.as_list()
320
+ self.slicer_layer.build(static)
321
+
322
+ def compute_output_shape(self, input_shape):
323
+ """
324
+ Overridden in case keras uses it somewhere... no idea. Just trying to avoid future errors.
325
+ """
326
+ if not self.built:
327
+ self.build(input_shape)
328
+ if not self.return_sequences:
329
+ batch_size = self.build_output_shape[0]
330
+ batch_size = batch_size.value if hasattr(batch_size, 'value') else batch_size
331
+ nb_filters = self.build_output_shape[-1]
332
+ return [batch_size, nb_filters]
333
+ else:
334
+ # Compatibility tensorflow 1.x
335
+ return [v.value if hasattr(v, 'value') else v for v in self.build_output_shape]
336
+
337
+ def call(self, inputs, training=None, **kwargs):
338
+ x = inputs
339
+ self.layers_outputs = [x]
340
+ self.skip_connections = []
341
+ for res_block in self.residual_blocks:
342
+ # try:
343
+ # x, skip_out = res_block(x, training=training)
344
+ # except TypeError: # compatibility with tensorflow 1.x
345
+ # x, skip_out = res_block(K.cast(x, 'float32'), training=training)
346
+ x, skip_out = res_block(x, training=training)
347
+
348
+ self.skip_connections.append(skip_out)
349
+ self.layers_outputs.append(x)
350
+
351
+ if self.use_skip_connections:
352
+ x = layers.add(self.skip_connections, name='Add_Skip_Connections')
353
+ self.layers_outputs.append(x)
354
+
355
+ if not self.return_sequences:
356
+ # case: time dimension is unknown. e.g. (bs, None, input_dim).
357
+ if self.padding_same_and_time_dim_unknown:
358
+ self.output_slice_index = K.shape(self.layers_outputs[-1])[1] // 2
359
+ x = self.slicer_layer(x)
360
+ self.layers_outputs.append(x)
361
+ return x
362
+
363
+ def get_config(self):
364
+ """
365
+ Returns the config of a the layer. This is used for saving and loading from a model
366
+ :return: python dictionary with specs to rebuild layer
367
+ """
368
+ config = super(TCN, self).get_config()
369
+ config['nb_filters'] = self.nb_filters
370
+ config['kernel_size'] = self.kernel_size
371
+ config['nb_stacks'] = self.nb_stacks
372
+ config['dilations'] = self.dilations
373
+ config['padding'] = self.padding
374
+ config['use_skip_connections'] = self.use_skip_connections
375
+ config['dropout_rate'] = self.dropout_rate
376
+ config['return_sequences'] = self.return_sequences
377
+ config['activation'] = self.activation_name
378
+ config['use_batch_norm'] = self.use_batch_norm
379
+ config['use_layer_norm'] = self.use_layer_norm
380
+ config['use_weight_norm'] = self.use_weight_norm
381
+ config['kernel_initializer'] = self.kernel_initializer
382
+ return config
383
+
384
+
385
+ def compiled_tcn(num_feat, # type: int
386
+ num_classes, # type: int
387
+ nb_filters, # type: int
388
+ kernel_size, # type: int
389
+ dilations, # type: List[int]
390
+ nb_stacks, # type: int
391
+ max_len, # type: int
392
+ output_len=1, # type: int
393
+ padding='causal', # type: str
394
+ use_skip_connections=False, # type: bool
395
+ return_sequences=True,
396
+ regression=False, # type: bool
397
+ dropout_rate=0.05, # type: float
398
+ name='tcn', # type: str,
399
+ kernel_initializer='he_normal', # type: str,
400
+ activation='relu', # type:str,
401
+ opt='adam',
402
+ lr=0.002,
403
+ use_batch_norm=False,
404
+ use_layer_norm=False,
405
+ use_weight_norm=False):
406
+ # type: (...) -> Model
407
+ """Creates a compiled TCN model for a given task (i.e. regression or classification).
408
+ Classification uses a sparse categorical loss. Please input class ids and not one-hot encodings.
409
+ Args:
410
+ num_feat: The number of features of your input, i.e. the last dimension of: (batch_size, timesteps, input_dim).
411
+ num_classes: The size of the final dense layer, how many classes we are predicting.
412
+ nb_filters: The number of filters to use in the convolutional layers.
413
+ kernel_size: The size of the kernel to use in each convolutional layer.
414
+ dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
415
+ nb_stacks : The number of stacks of residual blocks to use.
416
+ max_len: The maximum sequence length, use None if the sequence length is dynamic.
417
+ padding: The padding to use in the convolutional layers.
418
+ use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
419
+ return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
420
+ regression: Whether the output should be continuous or discrete.
421
+ dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
422
+ activation: The activation used in the residual blocks o = Activation(x + F(x)).
423
+ name: Name of the model. Useful when having multiple TCN.
424
+ kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
425
+ opt: Optimizer name.
426
+ lr: Learning rate.
427
+ use_batch_norm: Whether to use batch normalization in the residual layers or not.
428
+ use_layer_norm: Whether to use layer normalization in the residual layers or not.
429
+ use_weight_norm: Whether to use weight normalization in the residual layers or not.
430
+ Returns:
431
+ A compiled keras TCN.
432
+ """
433
+
434
+ dilations = adjust_dilations(dilations)
435
+
436
+ input_layer = Input(shape=(max_len, num_feat))
437
+
438
+ x = TCN(nb_filters, kernel_size, nb_stacks, dilations, padding,
439
+ use_skip_connections, dropout_rate, return_sequences,
440
+ activation, kernel_initializer, use_batch_norm, use_layer_norm,
441
+ use_weight_norm, name=name)(input_layer)
442
+
443
+ print('x.shape=', x.shape)
444
+
445
+ def get_opt():
446
+ if opt == 'adam':
447
+ return optimizers.Adam(lr=lr, clipnorm=1.)
448
+ elif opt == 'rmsprop':
449
+ return optimizers.RMSprop(lr=lr, clipnorm=1.)
450
+ else:
451
+ raise Exception('Only Adam and RMSProp are available here')
452
+
453
+ if not regression:
454
+ # classification
455
+ print('asdasfdasfa')
456
+ x = Dense(num_classes)(x)
457
+ x = Activation('softmax')(x)
458
+ output_layer = x
459
+ model = Model(input_layer, output_layer)
460
+
461
+ # https://github.com/keras-team/keras/pull/11373
462
+ # It's now in Keras@master but still not available with pip.
463
+ # TODO remove later.
464
+ def accuracy(y_true, y_pred):
465
+ # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
466
+ if K.ndim(y_true) == K.ndim(y_pred):
467
+ y_true = K.squeeze(y_true, -1)
468
+ # convert dense predictions to labels
469
+ y_pred_labels = K.argmax(y_pred, axis=-1)
470
+ y_pred_labels = K.cast(y_pred_labels, KK.floatx())
471
+ return K.cast(K.equal(y_true, y_pred_labels), KK.floatx())
472
+
473
+ model.compile(get_opt(), loss='sparse_categorical_crossentropy', metrics=[accuracy])
474
+ else:
475
+ # regression
476
+ x = Dense(output_len)(x)
477
+ x = Activation('linear')(x)
478
+ output_layer = x
479
+ model = Model(input_layer, output_layer)
480
+ model.compile(get_opt(), loss='mean_squared_error')
481
+ print('model.x = {}'.format(input_layer.shape))
482
+ print('model.y = {}'.format(output_layer.shape))
483
+ return model
484
+
485
+
486
+ def tcn_full_summary(model: Model, expand_residual_blocks=True):
487
+
488
+ layers = model._layers.copy() # store existing layers
489
+ model._layers.clear() # clear layers
490
+
491
+ for i in range(len(layers)):
492
+ if isinstance(layers[i], TCN):
493
+ for layer in layers[i]._layers:
494
+ if not isinstance(layer, ResidualBlock):
495
+ if not hasattr(layer, '__iter__'):
496
+ model._layers.append(layer)
497
+ else:
498
+ if expand_residual_blocks:
499
+ for lyr in layer._layers:
500
+ if not hasattr(lyr, '__iter__'):
501
+ model._layers.append(lyr)
502
+ else:
503
+ model._layers.append(layer)
504
+ else:
505
+ model._layers.append(layers[i])
506
+
507
+ model.summary() # print summary
508
+
509
+ # restore original layers
510
+ model._layers.clear()
511
+ [model._layers.append(lyr) for lyr in layers]
pages/inference.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import streamlit as st
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import random
8
+ from sklearn.preprocessing import MinMaxScaler
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from model.lstm import LSTMModel
13
+ from model.tcn import TCNModel
14
+ from model.tcn import move_custom_layers_to_device
15
+ from utils.lowlevel import LowLevel
16
+ from utils.highlevel import HighLevel
17
+ from utils.midpoint import MidPoint
18
+
19
+ from utils.transform import compute_gradient
20
+
21
+ st.set_page_config(page_title="Inference", page_icon=":chart_with_upwards_trend:", layout="wide", initial_sidebar_state="auto")
22
+
23
+ def uniform_sampling(data, n_sample):
24
+ k = len(data) // n_sample
25
+ return data[::k]
26
+
27
+ def low_level(option_time, slider_sample_orbit, progress_bar):
28
+ time.sleep(0.1)
29
+ low_level_total_start_time = time.time()
30
+ low_level_30000_start_time = time.time()
31
+
32
+ lowlevelhelper = LowLevel(j=slider_sample_orbit)
33
+ j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
34
+
35
+ a1 = 1 / (2 - 2 ** (1 / 3))
36
+ a2 = 1 - 2 * a1
37
+ jn = 0
38
+ t = 0.1
39
+
40
+ # Calculate the total number of iterations for the progress bar update
41
+ total_iterations = (float(option_time) - t) / h
42
+ current_iteration = 0
43
+
44
+ original_low_level_data = []
45
+
46
+ while t < float(option_time):
47
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
48
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
49
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
50
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
51
+
52
+ t = t + h
53
+
54
+ if jn % 10 == 0:
55
+ original_low_level_data.append([b, x, y, z, px, py, pz])
56
+ # Update progress bar
57
+ progress_percentage = int((current_iteration / total_iterations) * 100)
58
+ progress_bar.progress(progress_percentage)
59
+
60
+ if jn == 300000:
61
+ low_level_30000_end_time = time.time()
62
+ low_level_30000_execute_time = low_level_30000_end_time - low_level_30000_start_time
63
+ low_level_2000_start_time = time.time()
64
+ jn = jn + 1
65
+ current_iteration += 1
66
+
67
+ progress_bar.progress(100)
68
+
69
+ low_level_2000_end_time = time.time()
70
+ low_level_2000_execute_time = low_level_2000_end_time - low_level_2000_start_time
71
+ low_level_total_end_time = time.time()
72
+ low_level_total_execute_time = low_level_total_end_time - low_level_total_start_time
73
+
74
+ result = uniform_sampling(np.array(original_low_level_data), n_sample=int(option_time/100))
75
+
76
+ return low_level_30000_execute_time, low_level_2000_execute_time, low_level_total_execute_time, result
77
+
78
+ def high_level(option_time, slider_sample_orbit, progress_bar):
79
+ time.sleep(0.1)
80
+ high_level_total_start_time = time.time()
81
+ high_level_30000_start_time = time.time()
82
+
83
+ highlevelhelper = HighLevel(j=slider_sample_orbit)
84
+ j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = highlevelhelper.initial()
85
+
86
+ a1 = 1 / (2 - 2 ** (1 / 3))
87
+ a2 = 1 - 2 * a1
88
+ jn = 0
89
+ t = 0.1
90
+
91
+ # Calculate the total number of iterations for the progress bar update
92
+ total_iterations = (float(option_time) - t) / h
93
+ current_iteration = 0
94
+
95
+ original_high_level_data = []
96
+
97
+ while t < float(option_time):
98
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
99
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
100
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
101
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
102
+
103
+ t = t + h
104
+ vx, vy, vz, vpx, vpy, vpz, e = highlevelhelper.f(x, y, z, px, py, pz, b)
105
+ en = np.asarray(e).astype(np.float64)
106
+
107
+ if jn % 10 == 0:
108
+ original_high_level_data.append([b, x, y, z, px, py, pz])
109
+ if jn == 300000:
110
+ high_level_30000_end_time = time.time()
111
+ high_level_30000_execute_time = high_level_30000_end_time - high_level_30000_start_time
112
+ high_level_2000_start_time = time.time()
113
+ jn = jn + 1
114
+
115
+ # Update progress bar
116
+ progress_percentage = int((current_iteration / total_iterations) * 100)
117
+ progress_bar.progress(progress_percentage)
118
+ current_iteration += 1
119
+
120
+ progress_bar.progress(100)
121
+ high_level_2000_end_time = time.time()
122
+ high_level_2000_execute_time = high_level_2000_end_time - high_level_2000_start_time
123
+ high_level_total_end_time = time.time()
124
+ high_level_total_execute_time = high_level_total_end_time - high_level_total_start_time
125
+
126
+ result = uniform_sampling(np.array(original_high_level_data), n_sample=int(option_time / 100))
127
+
128
+ return high_level_30000_execute_time, high_level_2000_execute_time, high_level_total_execute_time, result
129
+
130
+ def midpoint(option_time, slider_sample_orbit, progress_bar):
131
+ time.sleep(0.1)
132
+ mid_point_total_start_time = time.time()
133
+ mid_point_30000_start_time = time.time()
134
+
135
+ midpointhelper = MidPoint(j=slider_sample_orbit)
136
+ j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
137
+
138
+ #en0 = np.asarray(e0).astype(np.float64)
139
+ a1 = 1 / (2 - 2 ** (1 / 3))
140
+ a2 = 1 - 2 * a1
141
+ jn = 0
142
+ t = 0.1
143
+
144
+ # Calculate the total number of iterations for the progress bar update
145
+ total_iterations = (float(option_time) - t) / h
146
+ current_iteration = 0
147
+
148
+ original_mid_point_data = []
149
+
150
+ while t < float(option_time):
151
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
152
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
153
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
154
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
155
+
156
+ t = t + h
157
+
158
+ if jn % 10 == 0:
159
+ original_mid_point_data.append([b, x, y, z, px, py, pz])
160
+ if jn == 300000:
161
+ mid_point_30000_end_time = time.time()
162
+ mid_point_30000_execute_time = mid_point_30000_end_time - mid_point_30000_start_time
163
+ mid_point_2000_start_time = time.time()
164
+ jn = jn + 1
165
+
166
+ # Update progress bar
167
+ progress_percentage = int((current_iteration / total_iterations) * 100)
168
+ progress_bar.progress(progress_percentage)
169
+ current_iteration += 1
170
+
171
+ #mid_point_df.to_excel('mid_point_df_output.xlsx', index=False)
172
+ progress_bar.progress(100)
173
+ mid_point_2000_end_time = time.time()
174
+ mid_point_2000_execute_time = mid_point_2000_end_time - mid_point_2000_start_time
175
+ mid_point_total_end_time = time.time()
176
+ mid_point_total_execute_time = mid_point_total_end_time - mid_point_total_start_time
177
+
178
+ result = uniform_sampling(np.array(original_mid_point_data), n_sample=int(option_time / 100))
179
+
180
+ return mid_point_30000_execute_time, mid_point_2000_execute_time, mid_point_total_execute_time, result
181
+
182
+ def low_level_lstm(slider_sample_orbit, lstm_progress_bar):
183
+ time.sleep(0.1)
184
+ total_start_time = time.time()
185
+
186
+ lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
187
+ lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
188
+ lstm_model.to("cpu")
189
+ lstm_model.eval()
190
+
191
+ # Initialize variables for the classical method
192
+ lowlevelhelper = LowLevel(j=slider_sample_orbit)
193
+ j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
194
+
195
+ a1 = 1 / (2 - 2 ** (1 / 3))
196
+ a2 = 1 - 2 * a1
197
+ jn = 0
198
+ t = 0.1
199
+
200
+ # Calculate the total number of iterations for the progress bar update
201
+ total_iterations = (float(30000) - t) / h
202
+ current_iteration = 0
203
+
204
+ original_low_level_data = []
205
+
206
+ low_level_start_time = time.time()
207
+
208
+ # Perform classical method prediction for the initial segment
209
+ while t < float(30000):
210
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
211
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
212
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
213
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
214
+ t = t + h
215
+
216
+ if jn % 10 == 0:
217
+ original_low_level_data.append([b, x, y, z, px, py, pz])
218
+ # Update progress bar
219
+ progress_percentage = int((current_iteration / total_iterations) * 100)
220
+ lstm_progress_bar.progress(progress_percentage)
221
+
222
+ jn = jn + 1
223
+ current_iteration += 1
224
+
225
+ original_low_level_data = np.array(original_low_level_data)
226
+ low_level_end_time = time.time()
227
+ low_level_data = original_low_level_data.copy()
228
+ low_level_data = uniform_sampling(low_level_data, n_sample=300)
229
+ scaler = MinMaxScaler()
230
+ low_level_data = scaler.fit_transform(low_level_data)
231
+ low_level_data = torch.tensor(np.stack(low_level_data)).float()
232
+ low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)
233
+
234
+ lstm_start_time = time.time()
235
+ with torch.no_grad():
236
+ lstm_preds = lstm_model(low_level_data[:, 100:300, :])
237
+ lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
238
+
239
+ original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)
240
+
241
+ lstm_end_time = time.time()
242
+ lstm_progress_bar.progress(100)
243
+
244
+ combined_preds = np.concatenate([original_low_level_data, lstm_innv_preds], axis=0)
245
+
246
+ lstm_total_time = lstm_end_time - lstm_start_time
247
+ low_level_total_time = low_level_end_time - low_level_start_time
248
+
249
+ total_end_time = time.time()
250
+ total_time = total_end_time - total_start_time
251
+
252
+ return low_level_total_time, lstm_total_time, total_time, combined_preds
253
+
254
+ def mid_point_lstm(slider_sample_orbit, lstm_progress_bar):
255
+ time.sleep(0.1)
256
+ total_start_time = time.time()
257
+
258
+ lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
259
+ lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
260
+ lstm_model.to("cpu")
261
+ lstm_model.eval()
262
+
263
+ midpointhelper = MidPoint(j=slider_sample_orbit)
264
+ j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
265
+
266
+ a1 = 1 / (2 - 2 ** (1 / 3))
267
+ a2 = 1 - 2 * a1
268
+ jn = 0
269
+ t = 0.1
270
+
271
+ # Calculate the total number of iterations for the progress bar update
272
+ total_iterations = (float(30000) - t) / h
273
+ current_iteration = 0
274
+
275
+ original_mid_point_data = []
276
+
277
+ mid_point_start_time = time.time()
278
+
279
+ while t < float(30000):
280
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
281
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
282
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
283
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
284
+
285
+ t = t + h
286
+
287
+ if jn % 10 == 0:
288
+ original_mid_point_data.append([b, x, y, z, px, py, pz])
289
+ # Update progress bar
290
+ progress_percentage = int((current_iteration / total_iterations) * 100)
291
+ lstm_progress_bar.progress(progress_percentage)
292
+ jn = jn + 1
293
+ current_iteration += 1
294
+
295
+ original_mid_point_data = np.array(original_mid_point_data)
296
+ mid_point_end_time = time.time()
297
+ mid_point_data = original_mid_point_data.copy()
298
+ mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
299
+ scaler = MinMaxScaler()
300
+ mid_point_data = scaler.fit_transform(mid_point_data)
301
+ mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
302
+ mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)
303
+
304
+ lstm_start_time = time.time()
305
+ with torch.no_grad():
306
+ lstm_preds = lstm_model(mid_point_data[:, 100:300, :])
307
+ lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
308
+
309
+ original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)
310
+
311
+ lstm_end_time = time.time()
312
+ lstm_progress_bar.progress(100)
313
+
314
+ combined_preds = np.concatenate([original_mid_point_data, lstm_innv_preds], axis=0)
315
+
316
+ lstm_total_time = lstm_end_time - lstm_start_time
317
+ mid_point_total_time = mid_point_end_time - mid_point_start_time
318
+
319
+ total_end_time = time.time()
320
+ total_time = total_end_time - total_start_time
321
+
322
+ return mid_point_total_time, lstm_total_time, total_time, combined_preds
323
+
324
+ def low_level_tcn(slider_sample_orbit, tcn_progress_bar):
325
+ time.sleep(0.1)
326
+ total_start_time = time.time()
327
+
328
+ tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
329
+ tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
330
+ move_custom_layers_to_device(tcn_model, "cpu")
331
+ tcn_model.eval()
332
+
333
+ # Initialize variables for the classical method
334
+ lowlevelhelper = LowLevel(j=slider_sample_orbit)
335
+ j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
336
+
337
+ a1 = 1 / (2 - 2 ** (1 / 3))
338
+ a2 = 1 - 2 * a1
339
+ jn = 0
340
+ t = 0.1
341
+
342
+ # Calculate the total number of iterations for the progress bar update
343
+ total_iterations = (float(30000) - t) / h
344
+ current_iteration = 0
345
+
346
+ original_low_level_data = []
347
+
348
+ low_level_start_time = time.time()
349
+
350
+ # Perform classical method prediction for the initial segment
351
+ while t < float(30000):
352
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
353
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
354
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
355
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza)
356
+
357
+ t = t + h
358
+
359
+ if jn % 10 == 0:
360
+ original_low_level_data.append([b, x, y, z, px, py, pz])
361
+ # Update progress bar
362
+ progress_percentage = int((current_iteration / total_iterations) * 100)
363
+ tcn_progress_bar.progress(progress_percentage)
364
+
365
+ jn = jn + 1
366
+ current_iteration += 1
367
+
368
+ original_low_level_data = np.array(original_low_level_data)
369
+ low_level_end_time = time.time()
370
+ low_level_data = original_low_level_data.copy()
371
+ low_level_data = uniform_sampling(low_level_data, n_sample=300)
372
+ scaler = MinMaxScaler()
373
+ low_level_data = scaler.fit_transform(low_level_data)
374
+ low_level_data = torch.tensor(np.stack(low_level_data)).float()
375
+ low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)
376
+
377
+ tcn_start_time = time.time()
378
+ with torch.no_grad():
379
+ tcn_preds = None
380
+ for i in range(20):
381
+ if i == 0:
382
+ tcn_preds = tcn_model(low_level_data[:, :300, :])
383
+ else:
384
+ gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
385
+ output = tcn_model(torch.cat([low_level_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
386
+ tcn_preds = torch.cat([tcn_preds, output], dim=1)
387
+ tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
388
+
389
+ original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)
390
+
391
+ tcn_end_time = time.time()
392
+ tcn_progress_bar.progress(100)
393
+
394
+ combined_preds = np.concatenate([original_low_level_data, tcn_innv_preds], axis=0)
395
+
396
+ tcn_total_time = tcn_end_time - tcn_start_time
397
+ low_level_total_time = low_level_end_time - low_level_start_time
398
+
399
+ total_end_time = time.time()
400
+ total_time = total_end_time - total_start_time
401
+
402
+ return low_level_total_time, tcn_total_time, total_time, combined_preds
403
+
404
+ def mid_point_tcn(slider_sample_orbit, tcn_progress_bar):
405
+ time.sleep(0.1)
406
+ total_start_time = time.time()
407
+
408
+ tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
409
+ tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
410
+ move_custom_layers_to_device(tcn_model, "cpu")
411
+ tcn_model.eval()
412
+
413
+ # Initialize variables for the classical method
414
+ midpointhelper = MidPoint(j=slider_sample_orbit)
415
+ j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
416
+
417
+ a1 = 1 / (2 - 2 ** (1 / 3))
418
+ a2 = 1 - 2 * a1
419
+ jn = 0
420
+ t = 0.1
421
+
422
+ # Calculate the total number of iterations for the progress bar update
423
+ total_iterations = (float(30000) - t) / h
424
+ current_iteration = 0
425
+
426
+ original_mid_point_data = []
427
+
428
+ mid_point_start_time = time.time()
429
+
430
+ # Perform classical method prediction for the initial segment
431
+ while t < float(30000):
432
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
433
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
434
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
435
+ x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza)
436
+
437
+ t = t + h
438
+
439
+ if jn % 10 == 0:
440
+ original_mid_point_data.append([b, x, y, z, px, py, pz])
441
+ # Update progress bar
442
+ progress_percentage = int((current_iteration / total_iterations) * 100)
443
+ tcn_progress_bar.progress(progress_percentage)
444
+ jn = jn + 1
445
+ current_iteration += 1
446
+
447
+ original_mid_point_data = np.array(original_mid_point_data)
448
+ mid_point_end_time = time.time()
449
+ mid_point_data = original_mid_point_data.copy()
450
+ mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
451
+ scaler = MinMaxScaler()
452
+ mid_point_data = scaler.fit_transform(mid_point_data)
453
+ mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
454
+ mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)
455
+
456
+ tcn_start_time = time.time()
457
+ with torch.no_grad():
458
+ tcn_preds = None
459
+ for i in range(20):
460
+ if i == 0:
461
+ tcn_preds = tcn_model(mid_point_data[:, :300, :])
462
+ else:
463
+ gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
464
+ output = tcn_model(torch.cat([mid_point_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
465
+ tcn_preds = torch.cat([tcn_preds, output], dim=1)
466
+ tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
467
+
468
+ original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)
469
+
470
+ tcn_end_time = time.time()
471
+ tcn_progress_bar.progress(100)
472
+
473
+ combined_preds = np.concatenate([original_mid_point_data, tcn_innv_preds], axis=0)
474
+
475
+ tcn_total_time = tcn_end_time - tcn_start_time
476
+ mid_point_total_time = mid_point_end_time - mid_point_start_time
477
+
478
+ total_end_time = time.time()
479
+ total_time = total_end_time - total_start_time
480
+
481
+ return mid_point_total_time, tcn_total_time, total_time, combined_preds
482
+
483
+ container = st.container()
484
+ container1, container2 = st.columns(2)
485
+ plot_container = st.container()
486
+
487
+ with st.sidebar:
488
+ slider_sample_orbit = st.slider('Orbit Sample ID', 1, 10, 1)
489
+ option_time = 32000
490
+ st.write(f'Total Time Step: {option_time}')
491
+ options_method = st.multiselect(
492
+ 'Compared Methods',
493
+ ['Low-Level', 'High-Level', 'Midpoint', 'Low-Level with LSTM', 'Low-Level with TCN', 'Midpoint with LSTM', 'Midpoint with TCN'],
494
+ ['Low-Level'])
495
+ btn_go = st.button("Go", type="primary", use_container_width=True)
496
+
497
+ if btn_go:
498
+ if 'Low-Level' in options_method:
499
+ with container1:
500
+ st.write('Low Level Progress Bar')
501
+ low_level_progress_bar = st.progress(0)
502
+ low_level_30000_time, low_level_2000_time, low_level_total_time, low_level_result = low_level(option_time, slider_sample_orbit, low_level_progress_bar)
503
+ with container2:
504
+ st.table(pd.DataFrame({'Model':"Low Level", '30000 Time Steps (s)': [low_level_30000_time], '2000 Time Steps (s)': [low_level_2000_time], 'Total Time (s)': [low_level_total_time]}))
505
+ if 'High-Level' in options_method:
506
+ with container1:
507
+ st.write('High Level Progress Bar')
508
+ high_level_progress_bar = st.progress(0)
509
+ high_level_30000_time, high_level_2000_time, high_level_total_time, high_level_result = high_level(option_time, slider_sample_orbit, high_level_progress_bar)
510
+ with container2:
511
+ st.table(pd.DataFrame({'Model':"High Level", '30000 Time Steps (s)': [high_level_30000_time], '2000 Time Steps (s)': [high_level_2000_time], 'Total Time (s)': [high_level_total_time]}))
512
+ if 'Midpoint' in options_method:
513
+ with container1:
514
+ st.write('Midpoint Progress Bar')
515
+ mid_point_progress_bar = st.progress(0)
516
+ mid_point_30000_time, mid_point_2000_time, mid_point_total_time, mid_point_result = midpoint(option_time, slider_sample_orbit, mid_point_progress_bar)
517
+ with container2:
518
+ st.table(pd.DataFrame({'Model':"Midpoint", '30000 Time Steps (s)': [mid_point_30000_time], '2000 Time Steps (s)': [mid_point_2000_time], 'Total Time (s)': [mid_point_total_time]}))
519
+ if 'Low-Level with LSTM' in options_method:
520
+ with container1:
521
+ st.write('Low Level LSTM Progress Bar')
522
+ low_level_lstm_progress_bar = st.progress(0)
523
+ lstm_30000_time, lstm_2000_time, lstm_total_time, lstm_result = low_level_lstm(slider_sample_orbit, low_level_lstm_progress_bar)
524
+ with container2:
525
+ st.table(pd.DataFrame({'Model':"Low Level + LSTM", '30000 Time Steps (s)': [lstm_30000_time], '2000 Time Steps (s)': [lstm_2000_time], 'Total Time (s)': [lstm_total_time]}))
526
+ if 'Low-Level with TCN' in options_method:
527
+ with container1:
528
+ st.write('Low Level TCN Progress Bar')
529
+ low_level_tcn_progress_bar = st.progress(0)
530
+ tcn_30000_time, tcn_2000_time, tcn_total_time, tcn_result = low_level_tcn(slider_sample_orbit, low_level_tcn_progress_bar)
531
+ with container2:
532
+ st.table(pd.DataFrame({'Model':"Low Level + TCN", '30000 Time Steps (s)': [tcn_30000_time], '2000 Time Steps (s)': [tcn_2000_time], 'Total Time (s)': [tcn_total_time]}))
533
+ if 'Midpoint with LSTM' in options_method:
534
+ with container1:
535
+ st.write('Midpoint LSTM Progress Bar')
536
+ mid_point_lstm_progress_bar = st.progress(0)
537
+ md_lstm_30000_time, md_lstm_2000_time, md_lstm_total_time, md_lstm_result = mid_point_lstm(slider_sample_orbit, mid_point_lstm_progress_bar)
538
+ with container2:
539
+ st.table(pd.DataFrame({'Model':"Midpoint + LSTM", '30000 Time Steps (s)': [md_lstm_30000_time], '2000 Time Steps (s)': [md_lstm_2000_time], 'Total Time (s)': [md_lstm_total_time]}))
540
+ if 'Midpoint with TCN' in options_method:
541
+ with container1:
542
+ st.write('Midpoint TCN Progress Bar')
543
+ mid_point_tcn_progress_bar = st.progress(0)
544
+ md_tcn_30000_time, md_tcn_2000_time, md_tcn_total_time, md_tcn_result = mid_point_tcn(slider_sample_orbit, mid_point_tcn_progress_bar)
545
+ with container2:
546
+ st.table(pd.DataFrame({'Model':"Midpoint + TCN", '30000 Time Steps (s)': [md_tcn_30000_time], '2000 Time Steps (s)': [md_tcn_2000_time], 'Total Time (s)': [md_tcn_total_time]}))
547
+
prediction.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ from sklearn.preprocessing import MinMaxScaler
6
+ import torch
7
+ import time
8
+
9
+ from utils.transform import compute_gradient
10
+ from model.lstm import LSTMModel
11
+ from model.tcn import TCNModel
12
+ from model.tcn import move_custom_layers_to_device
13
+ from utils.metrics import calculate_metrics
14
+
15
+ each_feature_name = ["q_1","q_2","q_3","p_1","p_2","p_3"]
16
+
17
+ def uniform_sampling(data, n_sample):
18
+ k = len(data) // n_sample
19
+ return data[::k]
20
+
21
+ st.set_page_config(page_title="Prediction", page_icon=":chart_with_upwards_trend:", layout="wide", initial_sidebar_state="auto")
22
+
23
+ #st.title("Prediction")
24
+
25
+ with st.sidebar:
26
+ slider_predict_step = st.slider('Predicted Step', 0, 20, 20)
27
+
28
+ number_input_sample_id = st.number_input("Select Sample ID 1~10", value=1, placeholder="Type a number...", min_value=1, max_value=10, step=1)
29
+
30
+ squences_start_idx = st.slider('Squences Start Index', 0, 700 - slider_predict_step, 0)
31
+
32
+ st.subheader("Model Configuration")
33
+ st.write("LSTM Window Size: ", 200)
34
+ st.write("TCN Window Size: ", 300)
35
+ st.write("Predicted Step: ", slider_predict_step)
36
+ st.write("Feature Augmentation: ", "Second-order derivative")
37
+
38
+ file_path = os.path.join("data", "file"+str(number_input_sample_id)+".dat.npz")
39
+ data = pd.DataFrame(np.load(file_path)['data'])
40
+ scaler = MinMaxScaler()
41
+ uniform_data = uniform_sampling(data, n_sample=1000).sort_index().values[:, 1:8]
42
+ normal_uniform_data = scaler.fit_transform(uniform_data)
43
+ data_sequences = torch.tensor(np.stack(normal_uniform_data)).float()
44
+ original_data_sequences = torch.tensor(np.stack(uniform_data)).float()
45
+ selected_data = data_sequences[squences_start_idx:squences_start_idx+300+slider_predict_step]
46
+ original_selected_data = original_data_sequences[squences_start_idx:squences_start_idx+300+slider_predict_step]
47
+ input_data = torch.stack([compute_gradient(i, degree=2) for i in selected_data]).unsqueeze(0)
48
+
49
+ with st.sidebar:
50
+ st.subheader("Data Configuration")
51
+ st.write("Sample ID: ", number_input_sample_id)
52
+ #st.write("Origianl Shape: ", data_sequences.shape)
53
+ st.write("Squences Start Index: ", squences_start_idx)
54
+ #st.write("Selected Shape: ", selected_data.shape)
55
+ #st.write("Input Shape: ", input_data.shape)
56
+
57
+ # st.write(selected_data[0])
58
+ # st.write(input_data[0][0])
59
+
60
+ #################################################
61
+ ## LSTM GPU Inference
62
+ #################################################
63
+ lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
64
+ lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
65
+ lstm_model.eval()
66
+ lstm_start_time = time.time()
67
+ with torch.no_grad():
68
+ lstm_preds = lstm_model(input_data[:, 100:300, :].cuda())
69
+ lstm_end_time = time.time()
70
+ lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
71
+ lstm_normal_preds = lstm_preds.squeeze().cpu().numpy()
72
+
73
+ #lstm_model.to_onnx("model/lstm.onnx", torch.randn((1, 200, 21)), export_params=True)
74
+
75
+ del lstm_model
76
+
77
+ #################################################
78
+ ## LSTM CPU Inference
79
+ #################################################
80
+ lstm_cpu_ckpt_file = os.path.join("model", "lstm.ckpt")
81
+ lstm_cpu_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
82
+ lstm_cpu_model.to("cpu")
83
+ lstm_cpu_model.eval()
84
+ lstm_cpu_start_time = time.time()
85
+ with torch.no_grad():
86
+ lstm_cpu_preds = lstm_cpu_model(input_data[:, 100:300, :])
87
+ lstm_cpu_end_time = time.time()
88
+
89
+ del lstm_cpu_model
90
+
91
+ #################################################
92
+ ## TCN GPU Inference
93
+ #################################################
94
+ tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
95
+ tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
96
+ tcn_model.eval()
97
+ tcn_start_time = time.time()
98
+ with torch.no_grad():
99
+ input_data_cuda = input_data[:,:300,:].cuda()
100
+ y_hat = tcn_model(input_data_cuda)
101
+ for i in range(1, slider_predict_step):
102
+ gd_y_hat = compute_gradient(y_hat[:, :i, :], degree=2).cuda()
103
+ output = tcn_model(torch.cat([input_data[:, i:300, :].cuda(), gd_y_hat], dim=1)).cuda()
104
+ y_hat = torch.cat([y_hat, output], dim=1)
105
+ tcn_end_time = time.time()
106
+ tcn_preds = y_hat
107
+ tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
108
+ tcn_normal_preds = tcn_preds.squeeze().cpu().numpy()
109
+
110
+ #tcn_model.to_onnx("model/tcn.onnx", torch.randn((1, 300, 21)), export_params=True)
111
+
112
+ del tcn_model
113
+ del y_hat, gd_y_hat, output
114
+
115
+ #################################################
116
+ ## TCN CPU Inference
117
+ #################################################
118
+ input_data_cpu = input_data.to("cpu")
119
+ tcn_cpu_ckpt_file = os.path.join("model", "tcn.ckpt")
120
+ tcn_cpu_model = TCNModel.load_from_checkpoint(tcn_cpu_ckpt_file)
121
+ move_custom_layers_to_device(tcn_cpu_model, "cpu")
122
+ tcn_cpu_model.eval()
123
+ tcn_cpu_start_time = time.time()
124
+ with torch.no_grad():
125
+ y_hat = None
126
+ for i in range(slider_predict_step):
127
+ if i == 0:
128
+ y_hat = tcn_cpu_model(input_data_cpu[:,:300,:])
129
+ else:
130
+ gd_y_hat = compute_gradient(y_hat[:, :i, :], degree=2).to('cpu')
131
+ output = tcn_cpu_model(torch.concatenate([input_data_cpu[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
132
+ y_hat = torch.concatenate([y_hat, output], dim=1)
133
+ tcn_cpu_preds = y_hat
134
+ tcn_cpu_end_time = time.time()
135
+
136
+ del tcn_cpu_model
137
+
138
+ st.subheader("Normalized Prediction")
139
+
140
+ i = 1
141
+ for each_col in st.columns(6):
142
+ with each_col:
143
+ raw_data = selected_data[:, i]
144
+ lstm_data = [np.nan] * 300 + lstm_normal_preds[:slider_predict_step, :][:, i].tolist()
145
+ tcn_data = [np.nan] * 300 + tcn_normal_preds[:, i].tolist()
146
+ st.markdown(f"<div style='text-align: center'>{each_feature_name[i-1]}</div>", unsafe_allow_html=True)
147
+ #st.write(np.array(raw_data).shape, np.array(lstm_data).shape, np.array(tcn_data).shape)
148
+ st.line_chart(pd.DataFrame({"Original": raw_data, "LSTM": lstm_data, "TCN": tcn_data}),
149
+ color=["#EE4035", "#0077BB", "#7BC043"])
150
+ i += 1
151
+
152
+ # with st.sidebar:
153
+ # st.write("Predicted Shape: ", lstm_preds.shape)
154
+
155
+ st.subheader("Inverse Normalized Prediction")
156
+
157
+ i = 1
158
+ for each_col in st.columns(6):
159
+ with each_col:
160
+ raw_data = original_selected_data[:, i]
161
+ lstm_data = [np.nan] * 300 + lstm_innv_preds[:slider_predict_step, :][:, i].tolist()
162
+ tcn_data = [np.nan] * 300 + tcn_innv_preds[:, i].tolist()
163
+ st.markdown(f"<div style='text-align: center'>{each_feature_name[i - 1]}</div>", unsafe_allow_html=True)
164
+ st.line_chart(pd.DataFrame({"Original": raw_data, "LSTM": lstm_data, "TCN": tcn_data}),
165
+ color=["#EE4035", "#0077BB", "#7BC043"])
166
+ i += 1
167
+
168
+ LSTM_SMAPE, LSTM_MSE, LSTM_RMSE, LSTM_MAE, LSTM_R2, LSTM_PSD = calculate_metrics(selected_data[300:300+slider_predict_step, :].cpu().numpy(), lstm_normal_preds[:slider_predict_step, :])
169
+ TCN_SMAPE, TCN_MSE, TCN_RMSE, TCN_MAE, TCN_R2, TCN_PSD = calculate_metrics(selected_data[300:300+slider_predict_step, :].cpu().numpy(), tcn_normal_preds)
170
+
171
+ results_df = pd.DataFrame({
172
+ "Model": ["LSTM", "TCN"],
173
+ "SMAPE": [LSTM_SMAPE, TCN_SMAPE],
174
+ "MSE": [LSTM_MSE, TCN_MSE],
175
+ "RMSE": [LSTM_RMSE, TCN_RMSE],
176
+ "MAE": [LSTM_MAE, TCN_MAE],
177
+ "R2": [LSTM_R2, TCN_R2],
178
+ "PSD": [LSTM_PSD, TCN_PSD]
179
+ })
180
+
181
+ time_df = pd.DataFrame({
182
+ "Model": ["LSTM-GPU", "TCN-GPU", "LSTM-CPU", "TCN-CPU"],
183
+ "Time(ms)": [(lstm_end_time - lstm_start_time)*1000,
184
+ (tcn_end_time - tcn_start_time)*1000,
185
+ (lstm_cpu_end_time - lstm_cpu_start_time)*1000,
186
+ (tcn_cpu_end_time - tcn_cpu_start_time)*1000]
187
+ })
188
+
189
+ col1, col2 = st.columns(2)
190
+ with col1:
191
+ st.subheader("Evaluation Metrics")
192
+ st.write(results_df)
193
+
194
+ with col2:
195
+ st.subheader("Prediction Time")
196
+ st.write(time_df)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pandas==1.5.3
2
+ scikit-learn==1.3.2
3
+ torch==2.1.1
4
+ streamlit==1.32.2
5
+ keras==3.0.0
6
+ torchvision==0.16.1
7
+ pytorch-lightning==2.1.2
utils/__pycache__/highlevel.cpython-310.pyc ADDED
Binary file (4.08 kB). View file
 
utils/__pycache__/lowlevel.cpython-310.pyc ADDED
Binary file (6.04 kB). View file
 
utils/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (2.22 kB). View file
 
utils/__pycache__/midpoint.cpython-310.pyc ADDED
Binary file (6.02 kB). View file
 
utils/__pycache__/transform.cpython-310.pyc ADDED
Binary file (448 Bytes). View file
 
utils/highlevel.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sympy as sp
2
+
3
+ class HighLevel():
4
+
5
+ def __init__(self, j):
6
+ self.j = j
7
+
8
+ def initial(self):
9
+ j = self.j
10
+
11
+ #init parameters
12
+ h, b, n, u = 0.1, None, None, None
13
+ x, y, z, px, py, pz = None, 0.1, 0.001, 0.01, None, 0.0001
14
+ xa, ya, za, pxa, pya, pza = None, None, None, None, None, None
15
+
16
+ if j == 1:
17
+ b = 5.0 / 4
18
+ n = b / (1 + b) ** 2
19
+ u = 1.0 / (1.0 / b + b + 2.0)
20
+ x = 10.0
21
+ py = 0.5
22
+ elif j == 2:
23
+ b = 3.0 / 4
24
+ n = b / (1 + b) ** 2
25
+ u = 1.0 / (1.0 / b + b + 2.0)
26
+ x = 8.3
27
+ py = 0.6
28
+ elif j == 3:
29
+ b = 3.0 / 2
30
+ x = 12.0
31
+ py = 0.4
32
+ elif j == 4:
33
+ b = 7.0 / 4
34
+ n = b / (1 + b) ** 2
35
+ u = 1.0 / (1.0 / b + b + 2.0)
36
+ x = 15.0
37
+ py = 0.35
38
+ elif j == 5:
39
+ b = 1.0
40
+ n = b / (1 + b) ** 2
41
+ u = 1.0 / (1.0 / b + b + 2.0)
42
+ x = 18.0
43
+ py = 0.3
44
+ elif j == 6:
45
+ b = 3.0 / 5
46
+ n = b / (1 + b) ** 2
47
+ u = 1.0 / (1.0 / b + b + 2.0)
48
+ x = 20.0
49
+ py = 0.25
50
+ elif j == 7:
51
+ b = 5.0 / 7
52
+ n = b / (1 + b) ** 2
53
+ u = 1.0 / (1.0 / b + b + 2.0)
54
+ x = 22.0
55
+ py = 0.22
56
+ elif j == 8:
57
+ b = 2.0
58
+ x = 26.0
59
+ py = 0.2
60
+ elif j == 9:
61
+ b = 0.5
62
+ n = b / (1 + b) ** 2
63
+ u = 1.0 / (1.0 / b + b + 2.0)
64
+ x = 30.0
65
+ y = 0.5
66
+ z = 0.1
67
+ pz = 0.01
68
+ elif j == 10:
69
+ b = 5.0
70
+ n = b / (1 + b) ** 2
71
+ u = 1.0 / (1.0 / b + b + 2.0)
72
+ x = 35.0
73
+ y = 2.0
74
+ z = 0.1
75
+ pz = 0.03
76
+ py = 0.15
77
+
78
+ xa, ya, za, pxa, pya, pza = x, y, z, px, py, pz
79
+ return j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza
80
+
81
+ def f(self, x, y, z, px, py, pz, b):
82
+ x_val, y_val, z_val, px_val, py_val, pz_val, b_val = x, y, z, px, py, pz, b
83
+ x, y, z, px, py, pz, b = sp.symbols('x y z px py pz b')
84
+
85
+ c = 1.0
86
+
87
+ u = 1 / (1 / b + b + 2)
88
+ ht = px ** 2 / 2 + py ** 2 / 2 + pz ** 2 / 2
89
+ hv = -1 / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
90
+ h1pn = 1 / (2 * x ** 2 + 2 * y ** 2 + 2 * z ** 2) - (((u + 3) * (px ** 2 + py ** 2 + pz ** 2)) / 2 + (u * (
91
+ (px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
92
+ 1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) ** 2) / 2) / (
93
+ x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + ((3 * u) / 8 - 1 / 8) * (
94
+ px ** 2 + py ** 2 + pz ** 2) ** 2
95
+
96
+ e = ht + hv + h1pn
97
+
98
+ de_dx = sp.diff(e, x)
99
+ de_dy = sp.diff(e, y)
100
+ de_dz = sp.diff(e, z)
101
+ de_dpx = sp.diff(e, px)
102
+ de_dpy = sp.diff(e, py)
103
+ de_dpz = sp.diff(e, pz)
104
+
105
+ de_dx_val = de_dx.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
106
+ de_dy_val = de_dy.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
107
+ de_dz_val = de_dz.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
108
+ de_dpx_val = de_dpx.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
109
+ de_dpy_val = de_dpy.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
110
+ de_dpz_val = de_dpz.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
111
+
112
+ e_val = e.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
113
+
114
+ return de_dx_val, de_dy_val, de_dz_val, de_dpx_val, de_dpy_val, de_dpz_val, e_val
115
+
116
+ def rejust(self, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza):
117
+
118
+ x = (x + xa) / 2
119
+ y = (y + ya) / 2
120
+ z = (z + za) / 2
121
+
122
+ px = (px + pxa) / 2
123
+ py = (py + pya) / 2
124
+ pz = (pz + pza) / 2
125
+ xa = x
126
+ ya = y
127
+ za = z
128
+ pxa = px
129
+ pya = py
130
+ pza = pz
131
+
132
+ return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
133
+
134
+ def symplectic(self, h, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza, b):
135
+
136
+ vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
137
+ x = x + h / 2 * vpx
138
+ y = y + h / 2 * vpy
139
+ z = z + h / 2 * vpz
140
+ pxa = pxa - h / 2 * vxa
141
+ pya = pya - h / 2 * vya
142
+ pza = pza - h / 2 * vza
143
+
144
+ vx, vy, vz, vpxa, vpya, vpza, e = self.f(x, y, z, pxa, pya, pza, b)
145
+ xa = xa + h * vpxa
146
+ ya = ya + h * vpya
147
+ za = za + h * vpza
148
+ px = px - h * vx
149
+ py = py - h * vy
150
+ pz = pz - h * vz
151
+
152
+ vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
153
+ x = x + h / 2 * vpx
154
+ y = y + h / 2 * vpy
155
+ z = z + h / 2 * vpz
156
+ pxa = pxa - h / 2 * vxa
157
+ pya = pya - h / 2 * vya
158
+ pza = pza - h / 2 * vza
159
+
160
+ return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
utils/lowlevel.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ class LowLevel():
5
+
6
+ def __init__(self, j):
7
+ self.j = j
8
+
9
+ def initial(self):
10
+ j = self.j
11
+
12
+ #init parameters
13
+ h, b, n, u = 0.1, None, None, None
14
+ x, y, z, px, py, pz = None, 0.1, 0.001, 0.01, None, 0.0001
15
+ xa, ya, za, pxa, pya, pza = None, None, None, None, None, None
16
+
17
+ if j == 1:
18
+ b = 5.0 / 4
19
+ n = b / (1 + b) ** 2
20
+ u = 1.0 / (1.0 / b + b + 2.0)
21
+ x = 10.0
22
+ py = 0.5
23
+ elif j == 2:
24
+ b = 3.0 / 4
25
+ n = b / (1 + b) ** 2
26
+ u = 1.0 / (1.0 / b + b + 2.0)
27
+ x = 8.3
28
+ py = 0.6
29
+ elif j == 3:
30
+ b = 3.0 / 2
31
+ x = 12.0
32
+ py = 0.4
33
+ elif j == 4:
34
+ b = 7.0 / 4
35
+ n = b / (1 + b) ** 2
36
+ u = 1.0 / (1.0 / b + b + 2.0)
37
+ x = 15.0
38
+ py = 0.35
39
+ elif j == 5:
40
+ b = 1.0
41
+ n = b / (1 + b) ** 2
42
+ u = 1.0 / (1.0 / b + b + 2.0)
43
+ x = 18.0
44
+ py = 0.3
45
+ elif j == 6:
46
+ b = 3.0 / 5
47
+ n = b / (1 + b) ** 2
48
+ u = 1.0 / (1.0 / b + b + 2.0)
49
+ x = 20.0
50
+ py = 0.25
51
+ elif j == 7:
52
+ b = 5.0 / 7
53
+ n = b / (1 + b) ** 2
54
+ u = 1.0 / (1.0 / b + b + 2.0)
55
+ x = 22.0
56
+ py = 0.22
57
+ elif j == 8:
58
+ b = 2.0
59
+ x = 26.0
60
+ py = 0.2
61
+ elif j == 9:
62
+ b = 0.5
63
+ n = b / (1 + b) ** 2
64
+ u = 1.0 / (1.0 / b + b + 2.0)
65
+ x = 30.0
66
+ y = 0.5
67
+ z = 0.1
68
+ pz = 0.01
69
+ elif j == 10:
70
+ b = 5.0
71
+ n = b / (1 + b) ** 2
72
+ u = 1.0 / (1.0 / b + b + 2.0)
73
+ x = 35.0
74
+ y = 2.0
75
+ z = 0.1
76
+ pz = 0.03
77
+ py = 0.15
78
+
79
+ xa, ya, za, pxa, pya, pza = x, y, z, px, py, pz
80
+ return j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza
81
+
82
+ def f(self, x, y, z, px, py, pz, b):
83
+ n = b / (1 + b)**2
84
+ u = 1 / (1 / b + b + 2)
85
+ ht = px**2 / 2 + py**2 / 2 + pz**2 / 2
86
+ hv = -1 / (x**2 + y**2 + z**2)**(1/2)
87
+ h1pn = 1/(2*x**2 + 2*y**2 + 2*z**2) - (((u + 3)*(px**2 + py**2 +pz**2))/2 + (u*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/ (x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))**2)/2)/(x**2 + y**2 + z**2)**(1/2) + ((3*u)/8 -1/8)*(px**2 + py**2 + pz**2)**2
88
+
89
+ e = ht + hv + h1pn
90
+
91
+ vnpx=px
92
+ v1pnpx=4*px*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (px*(n + 3) + (n*x*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
93
+ vpx=vnpx+v1pnpx
94
+
95
+ vnpy=py
96
+ v1pnpy=4*py*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (py*(n + 3) + (n*y*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
97
+ vpy=vnpy+v1pnpy
98
+
99
+ vnpz=pz
100
+ v1pnpz=4*pz*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (pz*(n + 3) + (n*z*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
101
+ vpz=vnpz+v1pnpz
102
+
103
+ vnx=x/(x**2 + y**2 + z**2)**(3/2)
104
+ v1pnx=(x*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*x)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((px*x**2)/(x**2 + y**2 + z**2)**(3/2) -px/(x**2 + y**2 + z**2)**(1/2) + (py*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*x*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
105
+ vx=vnx+v1pnx
106
+
107
+ vny=y/(x**2 + y**2 + z**2)**(3/2)
108
+ v1pny=(y*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*y)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((py*y**2)/(x**2 + y**2 + z**2)**(3/2) -py/(x**2 + y**2 + z**2)**(1/2) + (px*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
109
+ vy=vny+v1pny
110
+
111
+ vnz=z/(x**2 + y**2 + z**2)**(3/2)
112
+ v1pnz=(z*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*z)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((pz*z**2)/(x**2 + y**2 + z**2)**(3/2) -pz/(x**2 + y**2 + z**2)**(1/2) + (px*x*z)/(x**2 + y**2 + z**2)**(3/2) + (py*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
113
+ vz=vnz+v1pnz
114
+
115
+ return vx,vy,vz,vpx,vpy,vpz,e
116
+
117
+ def rejust(self, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza):
118
+ x = (x + xa) / 2
119
+ y = (y + ya) / 2
120
+ z = (z + za) / 2
121
+
122
+ px = (px + pxa) / 2
123
+ py = (py + pya) / 2
124
+ pz = (pz + pza) / 2
125
+ xa = x
126
+ ya = y
127
+ za = z
128
+ pxa = px
129
+ pya = py
130
+ pza = pz
131
+ return x,y,z,px,py,pz,xa,ya,za,pxa,pya,pza
132
+
133
+ def symplectic(self, h, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza, b):
134
+ vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
135
+ x = x + h / 2 * vpx
136
+ y = y + h / 2 * vpy
137
+ z = z + h / 2 * vpz
138
+ pxa = pxa - h / 2 * vxa
139
+ pya = pya - h / 2 * vya
140
+ pza = pza - h / 2 * vza
141
+
142
+ vx, vy, vz, vpxa, vpya, vpza, e = self.f(x, y, z, pxa, pya, pza, b)
143
+ xa = xa + h * vpxa
144
+ ya = ya + h * vpya
145
+ za = za + h * vpza
146
+ px = px - h * vx
147
+ py = py - h * vy
148
+ pz = pz - h * vz
149
+
150
+ vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
151
+ x = x + h / 2 * vpx
152
+ y = y + h / 2 * vpy
153
+ z = z + h / 2 * vpz
154
+ pxa = pxa - h / 2 * vxa
155
+ pya = pya - h / 2 * vya
156
+ pza = pza - h / 2 * vza
157
+
158
+ return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
utils/metrics.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
3
+
4
+ def calculate_metrics(y, y_hat, y_train=None):
5
+ def smape(a, f):
6
+ return 1/len(a) * np.sum(2 * np.abs(f - a) / (np.abs(a) + np.abs(f) + np.finfo(float).eps))
7
+
8
+ def mase(y_actual, y_pred, y_train):
9
+ n = y_train.shape[1]
10
+ d = np.abs(np.diff(y_train)).sum() / (n - 1)
11
+ errors = np.abs(y_actual - y_pred)
12
+ return errors.mean() / d
13
+
14
+ def phase_space_distance(y_actual, y_pred):
15
+ return np.sqrt(np.sum(np.square(y_actual - y_pred)))
16
+
17
+ SMAPE = np.mean([smape(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
18
+ MSE = np.mean([mean_squared_error(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
19
+ RMSE = np.mean([np.sqrt(mean_squared_error(yi.reshape(-1), y_hati.reshape(-1))) for yi, y_hati in zip(y, y_hat)])
20
+ MAE = np.mean([mean_absolute_error(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
21
+ R2 = np.mean([r2_score(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
22
+ PSD = np.mean([phase_space_distance(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
23
+
24
+ if y_train is None:
25
+ return SMAPE, MSE, RMSE, MAE, R2, PSD
26
+ else:
27
+ MASE = np.mean([mase(yi, y_hati, yt) for yi, y_hati, yt in zip(y, y_hat, y_train)])
28
+ return SMAPE, MSE, RMSE, MAE, R2, MASE, PSD
utils/midpoint.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class MidPoint():
3
+
4
+ def __init__(self, j):
5
+ self.j = j
6
+
7
+ def initial(self):
8
+ j = self.j
9
+
10
+ #init parameters
11
+ h, b, n, u = 0.1, None, None, None
12
+ x, y, z, px, py, pz = None, 0.1, 0.001, 0.01, None, 0.0001
13
+ xa, ya, za, pxa, pya, pza = None, None, None, None, None, None
14
+
15
+ if j == 1:
16
+ b = 5.0 / 4
17
+ n = b / (1 + b) ** 2
18
+ u = 1.0 / (1.0 / b + b + 2.0)
19
+ x = 10.0
20
+ py = 0.5
21
+ elif j == 2:
22
+ b = 3.0 / 4
23
+ n = b / (1 + b) ** 2
24
+ u = 1.0 / (1.0 / b + b + 2.0)
25
+ x = 8.3
26
+ py = 0.6
27
+ elif j == 3:
28
+ b = 3.0 / 2
29
+ x = 12.0
30
+ py = 0.4
31
+ elif j == 4:
32
+ b = 7.0 / 4
33
+ n = b / (1 + b) ** 2
34
+ u = 1.0 / (1.0 / b + b + 2.0)
35
+ x = 15.0
36
+ py = 0.35
37
+ elif j == 5:
38
+ b = 1.0
39
+ n = b / (1 + b) ** 2
40
+ u = 1.0 / (1.0 / b + b + 2.0)
41
+ x = 18.0
42
+ py = 0.3
43
+ elif j == 6:
44
+ b = 3.0 / 5
45
+ n = b / (1 + b) ** 2
46
+ u = 1.0 / (1.0 / b + b + 2.0)
47
+ x = 20.0
48
+ py = 0.25
49
+ elif j == 7:
50
+ b = 5.0 / 7
51
+ n = b / (1 + b) ** 2
52
+ u = 1.0 / (1.0 / b + b + 2.0)
53
+ x = 22.0
54
+ py = 0.22
55
+ elif j == 8:
56
+ b = 2.0
57
+ x = 26.0
58
+ py = 0.2
59
+ elif j == 9:
60
+ b = 0.5
61
+ n = b / (1 + b) ** 2
62
+ u = 1.0 / (1.0 / b + b + 2.0)
63
+ x = 30.0
64
+ y = 0.5
65
+ z = 0.1
66
+ pz = 0.01
67
+ elif j == 10:
68
+ b = 5.0
69
+ n = b / (1 + b) ** 2
70
+ u = 1.0 / (1.0 / b + b + 2.0)
71
+ x = 35.0
72
+ y = 2.0
73
+ z = 0.1
74
+ pz = 0.03
75
+ py = 0.15
76
+
77
+ xa, ya, za, pxa, pya, pza = x, y, z, px, py, pz
78
+ return j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza
79
+
80
+ def f(self, x, y, z, px, py, pz, b):
81
+
82
+ n = b / (1 + b) ** 2
83
+ u = 1 / (1 / b + b + 2)
84
+ ht = px ** 2 / 2 + py ** 2 / 2 + pz ** 2 / 2
85
+ hv = -1 / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
86
+ h1pn = 1 / (2 * x ** 2 + 2 * y ** 2 + 2 * z ** 2) - (((u + 3) * (px ** 2 + py ** 2 + pz ** 2)) / 2 + (u * (
87
+ (px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
88
+ 1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) ** 2) / 2) / (
89
+ x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + ((3 * u) / 8 - 1 / 8) * (
90
+ px ** 2 + py ** 2 + pz ** 2) ** 2
91
+
92
+ e = ht + hv + h1pn
93
+
94
+ vnpx=px
95
+ v1pnpx=4*px*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (px*(n + 3) + (n*x*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
96
+ vpx=vnpx+v1pnpx
97
+
98
+ vnpy=py
99
+ v1pnpy=4*py*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (py*(n + 3) + (n*y*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
100
+ vpy=vnpy+v1pnpy
101
+
102
+ vnpz=pz
103
+ v1pnpz=4*pz*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (pz*(n + 3) + (n*z*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
104
+ vpz=vnpz+v1pnpz
105
+
106
+ vnx=x/(x**2 + y**2 + z**2)**(3/2)
107
+ v1pnx=(x*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*x)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((px*x**2)/(x**2 + y**2 + z**2)**(3/2) -px/(x**2 + y**2 + z**2)**(1/2) + (py*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*x*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
108
+ vx=vnx+v1pnx
109
+
110
+ vny=y/(x**2 + y**2 + z**2)**(3/2)
111
+ v1pny=(y*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*y)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((py*y**2)/(x**2 + y**2 + z**2)**(3/2) -py/(x**2 + y**2 + z**2)**(1/2) + (px*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
112
+ vy=vny+v1pny
113
+
114
+ vnz=z/(x**2 + y**2 + z**2)**(3/2)
115
+ v1pnz=(z*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*z)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((pz*z**2)/(x**2 + y**2 + z**2)**(3/2) -pz/(x**2 + y**2 + z**2)**(1/2) + (px*x*z)/(x**2 + y**2 + z**2)**(3/2) + (py*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
116
+ vz=vnz+v1pnz
117
+
118
+ return vx,vy,vz,vpx,vpy,vpz,e
119
+
120
+ def rejust(self, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza):
121
+
122
+ x = (x + xa) / 2
123
+ y = (y + ya) / 2
124
+ z = (z + za) / 2
125
+
126
+ px = (px + pxa) / 2
127
+ py = (py + pya) / 2
128
+ pz = (pz + pza) / 2
129
+ xa = x
130
+ ya = y
131
+ za = z
132
+ pxa = px
133
+ pya = py
134
+ pza = pz
135
+
136
+ return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
137
+
138
+ def symplectic(self, h, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza, b):
139
+
140
+ vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
141
+ x = x + h / 2 * vpx
142
+ y = y + h / 2 * vpy
143
+ z = z + h / 2 * vpz
144
+ pxa = pxa - h / 2 * vxa
145
+ pya = pya - h / 2 * vya
146
+ pza = pza - h / 2 * vza
147
+
148
+ vx, vy, vz, vpxa, vpya, vpza, e = self.f(x, y, z, pxa, pya, pza, b)
149
+ xa = xa + h * vpxa
150
+ ya = ya + h * vpya
151
+ za = za + h * vpza
152
+ px = px - h * vx
153
+ py = py - h * vy
154
+ pz = pz - h * vz
155
+
156
+ vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
157
+ x = x + h / 2 * vpx
158
+ y = y + h / 2 * vpy
159
+ z = z + h / 2 * vpz
160
+ pxa = pxa - h / 2 * vxa
161
+ pya = pya - h / 2 * vya
162
+ pza = pza - h / 2 * vza
163
+
164
+ return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
utils/transform.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def compute_gradient(x, degree):
4
+ gradients = [x]
5
+ for i in range(degree):
6
+ x = torch.diff(x, dim=-1, prepend=x[..., 0:1])
7
+ gradients.append(x)
8
+ return torch.concatenate(gradients, dim=-1)