azrai99 commited on
Commit
700a890
·
verified ·
1 Parent(s): 8b13fd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -583
app.py CHANGED
@@ -1,670 +1,130 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
- import pytorch_lightning as pl
5
  from neuralforecast.core import NeuralForecast
6
  from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
7
  from neuralforecast.losses.pytorch import HuberMQLoss
8
  from neuralforecast.utils import AirPassengersDF
9
- import time
10
- from st_aggrid import AgGrid
11
- from nixtla import NixtlaClient
12
- import os
13
 
14
  st.set_page_config(layout='wide')
15
-
16
- @st.cache_resource
17
- def load_model(path, freq):
18
- nf = NeuralForecast.load(path=path)
19
- return nf
20
-
21
- @st.cache_resource
22
- def load_all_models():
23
- nhits_paths = {
24
- 'D': './M4/NHITS/daily',
25
- 'M': './M4/NHITS/monthly',
26
- 'H': './M4/NHITS/hourly',
27
- 'W': './M4/NHITS/weekly',
28
- 'Y': './M4/NHITS/yearly'
29
- }
30
-
31
- timesnet_paths = {
32
- 'D': './M4/TimesNet/daily',
33
- 'M': './M4/TimesNet/monthly',
34
- 'H': './M4/TimesNet/hourly',
35
- 'W': './M4/TimesNet/weekly',
36
- 'Y': './M4/TimesNet/yearly'
37
- }
38
-
39
- lstm_paths = {
40
- 'D': './M4/LSTM/daily',
41
- 'M': './M4/LSTM/monthly',
42
- 'H': './M4/LSTM/hourly',
43
- 'W': './M4/LSTM/weekly',
44
- 'Y': './M4/LSTM/yearly'
45
- }
46
-
47
- tft_paths = {
48
- 'D': './M4/TFT/daily',
49
- 'M': './M4/TFT/monthly',
50
- 'H': './M4/TFT/hourly',
51
- 'W': './M4/TFT/weekly',
52
- 'Y': './M4/TFT/yearly'
53
- }
54
- nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
55
- timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
56
- lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
57
- tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}
58
 
59
- return nhits_models, timesnet_models, lstm_models, tft_models
60
-
61
- def generate_forecast(model, df,tag=False):
62
  if tag == 'retrain':
63
- forecast_df = model.predict()
64
- else:
65
- forecast_df = model.predict(df=df)
66
- return forecast_df
67
 
68
  def determine_frequency(df):
69
  df['ds'] = pd.to_datetime(df['ds'])
70
- df = df.drop_duplicates(subset='ds')
71
- df = df.set_index('ds')
72
-
73
- # # Create a complete date range
74
- # full_range = pd.date_range(start=df.index.min(), end=df.index.max(),freq=freq)
75
-
76
- # # Reindex the DataFrame to this full date range
77
- # df_full = df.reindex(full_range)
78
-
79
- # Infer the frequency
80
- # freq = pd.infer_freq(df_full.index)
81
-
82
  freq = pd.infer_freq(df.index)
83
  if not freq:
84
- st.warning('The forecast will use default Daily forecast due to date inconsistency. Please check your data.',icon="⚠️")
85
  freq = 'D'
86
-
87
  return freq
88
 
89
- def plot_forecasts_matplotlib(forecast_df, train_df, title):
90
- fig, ax = plt.subplots(1, 1, figsize=(20, 7))
91
- plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
92
- historical_col = 'y'
93
- forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
94
- lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
95
- hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
96
- if forecast_col is None:
97
- raise KeyError("No forecast column found in the data.")
98
- plot_df[[historical_col, forecast_col]].plot(ax=ax, linewidth=2, label=['Historical', 'Forecast'])
99
- if lo_col and hi_col:
100
- ax.fill_between(
101
- plot_df.index,
102
- plot_df[lo_col],
103
- plot_df[hi_col],
104
- color='blue',
105
- alpha=0.3,
106
- label='90% Confidence Interval'
107
- )
108
- ax.set_title(title, fontsize=22)
109
- ax.set_ylabel('Value', fontsize=20)
110
- ax.set_xlabel('Timestamp [t]', fontsize=20)
111
- ax.legend(prop={'size': 15})
112
- ax.grid()
113
- st.pyplot(fig)
114
-
115
- import plotly.graph_objects as go
116
-
117
  def plot_forecasts(forecast_df, train_df, title):
118
- # Combine historical and forecast data
119
  plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
120
-
121
- # Find relevant columns
122
  historical_col = 'y'
123
  forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
124
  lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
125
  hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
126
-
127
  if forecast_col is None:
128
  raise KeyError("No forecast column found in the data.")
129
-
130
- # Create Plotly figure
131
  fig = go.Figure()
132
-
133
- # Add historical data
134
  fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[historical_col], mode='lines', name='Historical'))
135
-
136
- # Add forecast data
137
  fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[forecast_col], mode='lines', name='Forecast'))
138
-
139
- # Add confidence interval if available
140
- if lo_col and hi_col:
141
- fig.add_trace(go.Scatter(
142
- x=plot_df.index,
143
- y=plot_df[hi_col],
144
- mode='lines',
145
- line=dict(color='rgba(0,100,80,0.2)'),
146
- showlegend=False
147
- ))
148
- fig.add_trace(go.Scatter(
149
- x=plot_df.index,
150
- y=plot_df[lo_col],
151
- mode='lines',
152
- line=dict(color='rgba(0,100,80,0.2)'),
153
- fill='tonexty',
154
- fillcolor='rgba(0,100,80,0.2)',
155
- name='90% Confidence Interval'
156
- ))
157
-
158
- # Update layout
159
- fig.update_layout(
160
- title=title,
161
- xaxis_title='Timestamp [t]',
162
- yaxis_title='Value',
163
- template='plotly_white'
164
- )
165
-
166
- # Display the plot
167
- st.plotly_chart(fig)
168
 
 
 
 
169
 
170
- def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
171
- if freq == 'D':
172
- return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
173
- elif freq == 'ME':
174
- return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M']
175
- elif freq == 'H':
176
- return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H']
177
- elif freq in ['W', 'W-SUN']:
178
- return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W']
179
- elif freq in ['Y', 'Y-DEC']:
180
- return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y']
181
- else:
182
- raise ValueError(f"Unsupported frequency: {freq}")
183
 
184
  def select_model(horizon, model_type, max_steps=50):
185
  if model_type == 'NHITS':
186
- return NHITS(input_size=5 * horizon,
187
- h=horizon,
188
- max_steps=max_steps,
189
- stack_types=3*['identity'],
190
- n_blocks=3*[1],
191
- mlp_units=[[256, 256] for _ in range(3)],
192
- n_pool_kernel_size=3*[1],
193
- batch_size=32,
194
- scaler_type='standard',
195
- n_freq_downsample=[12, 4, 1],
196
- loss=HuberMQLoss(level=[90]))
197
  elif model_type == 'TimesNet':
198
- return TimesNet(h=horizon,
199
- input_size=horizon * 5,
200
- hidden_size=32,
201
- conv_hidden_size=64,
202
- loss=HuberMQLoss(level=[90]),
203
- scaler_type='standard',
204
- learning_rate=1e-3,
205
- max_steps=max_steps,
206
- val_check_steps=200,
207
- valid_batch_size=64,
208
- windows_batch_size=128,
209
- inference_windows_batch_size=512)
210
  elif model_type == 'LSTM':
211
- return LSTM(h=horizon,
212
- input_size=horizon * 5,
213
- loss=HuberMQLoss(level=[90]),
214
- scaler_type='standard',
215
- encoder_n_layers=3,
216
- encoder_hidden_size=256,
217
- context_size=10,
218
- decoder_hidden_size=256,
219
- decoder_layers=3,
220
- max_steps=max_steps)
221
  elif model_type == 'TFT':
222
- return TFT(h=horizon,
223
- input_size=horizon*5,
224
- hidden_size=96,
225
- loss=HuberMQLoss(level=[90]),
226
- learning_rate=0.005,
227
- scaler_type='standard',
228
- windows_batch_size=128,
229
- max_steps=max_steps,
230
- val_check_steps=200,
231
- valid_batch_size=64,
232
- enable_progress_bar=True)
233
  else:
234
  raise ValueError(f"Unsupported model type: {model_type}")
235
 
236
- def model_train(df,model, freq):
237
  nf = NeuralForecast(models=[model], freq=freq)
238
  df['ds'] = pd.to_datetime(df['ds'])
239
  nf.fit(df)
240
  return nf
241
 
242
- def forecast_time_series(df, model_type, horizon, max_steps,y_col):
243
- start_time = time.time() # Start timing
244
  freq = determine_frequency(df)
245
  st.sidebar.write(f"Data frequency: {freq}")
246
 
247
  selected_model = select_model(horizon, model_type, max_steps)
248
- st.spinner(f"Training {model_type} model...")
249
- model = model_train(df, selected_model,freq)
250
 
251
- forecast_results = {}
252
- forecast_results[model_type] = generate_forecast(model, df, tag='retrain')
253
-
254
  st.session_state.forecast_results = forecast_results
255
-
256
  for model_name, forecast_df in forecast_results.items():
257
  plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
258
-
259
- end_time = time.time() # End timing
260
- time_taken = end_time - start_time
261
  st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
 
262
  if 'forecast_results' in st.session_state:
263
- forecast_results = st.session_state.forecast_results
264
-
265
- st.markdown('You can download Input and Forecast Data below')
266
- tab_insample, tab_forecast = st.tabs(
267
- ["Input data", "Forecast"]
268
- )
269
-
270
  with tab_insample:
271
  df_grid = df.drop(columns="unique_id")
272
  st.write(df_grid)
273
- # grid_table = AgGrid(
274
- # df_grid,
275
- # theme="alpine",
276
- # )
277
-
278
  with tab_forecast:
279
  if model_type in forecast_results:
280
  df_grid = forecast_results[model_type]
281
  st.write(df_grid)
282
- # grid_table = AgGrid(
283
- # df_grid,
284
- # theme="alpine",
285
- # )
286
 
287
  @st.cache_data
288
  def load_default():
289
- df = AirPassengersDF.copy()
290
- return df
291
-
292
- def transfer_learning_forecasting():
293
- st.title("Zero-shot Forecasting")
294
- st.markdown("""
295
- Instant time series forecasting and visualization by using various pre-trained deep neural network-based model trained on M4 data.
296
- """)
297
-
298
- nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
299
-
300
- with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
301
- if 'uploaded_file' not in st.session_state:
302
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
303
- if uploaded_file:
304
- df = pd.read_csv(uploaded_file)
305
- st.session_state.df = df
306
- st.session_state.uploaded_file = uploaded_file
307
- else:
308
- df = load_default()
309
- st.session_state.df = df
310
- else:
311
- if st.checkbox("Upload a new file (CSV)"):
312
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
313
- if uploaded_file:
314
- df = pd.read_csv(uploaded_file)
315
- st.session_state.df = df
316
- st.session_state.uploaded_file = uploaded_file
317
- else:
318
- df = st.session_state.df
319
- else:
320
- df = st.session_state.df
321
-
322
- columns = df.columns.tolist()
323
- ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
324
- target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')]
325
- y_col = st.selectbox("Select Target column", options=target_columns, index=0)
326
-
327
- st.session_state.ds_col = ds_col
328
- st.session_state.y_col = y_col
329
-
330
- # Model selection and forecasting
331
- st.sidebar.subheader("Model Selection and Forecasting")
332
- model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
333
- horizon = st.sidebar.number_input("Forecast horizon", value=12)
334
-
335
- df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
336
- df['unique_id']=1
337
- df = df[['unique_id','ds','y']]
338
 
339
- # Determine frequency of data
340
- frequency = determine_frequency(df)
341
- st.sidebar.write(f"Detected frequency: {frequency}")
342
-
343
-
344
- nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
345
- forecast_results = {}
346
-
347
-
348
-
349
- if st.sidebar.button("Submit"):
350
- start_time = time.time() # Start timing
351
- if model_choice == "NHITS":
352
- forecast_results['NHITS'] = generate_forecast(nhits_model, df)
353
- elif model_choice == "TimesNet":
354
- forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
355
- elif model_choice == "LSTM":
356
- forecast_results['LSTM'] = generate_forecast(lstm_model, df)
357
- elif model_choice == "TFT":
358
- forecast_results['TFT'] = generate_forecast(tft_model, df)
359
-
360
- st.session_state.forecast_results = forecast_results
361
- for model_name, forecast_df in forecast_results.items():
362
- plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}')
363
-
364
- end_time = time.time() # End timing
365
- time_taken = end_time - start_time
366
- st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
367
-
368
- if 'forecast_results' in st.session_state:
369
- forecast_results = st.session_state.forecast_results
370
-
371
- st.markdown('You can download Input and Forecast Data below')
372
- tab_insample, tab_forecast = st.tabs(
373
- ["Input data", "Forecast"]
374
- )
375
-
376
- with tab_insample:
377
- df_grid = df.drop(columns="unique_id")
378
- st.write(df_grid)
379
- # grid_table = AgGrid(
380
- # df_grid,
381
- # theme="alpine",
382
- # )
383
-
384
- with tab_forecast:
385
- if model_choice in forecast_results:
386
- df_grid = forecast_results[model_choice]
387
- st.write(df_grid)
388
- # grid_table = AgGrid(
389
- # df_grid,
390
- # theme="alpine",
391
- # )
392
-
393
-
394
- def dynamic_forecasting():
395
  st.title("Personalized Neural Forecasting")
396
- st.markdown("""
397
- Train time series forecasting model from scratch and provide forecasts/visualization by using various deep neural network-based model trained on user data.
398
-
399
- Forecasting speed depends on CPU/GPU availabilty.
400
- """)
401
-
402
  with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
403
- if 'uploaded_file' not in st.session_state:
404
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
405
- if uploaded_file:
406
- df = pd.read_csv(uploaded_file)
407
- st.session_state.df = df
408
- st.session_state.uploaded_file = uploaded_file
409
- else:
410
- df = load_default()
411
- st.session_state.df = df
412
- else:
413
- if st.checkbox("Upload a new file (CSV)"):
414
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
415
- if uploaded_file:
416
- df = pd.read_csv(uploaded_file)
417
- st.session_state.df = df
418
- st.session_state.uploaded_file = uploaded_file
419
- else:
420
- df = st.session_state.df
421
- else:
422
- df = st.session_state.df
423
 
424
  columns = df.columns.tolist()
425
  ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
426
- target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')]
427
  y_col = st.selectbox("Select Target column", options=target_columns, index=0)
428
 
429
- st.session_state.ds_col = ds_col
430
- st.session_state.y_col = y_col
431
 
432
- df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
433
-
434
- df['unique_id']=1
435
- df = df[['unique_id','ds','y']]
436
-
437
-
438
- # Dynamic forecasting
439
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")
440
- dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
441
  dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=12)
442
  dynamic_max_steps = st.sidebar.number_input('Max steps', value=20)
443
 
444
  if st.sidebar.button("Submit"):
445
- with st.spinner('Training model. This may take few minutes...'):
446
- forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
447
-
448
- def timegpt_fcst():
449
- nixtla_token = os.environ.get("NIXTLA_API_KEY")
450
- nixtla_client = NixtlaClient(
451
- api_key = nixtla_token
452
- )
453
-
454
-
455
- st.title("TimeGPT Forecasting")
456
- st.markdown("""
457
- Instant time series forecasting and visualization by using the TimeGPT API provided by Nixtla.
458
- """)
459
- with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
460
- if 'uploaded_file' not in st.session_state:
461
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
462
- if uploaded_file:
463
- df = pd.read_csv(uploaded_file)
464
- st.session_state.df = df
465
- st.session_state.uploaded_file = uploaded_file
466
- else:
467
- df = load_default()
468
- st.session_state.df = df
469
- else:
470
- if st.checkbox("Upload a new file (CSV)"):
471
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
472
- if uploaded_file:
473
- df = pd.read_csv(uploaded_file)
474
- st.session_state.df = df
475
- st.session_state.uploaded_file = uploaded_file
476
- else:
477
- df = st.session_state.df
478
- else:
479
- df = st.session_state.df
480
-
481
- columns = df.columns.tolist()
482
- ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
483
- target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')]
484
- y_col = st.selectbox("Select Target column", options=target_columns, index=0)
485
- h = st.number_input("Forecast horizon", value=14)
486
-
487
- df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
488
-
489
-
490
- id_col = 'ts_test'
491
- df['unique_id']=id_col
492
- df = df[['unique_id','ds','y']]
493
-
494
-
495
- freq = determine_frequency(df)
496
-
497
- df = df.drop_duplicates(subset=['ds']).reset_index(drop=True)
498
-
499
- plot_type = st.sidebar.selectbox("Select Visualization", ["Matplotlib", "Plotly"])
500
- if st.sidebar.button("Submit"):
501
- start_time = time.time()
502
- forecast_df = nixtla_client.forecast(
503
- df=df,
504
- h=h,
505
- freq=freq,
506
- level=[90]
507
- )
508
- st.session_state.forecast_df = forecast_df
509
-
510
-
511
- if 'forecast_df' in st.session_state:
512
- forecast_df = st.session_state.forecast_df
513
-
514
- if plot_type == "Matplotlib":
515
- # Convert the Plotly figure to a Matplotlib figure if needed
516
- # Note: You may need to handle this conversion depending on your specific use case
517
- # For now, this example assumes that you are using a Matplotlib figure
518
- fig = nixtla_client.plot(df, forecast_df, level=[90], engine='matplotlib')
519
- st.pyplot(fig)
520
- elif plot_type == "Plotly":
521
- # Plotly figure directly
522
- fig = nixtla_client.plot(df, forecast_df, level=[90], engine='plotly')
523
- st.plotly_chart(fig)
524
-
525
- end_time = time.time() # End timing
526
- time_taken = end_time - start_time
527
- st.success(f"Time taken for TimeGPT forecast: {time_taken:.2f} seconds")
528
-
529
- if 'forecast_df' in st.session_state:
530
- forecast_df = st.session_state.forecast_df
531
-
532
- st.markdown('You can download Input and Forecast Data below')
533
- tab_insample, tab_forecast = st.tabs(
534
- ["Input data", "Forecast"]
535
- )
536
-
537
- with tab_insample:
538
- df_grid = df.drop(columns="unique_id")
539
- st.write(df_grid)
540
- # grid_table = AgGrid(
541
- # df_grid,
542
- # theme="alpine",
543
- # )
544
-
545
- with tab_forecast:
546
- df_grid = forecast_df
547
- st.write(df_grid)
548
- # grid_table = AgGrid(
549
- # df_grid,
550
- # theme="alpine",
551
- # )
552
-
553
-
554
-
555
- def timegpt_anom():
556
- nixtla_token = os.environ.get("NIXTLA_API_KEY")
557
- nixtla_client = NixtlaClient(
558
- api_key = nixtla_token
559
- )
560
-
561
-
562
- st.title("TimeGPT Anomaly Detection")
563
- st.markdown("""
564
- Instant time series anomaly detection and visualization by using the TimeGPT API provided by Nixtla.
565
- """)
566
- with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
567
- if 'uploaded_file' not in st.session_state:
568
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
569
- if uploaded_file:
570
- df = pd.read_csv(uploaded_file)
571
- st.session_state.df = df
572
- st.session_state.uploaded_file = uploaded_file
573
- else:
574
- df = load_default()
575
- st.session_state.df = df
576
- else:
577
- if st.checkbox("Upload a new file (CSV)"):
578
- uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
579
- if uploaded_file:
580
- df = pd.read_csv(uploaded_file)
581
- st.session_state.df = df
582
- st.session_state.uploaded_file = uploaded_file
583
- else:
584
- df = st.session_state.df
585
- else:
586
- df = st.session_state.df
587
-
588
- columns = df.columns.tolist()
589
- ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
590
- target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')]
591
- y_col = st.selectbox("Select Target column", options=target_columns, index=0)
592
-
593
- df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
594
-
595
- id_col = 'ts_test'
596
- df['unique_id']=id_col
597
- df = df[['unique_id','ds','y']]
598
-
599
- freq = determine_frequency(df)
600
-
601
- df = df.drop_duplicates(subset=['ds']).reset_index(drop=True)
602
-
603
- plot_type = st.sidebar.selectbox("Select Visualization", ["Matplotlib", "Plotly"])
604
- if st.sidebar.button("Submit"):
605
- start_time=time.time()
606
- anom_df = nixtla_client.detect_anomalies(
607
- df=df,
608
- freq=freq,
609
- level=90
610
- )
611
- st.session_state.anom_df = anom_df
612
-
613
- if 'anom_df' in st.session_state:
614
- anom_df = st.session_state.anom_df
615
-
616
- if plot_type == "Matplotlib":
617
- # Convert the Plotly figure to a Matplotlib figure if needed
618
- # Note: You may need to handle this conversion depending on your specific use case
619
- # For now, this example assumes that you are using a Matplotlib figure
620
- fig = nixtla_client.plot(df, anom_df, level=[90], engine='matplotlib')
621
- st.pyplot(fig)
622
- elif plot_type == "Plotly":
623
- # Plotly figure directly
624
- fig = nixtla_client.plot(df, anom_df, level=[90], engine='plotly')
625
- st.plotly_chart(fig)
626
-
627
- end_time = time.time() # End timing
628
- time_taken = end_time - start_time
629
- st.success(f"Time taken for TimeGPT forecast: {time_taken:.2f} seconds")
630
-
631
-
632
- st.markdown('You can download Input and Forecast Data below')
633
- tab_insample, tab_forecast = st.tabs(
634
- ["Input data", "Forecast"]
635
- )
636
-
637
- with tab_insample:
638
- df_grid = df.drop(columns="unique_id")
639
- st.write(df_grid)
640
- # grid_table = AgGrid(
641
- # df_grid,
642
- # theme="alpine",
643
- # )
644
-
645
- with tab_forecast:
646
- df_grid = anom_df
647
- st.write(df_grid)
648
- # grid_table = AgGrid(
649
- # df_grid,
650
- # theme="alpine",
651
- # )
652
-
653
-
654
-
655
 
656
  pg = st.navigation({
657
  "Neuralforecast": [
658
- # Load pages from functions
659
- st.Page(transfer_learning_forecasting, title="Zero-shot Forecasting", default=True, icon=":material/query_stats:"),
660
- st.Page(dynamic_forecasting, title="Personalized Neural Forecasting", icon=":material/monitoring:"),
661
  ],
662
- "TimeGPT": [
663
- # Load pages from functions
664
- st.Page(timegpt_fcst, title="TimeGPT Forecast", icon=":material/smart_toy:"),
665
- st.Page(timegpt_anom, title="TimeGPT Anomalies Detection", icon=":material/detector_offline:")
666
- ]
667
  })
668
 
669
  pg.run()
670
-
 
1
  import streamlit as st
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
+ import time
5
  from neuralforecast.core import NeuralForecast
6
  from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
7
  from neuralforecast.losses.pytorch import HuberMQLoss
8
  from neuralforecast.utils import AirPassengersDF
9
+ import plotly.graph_objects as go
 
 
 
10
 
11
  st.set_page_config(layout='wide')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def generate_forecast(model, df, tag=False):
 
 
14
  if tag == 'retrain':
15
+ return model.predict()
16
+ return model.predict(df=df)
 
 
17
 
18
  def determine_frequency(df):
19
  df['ds'] = pd.to_datetime(df['ds'])
20
+ df = df.drop_duplicates(subset='ds').set_index('ds')
 
 
 
 
 
 
 
 
 
 
 
21
  freq = pd.infer_freq(df.index)
22
  if not freq:
23
+ st.warning('Defaulting to Daily frequency due to date inconsistencies. Please check your data.', icon="⚠️")
24
  freq = 'D'
 
25
  return freq
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def plot_forecasts(forecast_df, train_df, title):
 
28
  plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
 
 
29
  historical_col = 'y'
30
  forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
31
  lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
32
  hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
33
+
34
  if forecast_col is None:
35
  raise KeyError("No forecast column found in the data.")
36
+
 
37
  fig = go.Figure()
 
 
38
  fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[historical_col], mode='lines', name='Historical'))
 
 
39
  fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[forecast_col], mode='lines', name='Forecast'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ if lo_col and hi_col:
42
+ fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[hi_col], mode='lines', line=dict(color='rgba(0,100,80,0.2)'), showlegend=False))
43
+ fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[lo_col], mode='lines', line=dict(color='rgba(0,100,80,0.2)'), fill='tonexty', fillcolor='rgba(0,100,80,0.2)', name='90% Confidence Interval'))
44
 
45
+ fig.update_layout(title=title, xaxis_title='Timestamp [t]', yaxis_title='Value', template='plotly_white')
46
+ st.plotly_chart(fig)
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def select_model(horizon, model_type, max_steps=50):
49
  if model_type == 'NHITS':
50
+ return NHITS(input_size=5 * horizon, h=horizon, max_steps=max_steps, stack_types=3*['identity'], n_blocks=3*[1], mlp_units=[[256, 256] for _ in range(3)], batch_size=32, scaler_type='standard', loss=HuberMQLoss(level=[90]))
 
 
 
 
 
 
 
 
 
 
51
  elif model_type == 'TimesNet':
52
+ return TimesNet(h=horizon, input_size=horizon * 5, hidden_size=32, conv_hidden_size=64, loss=HuberMQLoss(level=[90]), scaler_type='standard', learning_rate=1e-3, max_steps=max_steps)
 
 
 
 
 
 
 
 
 
 
 
53
  elif model_type == 'LSTM':
54
+ return LSTM(h=horizon, input_size=horizon * 5, loss=HuberMQLoss(level=[90]), scaler_type='standard', encoder_n_layers=3, encoder_hidden_size=256, context_size=10, decoder_hidden_size=256, decoder_layers=3, max_steps=max_steps)
 
 
 
 
 
 
 
 
 
55
  elif model_type == 'TFT':
56
+ return TFT(h=horizon, input_size=horizon*5, hidden_size=96, loss=HuberMQLoss(level=[90]), learning_rate=0.005, scaler_type='standard', windows_batch_size=128, max_steps=max_steps)
 
 
 
 
 
 
 
 
 
 
57
  else:
58
  raise ValueError(f"Unsupported model type: {model_type}")
59
 
60
+ def model_train(df, model, freq):
61
  nf = NeuralForecast(models=[model], freq=freq)
62
  df['ds'] = pd.to_datetime(df['ds'])
63
  nf.fit(df)
64
  return nf
65
 
66
+ def forecast_time_series(df, model_type, horizon, max_steps, y_col):
67
+ start_time = time.time()
68
  freq = determine_frequency(df)
69
  st.sidebar.write(f"Data frequency: {freq}")
70
 
71
  selected_model = select_model(horizon, model_type, max_steps)
72
+ model = model_train(df, selected_model, freq)
 
73
 
74
+ forecast_results = {model_type: generate_forecast(model, df, tag='retrain')}
 
 
75
  st.session_state.forecast_results = forecast_results
76
+
77
  for model_name, forecast_df in forecast_results.items():
78
  plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
79
+
80
+ time_taken = time.time() - start_time
 
81
  st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
82
+
83
  if 'forecast_results' in st.session_state:
84
+ st.markdown('Download Input and Forecast Data below')
85
+ tab_insample, tab_forecast = st.tabs(["Input data", "Forecast"])
86
+
 
 
 
 
87
  with tab_insample:
88
  df_grid = df.drop(columns="unique_id")
89
  st.write(df_grid)
90
+
 
 
 
 
91
  with tab_forecast:
92
  if model_type in forecast_results:
93
  df_grid = forecast_results[model_type]
94
  st.write(df_grid)
 
 
 
 
95
 
96
  @st.cache_data
97
  def load_default():
98
+ return AirPassengersDF.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ def personalized_forecasting():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  st.title("Personalized Neural Forecasting")
102
+ st.markdown("Train a time series forecasting model from scratch using various deep neural network models.")
103
+
 
 
 
 
104
  with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
105
+ uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
106
+ df = pd.read_csv(uploaded_file) if uploaded_file else load_default()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  columns = df.columns.tolist()
109
  ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
110
+ target_columns = [col for col in columns if col != ds_col]
111
  y_col = st.selectbox("Select Target column", options=target_columns, index=0)
112
 
113
+ df = df.rename(columns={ds_col: 'ds', y_col: 'y'}).assign(unique_id=1)[['unique_id', 'ds', 'y']]
 
114
 
 
 
 
 
 
 
 
115
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")
116
+ dynamic_model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
117
  dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=12)
118
  dynamic_max_steps = st.sidebar.number_input('Max steps', value=20)
119
 
120
  if st.sidebar.button("Submit"):
121
+ with st.spinner('Training model...'):
122
+ forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps, y_col)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  pg = st.navigation({
125
  "Neuralforecast": [
126
+ st.Page(personalized_forecasting, title="Personalized Forecasting", icon=":star:")
 
 
127
  ],
 
 
 
 
 
128
  })
129
 
130
  pg.run()