azrai99 commited on
Commit
7cecffc
·
verified ·
1 Parent(s): 70a8bde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -4
app.py CHANGED
@@ -8,6 +8,8 @@ from neuralforecast.losses.pytorch import HuberMQLoss
8
  from neuralforecast.utils import AirPassengersDF
9
  import time
10
  from st_aggrid import AgGrid
 
 
11
 
12
  @st.cache_resource
13
  def load_model(path, freq):
@@ -107,7 +109,7 @@ def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_mo
107
  else:
108
  raise ValueError(f"Unsupported frequency: {freq}")
109
 
110
- def select_model(horizon, model_type, max_steps=200):
111
  if model_type == 'NHITS':
112
  return NHITS(input_size=5 * horizon,
113
  h=horizon,
@@ -304,17 +306,45 @@ def dynamic_forecasting():
304
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")
305
  dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
306
  dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
307
- dynamic_max_steps = st.sidebar.number_input('Max steps', value=200)
308
 
309
  if st.sidebar.button("Submit"):
310
  forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  pg = st.navigation({
313
- "Overview": [
314
  # Load pages from functions
315
  st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/query_stats:"),
316
  st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"),
317
- ]
 
 
 
 
 
318
  })
319
 
320
  try:
 
8
  from neuralforecast.utils import AirPassengersDF
9
  import time
10
  from st_aggrid import AgGrid
11
+ from nixtla import NixtlaClient
12
+
13
 
14
  @st.cache_resource
15
  def load_model(path, freq):
 
109
  else:
110
  raise ValueError(f"Unsupported frequency: {freq}")
111
 
112
+ def select_model(horizon, model_type, max_steps=50):
113
  if model_type == 'NHITS':
114
  return NHITS(input_size=5 * horizon,
115
  h=horizon,
 
306
  st.sidebar.subheader("Dynamic Model Selection and Forecasting")
307
  dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
308
  dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
309
+ dynamic_max_steps = st.sidebar.number_input('Max steps', value=10)
310
 
311
  if st.sidebar.button("Submit"):
312
  forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
313
 
314
+ def timegpt():
315
+ st.title("TimeGPT Forecasting")
316
+ with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
317
+ uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
318
+ if uploaded_file:
319
+ df = pd.read_csv(uploaded_file)
320
+ st.session_state.df = df
321
+ else:
322
+ df = load_default()
323
+ st.session_state.df = df
324
+
325
+ # Column selection
326
+ columns = df.columns.tolist() # Convert Index to list
327
+ ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
328
+ y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1
329
+
330
+
331
+
332
+
333
+
334
+
335
+
336
+
337
  pg = st.navigation({
338
+ "NeuralForecast": [
339
  # Load pages from functions
340
  st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/query_stats:"),
341
  st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"),
342
+ ],
343
+ "TimeGPT": [
344
+ # Load pages from functions
345
+ st.Page(timegpt, title="TimeGPT Forecast", icon=":material/smart_toy:")
346
+ st.Page(timegpt, title="TimeGPT Anomalies Detection", icon=":material/detector_offline:")
347
+ ]
348
  })
349
 
350
  try: