azrai99 commited on
Commit
340176f
·
verified ·
1 Parent(s): 48a5b9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -35
app.py CHANGED
@@ -8,48 +8,53 @@ from neuralforecast.losses.pytorch import HuberMQLoss
8
  from neuralforecast.utils import AirPassengersDF
9
  import time
10
 
11
- # Paths for saving models
12
- nhits_paths = {
13
- 'D': './M4/NHITS/daily',
14
- 'M': './M4/NHITS/monthly',
15
- 'H': './M4/NHITS/hourly',
16
- 'W': './M4/NHITS/weekly',
17
- 'Y': './M4/NHITS/yearly'
18
- }
19
 
20
- timesnet_paths = {
21
- 'D': './M4/TimesNet/daily',
22
- 'M': './M4/TimesNet/monthly',
23
- 'H': './M4/TimesNet/hourly',
24
- 'W': './M4/TimesNet/weekly',
25
- 'Y': './M4/TimesNet/yearly'
26
- }
27
-
28
- lstm_paths = {
29
- 'D': './M4/LSTM/daily',
30
- 'M': './M4/LSTM/monthly',
31
- 'H': './M4/LSTM/hourly',
32
- 'W': './M4/LSTM/weekly',
33
- 'Y': './M4/LSTM/yearly'
34
- }
35
-
36
- tft_paths = {
37
- 'D': './M4/TFT/daily',
38
- 'M': './M4/TFT/monthly',
39
- 'H': './M4/TFT/hourly',
40
- 'W': './M4/TFT/weekly',
41
- 'Y': './M4/TFT/yearly'
42
- }
43
 
44
  @st.cache_resource
45
  def load_model(path, freq):
46
  nf = NeuralForecast.load(path=path)
47
  return nf
48
 
49
- nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
50
- timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
51
- lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
52
- tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def generate_forecast(model, df):
55
  forecast_df = model.predict(df=df)
@@ -189,6 +194,8 @@ def transfer_learning_forecasting():
189
  else:
190
  df = load_default()
191
 
 
 
192
  # Model selection and forecasting
193
  st.subheader("Model Selection and Forecasting")
194
  model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
@@ -251,6 +258,7 @@ pg = st.navigation({
251
  ]
252
  })
253
 
 
254
  try:
255
  pg.run()
256
  except Exception as e:
 
8
  from neuralforecast.utils import AirPassengersDF
9
  import time
10
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @st.cache_resource
14
  def load_model(path, freq):
15
  nf = NeuralForecast.load(path=path)
16
  return nf
17
 
18
+ @st.cache_resource
19
+ def load_all_models():
20
+ # Paths for saving models
21
+ nhits_paths = {
22
+ 'D': './M4/NHITS/daily',
23
+ 'M': './M4/NHITS/monthly',
24
+ 'H': './M4/NHITS/hourly',
25
+ 'W': './M4/NHITS/weekly',
26
+ 'Y': './M4/NHITS/yearly'
27
+ }
28
+
29
+ timesnet_paths = {
30
+ 'D': './M4/TimesNet/daily',
31
+ 'M': './M4/TimesNet/monthly',
32
+ 'H': './M4/TimesNet/hourly',
33
+ 'W': './M4/TimesNet/weekly',
34
+ 'Y': './M4/TimesNet/yearly'
35
+ }
36
+
37
+ lstm_paths = {
38
+ 'D': './M4/LSTM/daily',
39
+ 'M': './M4/LSTM/monthly',
40
+ 'H': './M4/LSTM/hourly',
41
+ 'W': './M4/LSTM/weekly',
42
+ 'Y': './M4/LSTM/yearly'
43
+ }
44
+
45
+ tft_paths = {
46
+ 'D': './M4/TFT/daily',
47
+ 'M': './M4/TFT/monthly',
48
+ 'H': './M4/TFT/hourly',
49
+ 'W': './M4/TFT/weekly',
50
+ 'Y': './M4/TFT/yearly'
51
+ }
52
+ nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
53
+ timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
54
+ lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
55
+ tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}
56
+
57
+ return nhits_models, timesnet_models, lstm_models, tft_models
58
 
59
  def generate_forecast(model, df):
60
  forecast_df = model.predict(df=df)
 
194
  else:
195
  df = load_default()
196
 
197
+ nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
198
+
199
  # Model selection and forecasting
200
  st.subheader("Model Selection and Forecasting")
201
  model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
 
258
  ]
259
  })
260
 
261
+
262
  try:
263
  pg.run()
264
  except Exception as e: