Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,48 +8,53 @@ from neuralforecast.losses.pytorch import HuberMQLoss
|
|
8 |
from neuralforecast.utils import AirPassengersDF
|
9 |
import time
|
10 |
|
11 |
-
# Paths for saving models
|
12 |
-
nhits_paths = {
|
13 |
-
'D': './M4/NHITS/daily',
|
14 |
-
'M': './M4/NHITS/monthly',
|
15 |
-
'H': './M4/NHITS/hourly',
|
16 |
-
'W': './M4/NHITS/weekly',
|
17 |
-
'Y': './M4/NHITS/yearly'
|
18 |
-
}
|
19 |
|
20 |
-
timesnet_paths = {
|
21 |
-
'D': './M4/TimesNet/daily',
|
22 |
-
'M': './M4/TimesNet/monthly',
|
23 |
-
'H': './M4/TimesNet/hourly',
|
24 |
-
'W': './M4/TimesNet/weekly',
|
25 |
-
'Y': './M4/TimesNet/yearly'
|
26 |
-
}
|
27 |
-
|
28 |
-
lstm_paths = {
|
29 |
-
'D': './M4/LSTM/daily',
|
30 |
-
'M': './M4/LSTM/monthly',
|
31 |
-
'H': './M4/LSTM/hourly',
|
32 |
-
'W': './M4/LSTM/weekly',
|
33 |
-
'Y': './M4/LSTM/yearly'
|
34 |
-
}
|
35 |
-
|
36 |
-
tft_paths = {
|
37 |
-
'D': './M4/TFT/daily',
|
38 |
-
'M': './M4/TFT/monthly',
|
39 |
-
'H': './M4/TFT/hourly',
|
40 |
-
'W': './M4/TFT/weekly',
|
41 |
-
'Y': './M4/TFT/yearly'
|
42 |
-
}
|
43 |
|
44 |
@st.cache_resource
|
45 |
def load_model(path, freq):
|
46 |
nf = NeuralForecast.load(path=path)
|
47 |
return nf
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
def generate_forecast(model, df):
|
55 |
forecast_df = model.predict(df=df)
|
@@ -189,6 +194,8 @@ def transfer_learning_forecasting():
|
|
189 |
else:
|
190 |
df = load_default()
|
191 |
|
|
|
|
|
192 |
# Model selection and forecasting
|
193 |
st.subheader("Model Selection and Forecasting")
|
194 |
model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
@@ -251,6 +258,7 @@ pg = st.navigation({
|
|
251 |
]
|
252 |
})
|
253 |
|
|
|
254 |
try:
|
255 |
pg.run()
|
256 |
except Exception as e:
|
|
|
8 |
from neuralforecast.utils import AirPassengersDF
|
9 |
import time
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
@st.cache_resource
|
14 |
def load_model(path, freq):
|
15 |
nf = NeuralForecast.load(path=path)
|
16 |
return nf
|
17 |
|
18 |
+
@st.cache_resource
|
19 |
+
def load_all_models():
|
20 |
+
# Paths for saving models
|
21 |
+
nhits_paths = {
|
22 |
+
'D': './M4/NHITS/daily',
|
23 |
+
'M': './M4/NHITS/monthly',
|
24 |
+
'H': './M4/NHITS/hourly',
|
25 |
+
'W': './M4/NHITS/weekly',
|
26 |
+
'Y': './M4/NHITS/yearly'
|
27 |
+
}
|
28 |
+
|
29 |
+
timesnet_paths = {
|
30 |
+
'D': './M4/TimesNet/daily',
|
31 |
+
'M': './M4/TimesNet/monthly',
|
32 |
+
'H': './M4/TimesNet/hourly',
|
33 |
+
'W': './M4/TimesNet/weekly',
|
34 |
+
'Y': './M4/TimesNet/yearly'
|
35 |
+
}
|
36 |
+
|
37 |
+
lstm_paths = {
|
38 |
+
'D': './M4/LSTM/daily',
|
39 |
+
'M': './M4/LSTM/monthly',
|
40 |
+
'H': './M4/LSTM/hourly',
|
41 |
+
'W': './M4/LSTM/weekly',
|
42 |
+
'Y': './M4/LSTM/yearly'
|
43 |
+
}
|
44 |
+
|
45 |
+
tft_paths = {
|
46 |
+
'D': './M4/TFT/daily',
|
47 |
+
'M': './M4/TFT/monthly',
|
48 |
+
'H': './M4/TFT/hourly',
|
49 |
+
'W': './M4/TFT/weekly',
|
50 |
+
'Y': './M4/TFT/yearly'
|
51 |
+
}
|
52 |
+
nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
|
53 |
+
timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
|
54 |
+
lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
|
55 |
+
tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}
|
56 |
+
|
57 |
+
return nhits_models, timesnet_models, lstm_models, tft_models
|
58 |
|
59 |
def generate_forecast(model, df):
|
60 |
forecast_df = model.predict(df=df)
|
|
|
194 |
else:
|
195 |
df = load_default()
|
196 |
|
197 |
+
nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
|
198 |
+
|
199 |
# Model selection and forecasting
|
200 |
st.subheader("Model Selection and Forecasting")
|
201 |
model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
|
|
258 |
]
|
259 |
})
|
260 |
|
261 |
+
|
262 |
try:
|
263 |
pg.run()
|
264 |
except Exception as e:
|