azrai99 commited on
Commit
65275be
·
verified ·
1 Parent(s): 27e1b78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -18
app.py CHANGED
@@ -174,10 +174,10 @@ def forecast_time_series(df, model_type, horizon, max_steps,y_col):
174
  st.sidebar.write(f"Data frequency: {freq}")
175
 
176
  selected_model = select_model(horizon, model_type, max_steps)
 
177
  model = model_train(df, selected_model,freq)
178
 
179
  forecast_results = {}
180
- st.sidebar.write(f"Generating forecast using {model_type} model...")
181
  forecast_results[model_type] = generate_forecast(model, df, tag='retrain')
182
 
183
  for model_name, forecast_df in forecast_results.items():
@@ -221,7 +221,7 @@ def transfer_learning_forecasting():
221
  # Model selection and forecasting
222
  st.sidebar.subheader("Model Selection and Forecasting")
223
  model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
224
- horizon = st.sidebar.number_input("Forecast horizon", value=18)
225
 
226
  df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
227
  df['unique_id']=1
@@ -312,8 +312,8 @@ def dynamic_forecasting():
312
  # Dynamic forecasting
313
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")
314
  dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
315
- dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
316
- dynamic_max_steps = st.sidebar.number_input('Max steps', value=10)
317
 
318
  if st.sidebar.button("Submit"):
319
  forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
@@ -349,13 +349,14 @@ def timegpt_fcst():
349
  id_col = 'ts_test'
350
  df['unique_id']=id_col
351
  freq = determine_frequency(df)
352
-
353
- forecast_df = nixtla_client.forecast(
354
- df=df,
355
- h=7,
356
- freq=freq,
357
- level=[90]
358
- )
 
359
 
360
  nixtla_client.plot(
361
  forecast_df,
@@ -394,15 +395,16 @@ def timegpt_anom():
394
  id_col = 'ts_test'
395
  df['unique_id']=id_col
396
  freq = determine_frequency(df)
397
-
398
- anom_df = nixtla_client.detect_anomalies(
399
- df=df,
400
- freq=freq,
401
- level=90
402
- )
 
403
 
404
  nixtla_client.plot(
405
- forecast_df,
406
  anom_df
407
  )
408
 
 
174
  st.sidebar.write(f"Data frequency: {freq}")
175
 
176
  selected_model = select_model(horizon, model_type, max_steps)
177
+ st.spinner(f"Training {model_type} model...")
178
  model = model_train(df, selected_model,freq)
179
 
180
  forecast_results = {}
 
181
  forecast_results[model_type] = generate_forecast(model, df, tag='retrain')
182
 
183
  for model_name, forecast_df in forecast_results.items():
 
221
  # Model selection and forecasting
222
  st.sidebar.subheader("Model Selection and Forecasting")
223
  model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
224
+ horizon = st.sidebar.number_input("Forecast horizon", value=12)
225
 
226
  df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
227
  df['unique_id']=1
 
312
  # Dynamic forecasting
313
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")
314
  dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
315
+ dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=12)
316
+ dynamic_max_steps = st.sidebar.number_input('Max steps', value=20)
317
 
318
  if st.sidebar.button("Submit"):
319
  forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
 
349
  id_col = 'ts_test'
350
  df['unique_id']=id_col
351
  freq = determine_frequency(df)
352
+
353
+ if st.sidebar.button("Submit"):
354
+ forecast_df = nixtla_client.forecast(
355
+ df=df,
356
+ h=7,
357
+ freq=freq,
358
+ level=[90]
359
+ )
360
 
361
  nixtla_client.plot(
362
  forecast_df,
 
395
  id_col = 'ts_test'
396
  df['unique_id']=id_col
397
  freq = determine_frequency(df)
398
+
399
+ if st.sidebar.button("Submit"):
400
+ anom_df = nixtla_client.detect_anomalies(
401
+ df=df,
402
+ freq=freq,
403
+ level=90
404
+ )
405
 
406
  nixtla_client.plot(
407
+ df,
408
  anom_df
409
  )
410