Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,8 @@ from neuralforecast.losses.pytorch import HuberMQLoss
|
|
8 |
from neuralforecast.utils import AirPassengersDF
|
9 |
import time
|
10 |
from st_aggrid import AgGrid
|
|
|
|
|
11 |
|
12 |
@st.cache_resource
|
13 |
def load_model(path, freq):
|
@@ -107,7 +109,7 @@ def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_mo
|
|
107 |
else:
|
108 |
raise ValueError(f"Unsupported frequency: {freq}")
|
109 |
|
110 |
-
def select_model(horizon, model_type, max_steps=
|
111 |
if model_type == 'NHITS':
|
112 |
return NHITS(input_size=5 * horizon,
|
113 |
h=horizon,
|
@@ -304,17 +306,45 @@ def dynamic_forecasting():
|
|
304 |
st.sidebar.subheader("Dynamic Model Selection and Forecasting")
|
305 |
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
306 |
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
|
307 |
-
dynamic_max_steps = st.sidebar.number_input('Max steps', value=
|
308 |
|
309 |
if st.sidebar.button("Submit"):
|
310 |
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
|
311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
pg = st.navigation({
|
313 |
-
"
|
314 |
# Load pages from functions
|
315 |
st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/query_stats:"),
|
316 |
st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"),
|
317 |
-
]
|
|
|
|
|
|
|
|
|
|
|
318 |
})
|
319 |
|
320 |
try:
|
|
|
8 |
from neuralforecast.utils import AirPassengersDF
|
9 |
import time
|
10 |
from st_aggrid import AgGrid
|
11 |
+
from nixtla import NixtlaClient
|
12 |
+
|
13 |
|
14 |
@st.cache_resource
|
15 |
def load_model(path, freq):
|
|
|
109 |
else:
|
110 |
raise ValueError(f"Unsupported frequency: {freq}")
|
111 |
|
112 |
+
def select_model(horizon, model_type, max_steps=50):
|
113 |
if model_type == 'NHITS':
|
114 |
return NHITS(input_size=5 * horizon,
|
115 |
h=horizon,
|
|
|
306 |
st.sidebar.subheader("Dynamic Model Selection and Forecasting")
|
307 |
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
308 |
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
|
309 |
+
dynamic_max_steps = st.sidebar.number_input('Max steps', value=10)
|
310 |
|
311 |
if st.sidebar.button("Submit"):
|
312 |
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
|
313 |
|
314 |
+
def timegpt():
|
315 |
+
st.title("TimeGPT Forecasting")
|
316 |
+
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
|
317 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
318 |
+
if uploaded_file:
|
319 |
+
df = pd.read_csv(uploaded_file)
|
320 |
+
st.session_state.df = df
|
321 |
+
else:
|
322 |
+
df = load_default()
|
323 |
+
st.session_state.df = df
|
324 |
+
|
325 |
+
# Column selection
|
326 |
+
columns = df.columns.tolist() # Convert Index to list
|
327 |
+
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
|
328 |
+
y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
pg = st.navigation({
|
338 |
+
"NeuralForecast": [
|
339 |
# Load pages from functions
|
340 |
st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/query_stats:"),
|
341 |
st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"),
|
342 |
+
],
|
343 |
+
"TimeGPT": [
|
344 |
+
# Load pages from functions
|
345 |
+
st.Page(timegpt, title="TimeGPT Forecast", icon=":material/smart_toy:")
|
346 |
+
st.Page(timegpt, title="TimeGPT Anomalies Detection", icon=":material/detector_offline:")
|
347 |
+
]
|
348 |
})
|
349 |
|
350 |
try:
|