Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
|
|
4 |
from neuralforecast.core import NeuralForecast
|
5 |
from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
|
6 |
from neuralforecast.losses.pytorch import HuberMQLoss
|
@@ -9,35 +10,35 @@ import time
|
|
9 |
|
10 |
# Paths for saving models
|
11 |
nhits_paths = {
|
12 |
-
'D': './M4/NHITS/daily',
|
13 |
-
'M': './M4/NHITS/monthly',
|
14 |
-
'H': './M4/NHITS/hourly',
|
15 |
-
'W': './M4/NHITS/weekly',
|
16 |
-
'Y': './M4/NHITS/yearly'
|
17 |
}
|
18 |
|
19 |
timesnet_paths = {
|
20 |
-
'D': './M4/TimesNet/daily',
|
21 |
-
'M': './M4/TimesNet/monthly',
|
22 |
-
'H': './M4/TimesNet/hourly',
|
23 |
-
'W': './M4/TimesNet/weekly',
|
24 |
-
'Y': './M4/TimesNet/yearly'
|
25 |
}
|
26 |
|
27 |
lstm_paths = {
|
28 |
-
'D': './M4/LSTM/daily',
|
29 |
-
'M': './M4/LSTM/monthly',
|
30 |
-
'H': './M4/LSTM/hourly',
|
31 |
-
'W': './M4/LSTM/weekly',
|
32 |
-
'Y': './M4/LSTM/yearly'
|
33 |
}
|
34 |
|
35 |
tft_paths = {
|
36 |
-
'D': './M4/TFT/daily',
|
37 |
-
'M': './M4/TFT/monthly',
|
38 |
-
'H': './M4/TFT/hourly',
|
39 |
-
'W': './M4/TFT/weekly',
|
40 |
-
'Y': './M4/TFT/yearly'
|
41 |
}
|
42 |
|
43 |
@st.cache_resource
|
@@ -164,7 +165,7 @@ def forecast_time_series(df, model_type, freq, horizon, max_steps=200):
|
|
164 |
model = select_model(horizon, model_type, max_steps)
|
165 |
forecast_results = {}
|
166 |
st.write(f"Generating forecast using {model_type} model...")
|
167 |
-
forecast_results[model_type] = generate_forecast(model, df
|
168 |
|
169 |
for model_name, forecast_df in forecast_results.items():
|
170 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
|
@@ -173,49 +174,82 @@ def forecast_time_series(df, model_type, freq, horizon, max_steps=200):
|
|
173 |
time_taken = end_time - start_time
|
174 |
st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
181 |
-
if uploaded_file:
|
182 |
-
df = pd.read_csv(uploaded_file)
|
183 |
-
else:
|
184 |
-
st.warning("Using default data")
|
185 |
-
df = AirPassengersDF.copy()
|
186 |
-
|
187 |
-
# Model selection and forecasting
|
188 |
-
st.subheader("Transfer Learning Forecasting")
|
189 |
-
model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
190 |
-
horizon = st.slider("Forecast horizon", 1, 100, 10)
|
191 |
-
|
192 |
-
# Determine frequency of data
|
193 |
-
frequency = determine_frequency(df)
|
194 |
-
st.write(f"Detected frequency: {frequency}")
|
195 |
-
|
196 |
-
# Load pre-trained models
|
197 |
-
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
|
198 |
-
forecast_results = {}
|
199 |
-
|
200 |
-
start_time = time.time() # Start timing
|
201 |
-
if model_choice == "NHITS":
|
202 |
-
forecast_results['NHITS'] = generate_forecast(nhits_model, df)
|
203 |
-
elif model_choice == "TimesNet":
|
204 |
-
forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
|
205 |
-
elif model_choice == "LSTM":
|
206 |
-
forecast_results['LSTM'] = generate_forecast(lstm_model, df)
|
207 |
-
elif model_choice == "TFT":
|
208 |
-
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
209 |
-
|
210 |
-
for model_name, forecast_df in forecast_results.items():
|
211 |
-
plot_forecasts(forecast_df, df, f'{model_name} Forecast')
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
4 |
+
import pytorch_lightning as pl
|
5 |
from neuralforecast.core import NeuralForecast
|
6 |
from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
|
7 |
from neuralforecast.losses.pytorch import HuberMQLoss
|
|
|
10 |
|
11 |
# Paths for saving models
|
12 |
nhits_paths = {
|
13 |
+
'D': './results/M4/NHITS/daily',
|
14 |
+
'M': './results/M4/NHITS/monthly',
|
15 |
+
'H': './results/M4/NHITS/hourly',
|
16 |
+
'W': './results/M4/NHITS/weekly',
|
17 |
+
'Y': './results/M4/NHITS/yearly'
|
18 |
}
|
19 |
|
20 |
timesnet_paths = {
|
21 |
+
'D': './results/M4/TimesNet/daily',
|
22 |
+
'M': './results/M4/TimesNet/monthly',
|
23 |
+
'H': './results/M4/TimesNet/hourly',
|
24 |
+
'W': './results/M4/TimesNet/weekly',
|
25 |
+
'Y': './results/M4/TimesNet/yearly'
|
26 |
}
|
27 |
|
28 |
lstm_paths = {
|
29 |
+
'D': './results/M4/LSTM/daily',
|
30 |
+
'M': './results/M4/LSTM/monthly',
|
31 |
+
'H': './results/M4/LSTM/hourly',
|
32 |
+
'W': './results/M4/LSTM/weekly',
|
33 |
+
'Y': './results/M4/LSTM/yearly'
|
34 |
}
|
35 |
|
36 |
tft_paths = {
|
37 |
+
'D': './results/M4/TFT/daily',
|
38 |
+
'M': './results/M4/TFT/monthly',
|
39 |
+
'H': './results/M4/TFT/hourly',
|
40 |
+
'W': './results/M4/TFT/weekly',
|
41 |
+
'Y': './results/M4/TFT/yearly'
|
42 |
}
|
43 |
|
44 |
@st.cache_resource
|
|
|
165 |
model = select_model(horizon, model_type, max_steps)
|
166 |
forecast_results = {}
|
167 |
st.write(f"Generating forecast using {model_type} model...")
|
168 |
+
forecast_results[model_type] = generate_forecast(model, df)
|
169 |
|
170 |
for model_name, forecast_df in forecast_results.items():
|
171 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
|
|
|
174 |
time_taken = end_time - start_time
|
175 |
st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
|
176 |
|
177 |
+
@st.cache_data
|
178 |
+
def load_default():
|
179 |
+
df = AirPassengersDf.copy()
|
180 |
+
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
+
def transfer_learning_forecasting():
|
183 |
+
st.title("Transfer Learning Forecasting")
|
184 |
+
|
185 |
+
# Upload dataset
|
186 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
187 |
+
if uploaded_file:
|
188 |
+
df = pd.read_csv(uploaded_file)
|
189 |
+
else:
|
190 |
+
df = load_default()
|
191 |
+
|
192 |
+
# Model selection and forecasting
|
193 |
+
st.subheader("Model Selection and Forecasting")
|
194 |
+
model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
195 |
+
horizon = st.slider("Forecast horizon", 1, 100, 10)
|
196 |
+
|
197 |
+
# Determine frequency of data
|
198 |
+
frequency = determine_frequency(df)
|
199 |
+
st.write(f"Detected frequency: {frequency}")
|
200 |
+
|
201 |
+
# Load pre-trained models
|
202 |
+
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
|
203 |
+
forecast_results = {}
|
204 |
+
|
205 |
+
start_time = time.time() # Start timing
|
206 |
+
if model_choice == "NHITS":
|
207 |
+
forecast_results['NHITS'] = generate_forecast(nhits_model, df)
|
208 |
+
elif model_choice == "TimesNet":
|
209 |
+
forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
|
210 |
+
elif model_choice == "LSTM":
|
211 |
+
forecast_results['LSTM'] = generate_forecast(lstm_model, df)
|
212 |
+
elif model_choice == "TFT":
|
213 |
+
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
214 |
+
|
215 |
+
for model_name, forecast_df in forecast_results.items():
|
216 |
+
plot_forecasts(forecast_df, df, f'{model_name} Forecast')
|
217 |
+
|
218 |
+
end_time = time.time() # End timing
|
219 |
+
time_taken = end_time - start_time
|
220 |
+
st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
|
221 |
+
|
222 |
+
def dynamic_forecasting():
|
223 |
+
st.title("Dynamic Forecasting")
|
224 |
+
|
225 |
+
# Upload dataset
|
226 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
227 |
+
if uploaded_file:
|
228 |
+
df = pd.read_csv(uploaded_file)
|
229 |
+
else:
|
230 |
+
df = load_default()
|
231 |
+
|
232 |
+
# Dynamic forecasting
|
233 |
+
st.subheader("Dynamic Model Selection and Forecasting")
|
234 |
+
dynamic_model_choice = st.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
235 |
+
dynamic_horizon = st.slider("Forecast horizon for dynamic forecasting", 1, 100, 10, key="dynamic_horizon")
|
236 |
+
|
237 |
+
# Determine frequency of data
|
238 |
+
frequency = determine_frequency(df)
|
239 |
+
st.write(f"Detected frequency: {frequency}")
|
240 |
+
|
241 |
+
forecast_time_series(df, dynamic_model_choice, frequency, dynamic_horizon
|
242 |
+
|
243 |
+
# Define the main navigation
|
244 |
+
pg = st.navigation({
|
245 |
+
"Overview": [
|
246 |
+
# Load pages from functions
|
247 |
+
st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/library_books:"),
|
248 |
+
st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/person:"),
|
249 |
+
]
|
250 |
+
})
|
251 |
+
|
252 |
+
try:
|
253 |
+
pg.run()
|
254 |
+
except Exception as e:
|
255 |
+
st.error(f"Something went wrong: {str(e)}", icon=":material/error:")
|