azrai99 commited on
Commit
d511c44
·
verified ·
1 Parent(s): c0809e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -66
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
 
4
  from neuralforecast.core import NeuralForecast
5
  from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
6
  from neuralforecast.losses.pytorch import HuberMQLoss
@@ -9,35 +10,35 @@ import time
9
 
10
  # Paths for saving models
11
  nhits_paths = {
12
- 'D': './M4/NHITS/daily',
13
- 'M': './M4/NHITS/monthly',
14
- 'H': './M4/NHITS/hourly',
15
- 'W': './M4/NHITS/weekly',
16
- 'Y': './M4/NHITS/yearly'
17
  }
18
 
19
  timesnet_paths = {
20
- 'D': './M4/TimesNet/daily',
21
- 'M': './M4/TimesNet/monthly',
22
- 'H': './M4/TimesNet/hourly',
23
- 'W': './M4/TimesNet/weekly',
24
- 'Y': './M4/TimesNet/yearly'
25
  }
26
 
27
  lstm_paths = {
28
- 'D': './M4/LSTM/daily',
29
- 'M': './M4/LSTM/monthly',
30
- 'H': './M4/LSTM/hourly',
31
- 'W': './M4/LSTM/weekly',
32
- 'Y': './M4/LSTM/yearly'
33
  }
34
 
35
  tft_paths = {
36
- 'D': './M4/TFT/daily',
37
- 'M': './M4/TFT/monthly',
38
- 'H': './M4/TFT/hourly',
39
- 'W': './M4/TFT/weekly',
40
- 'Y': './M4/TFT/yearly'
41
  }
42
 
43
  @st.cache_resource
@@ -164,7 +165,7 @@ def forecast_time_series(df, model_type, freq, horizon, max_steps=200):
164
  model = select_model(horizon, model_type, max_steps)
165
  forecast_results = {}
166
  st.write(f"Generating forecast using {model_type} model...")
167
- forecast_results[model_type] = generate_forecast(model, df, freq)
168
 
169
  for model_name, forecast_df in forecast_results.items():
170
  plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
@@ -173,49 +174,82 @@ def forecast_time_series(df, model_type, freq, horizon, max_steps=200):
173
  time_taken = end_time - start_time
174
  st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
175
 
176
- # Streamlit App
177
- st.title("Dynamic and Automatic Time Series Forecasting")
178
-
179
- # Upload dataset
180
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
181
- if uploaded_file:
182
- df = pd.read_csv(uploaded_file)
183
- else:
184
- st.warning("Using default data")
185
- df = AirPassengersDF.copy()
186
-
187
- # Model selection and forecasting
188
- st.subheader("Transfer Learning Forecasting")
189
- model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
190
- horizon = st.slider("Forecast horizon", 1, 100, 10)
191
-
192
- # Determine frequency of data
193
- frequency = determine_frequency(df)
194
- st.write(f"Detected frequency: {frequency}")
195
-
196
- # Load pre-trained models
197
- nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
198
- forecast_results = {}
199
-
200
- start_time = time.time() # Start timing
201
- if model_choice == "NHITS":
202
- forecast_results['NHITS'] = generate_forecast(nhits_model, df)
203
- elif model_choice == "TimesNet":
204
- forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
205
- elif model_choice == "LSTM":
206
- forecast_results['LSTM'] = generate_forecast(lstm_model, df)
207
- elif model_choice == "TFT":
208
- forecast_results['TFT'] = generate_forecast(tft_model, df)
209
-
210
- for model_name, forecast_df in forecast_results.items():
211
- plot_forecasts(forecast_df, df, f'{model_name} Forecast')
212
 
213
- end_time = time.time() # End timing
214
- time_taken = end_time - start_time
215
- st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
216
-
217
- # Dynamic forecasting
218
- st.subheader("Dynamic Forecasting")
219
- dynamic_model_choice = st.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
220
- dynamic_horizon = st.slider("Forecast horizon for dynamic forecasting", 1, 100, 10, key="dynamic_horizon")
221
- forecast_time_series(df, dynamic_model_choice, frequency, dynamic_horizon)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
+ import pytorch_lightning as pl
5
  from neuralforecast.core import NeuralForecast
6
  from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
7
  from neuralforecast.losses.pytorch import HuberMQLoss
 
10
 
11
  # Paths for saving models
12
  nhits_paths = {
13
+ 'D': './results/M4/NHITS/daily',
14
+ 'M': './results/M4/NHITS/monthly',
15
+ 'H': './results/M4/NHITS/hourly',
16
+ 'W': './results/M4/NHITS/weekly',
17
+ 'Y': './results/M4/NHITS/yearly'
18
  }
19
 
20
  timesnet_paths = {
21
+ 'D': './results/M4/TimesNet/daily',
22
+ 'M': './results/M4/TimesNet/monthly',
23
+ 'H': './results/M4/TimesNet/hourly',
24
+ 'W': './results/M4/TimesNet/weekly',
25
+ 'Y': './results/M4/TimesNet/yearly'
26
  }
27
 
28
  lstm_paths = {
29
+ 'D': './results/M4/LSTM/daily',
30
+ 'M': './results/M4/LSTM/monthly',
31
+ 'H': './results/M4/LSTM/hourly',
32
+ 'W': './results/M4/LSTM/weekly',
33
+ 'Y': './results/M4/LSTM/yearly'
34
  }
35
 
36
  tft_paths = {
37
+ 'D': './results/M4/TFT/daily',
38
+ 'M': './results/M4/TFT/monthly',
39
+ 'H': './results/M4/TFT/hourly',
40
+ 'W': './results/M4/TFT/weekly',
41
+ 'Y': './results/M4/TFT/yearly'
42
  }
43
 
44
  @st.cache_resource
 
165
  model = select_model(horizon, model_type, max_steps)
166
  forecast_results = {}
167
  st.write(f"Generating forecast using {model_type} model...")
168
+ forecast_results[model_type] = generate_forecast(model, df)
169
 
170
  for model_name, forecast_df in forecast_results.items():
171
  plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
 
174
  time_taken = end_time - start_time
175
  st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
176
 
177
+ @st.cache_data
178
+ def load_default():
179
+ df = AirPassengersDf.copy()
180
+ return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ def transfer_learning_forecasting():
183
+ st.title("Transfer Learning Forecasting")
184
+
185
+ # Upload dataset
186
+ uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
187
+ if uploaded_file:
188
+ df = pd.read_csv(uploaded_file)
189
+ else:
190
+ df = load_default()
191
+
192
+ # Model selection and forecasting
193
+ st.subheader("Model Selection and Forecasting")
194
+ model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
195
+ horizon = st.slider("Forecast horizon", 1, 100, 10)
196
+
197
+ # Determine frequency of data
198
+ frequency = determine_frequency(df)
199
+ st.write(f"Detected frequency: {frequency}")
200
+
201
+ # Load pre-trained models
202
+ nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
203
+ forecast_results = {}
204
+
205
+ start_time = time.time() # Start timing
206
+ if model_choice == "NHITS":
207
+ forecast_results['NHITS'] = generate_forecast(nhits_model, df)
208
+ elif model_choice == "TimesNet":
209
+ forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
210
+ elif model_choice == "LSTM":
211
+ forecast_results['LSTM'] = generate_forecast(lstm_model, df)
212
+ elif model_choice == "TFT":
213
+ forecast_results['TFT'] = generate_forecast(tft_model, df)
214
+
215
+ for model_name, forecast_df in forecast_results.items():
216
+ plot_forecasts(forecast_df, df, f'{model_name} Forecast')
217
+
218
+ end_time = time.time() # End timing
219
+ time_taken = end_time - start_time
220
+ st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
221
+
222
+ def dynamic_forecasting():
223
+ st.title("Dynamic Forecasting")
224
+
225
+ # Upload dataset
226
+ uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
227
+ if uploaded_file:
228
+ df = pd.read_csv(uploaded_file)
229
+ else:
230
+ df = load_default()
231
+
232
+ # Dynamic forecasting
233
+ st.subheader("Dynamic Model Selection and Forecasting")
234
+ dynamic_model_choice = st.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
235
+ dynamic_horizon = st.slider("Forecast horizon for dynamic forecasting", 1, 100, 10, key="dynamic_horizon")
236
+
237
+ # Determine frequency of data
238
+ frequency = determine_frequency(df)
239
+ st.write(f"Detected frequency: {frequency}")
240
+
241
+ forecast_time_series(df, dynamic_model_choice, frequency, dynamic_horizon
242
+
243
+ # Define the main navigation
244
+ pg = st.navigation({
245
+ "Overview": [
246
+ # Load pages from functions
247
+ st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/library_books:"),
248
+ st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/person:"),
249
+ ]
250
+ })
251
+
252
+ try:
253
+ pg.run()
254
+ except Exception as e:
255
+ st.error(f"Something went wrong: {str(e)}", icon=":material/error:")