Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -174,10 +174,10 @@ def forecast_time_series(df, model_type, horizon, max_steps,y_col):
|
|
174 |
st.sidebar.write(f"Data frequency: {freq}")
|
175 |
|
176 |
selected_model = select_model(horizon, model_type, max_steps)
|
|
|
177 |
model = model_train(df, selected_model,freq)
|
178 |
|
179 |
forecast_results = {}
|
180 |
-
st.sidebar.write(f"Generating forecast using {model_type} model...")
|
181 |
forecast_results[model_type] = generate_forecast(model, df, tag='retrain')
|
182 |
|
183 |
for model_name, forecast_df in forecast_results.items():
|
@@ -221,7 +221,7 @@ def transfer_learning_forecasting():
|
|
221 |
# Model selection and forecasting
|
222 |
st.sidebar.subheader("Model Selection and Forecasting")
|
223 |
model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
224 |
-
horizon = st.sidebar.number_input("Forecast horizon", value=
|
225 |
|
226 |
df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
|
227 |
df['unique_id']=1
|
@@ -312,8 +312,8 @@ def dynamic_forecasting():
|
|
312 |
# Dynamic forecasting
|
313 |
st.sidebar.subheader("Dynamic Model Selection and Forecasting")
|
314 |
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
315 |
-
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=
|
316 |
-
dynamic_max_steps = st.sidebar.number_input('Max steps', value=
|
317 |
|
318 |
if st.sidebar.button("Submit"):
|
319 |
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
|
@@ -349,13 +349,14 @@ def timegpt_fcst():
|
|
349 |
id_col = 'ts_test'
|
350 |
df['unique_id']=id_col
|
351 |
freq = determine_frequency(df)
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
359 |
|
360 |
nixtla_client.plot(
|
361 |
forecast_df,
|
@@ -394,15 +395,16 @@ def timegpt_anom():
|
|
394 |
id_col = 'ts_test'
|
395 |
df['unique_id']=id_col
|
396 |
freq = determine_frequency(df)
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
|
|
403 |
|
404 |
nixtla_client.plot(
|
405 |
-
|
406 |
anom_df
|
407 |
)
|
408 |
|
|
|
174 |
st.sidebar.write(f"Data frequency: {freq}")
|
175 |
|
176 |
selected_model = select_model(horizon, model_type, max_steps)
|
177 |
+
st.spinner(f"Training {model_type} model...")
|
178 |
model = model_train(df, selected_model,freq)
|
179 |
|
180 |
forecast_results = {}
|
|
|
181 |
forecast_results[model_type] = generate_forecast(model, df, tag='retrain')
|
182 |
|
183 |
for model_name, forecast_df in forecast_results.items():
|
|
|
221 |
# Model selection and forecasting
|
222 |
st.sidebar.subheader("Model Selection and Forecasting")
|
223 |
model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
224 |
+
horizon = st.sidebar.number_input("Forecast horizon", value=12)
|
225 |
|
226 |
df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
|
227 |
df['unique_id']=1
|
|
|
312 |
# Dynamic forecasting
|
313 |
st.sidebar.subheader("Dynamic Model Selection and Forecasting")
|
314 |
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
315 |
+
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=12)
|
316 |
+
dynamic_max_steps = st.sidebar.number_input('Max steps', value=20)
|
317 |
|
318 |
if st.sidebar.button("Submit"):
|
319 |
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
|
|
|
349 |
id_col = 'ts_test'
|
350 |
df['unique_id']=id_col
|
351 |
freq = determine_frequency(df)
|
352 |
+
|
353 |
+
if st.sidebar.button("Submit"):
|
354 |
+
forecast_df = nixtla_client.forecast(
|
355 |
+
df=df,
|
356 |
+
h=7,
|
357 |
+
freq=freq,
|
358 |
+
level=[90]
|
359 |
+
)
|
360 |
|
361 |
nixtla_client.plot(
|
362 |
forecast_df,
|
|
|
395 |
id_col = 'ts_test'
|
396 |
df['unique_id']=id_col
|
397 |
freq = determine_frequency(df)
|
398 |
+
|
399 |
+
if st.sidebar.button("Submit"):
|
400 |
+
anom_df = nixtla_client.detect_anomalies(
|
401 |
+
df=df,
|
402 |
+
freq=freq,
|
403 |
+
level=90
|
404 |
+
)
|
405 |
|
406 |
nixtla_client.plot(
|
407 |
+
df,
|
408 |
anom_df
|
409 |
)
|
410 |
|