RaulSalinasHerr commited on
Commit
21ab5cd
·
1 Parent(s): de80710

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sktime.forecasting.theta import ThetaForecaster
4
+ from sktime.forecasting.base import ForecastingHorizon
5
+ from sktime.utils.plotting import plot_series
6
+
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+
10
+ import gradio as gr
11
+ from enum import Enum
12
+ import pickle as pkl
13
+ import os
14
+
15
+
16
+ matplotlib.use("Agg")
17
+
18
+ class Sector(Enum):
19
+ Region = "Region"
20
+ Provincia = "Provincia"
21
+ Comuna = "Comuna"
22
+
23
+ class TypePrediction(Enum):
24
+ Origin = "Total Origin"
25
+ Destiny = "Total Destiny"
26
+ OriginDestiny = "Origin and destiny"
27
+
28
+ type_choices = [x.value for x in TypePrediction]
29
+
30
+
31
+ def load_data(path:str = "./data/trips.csv") -> pd.DataFrame:
32
+ """load trips.csv data from path"""
33
+
34
+ read_params = {
35
+ "encoding": "latin_1",
36
+ "sep": ";",
37
+ "decimal": ","
38
+ }
39
+
40
+ return pd.read_csv(path, **read_params)
41
+
42
+ def to_date(month: int,year:int):
43
+ return pd.Timestamp(day=1, month=month, year = year)
44
+
45
+ def to_date_row(x):
46
+ month = x["month_value"]
47
+ year = x["Anio"]
48
+ return to_date(month, year)
49
+
50
+ def preprocess_data(data: pd.DataFrame, sector: Sector = Sector.Region) -> pd.DataFrame:
51
+ """preprocess data, choose sector to get value"""
52
+
53
+ data.columns = [x.strip() for x in data.columns]
54
+
55
+ col_melt = list(data.columns[-12:]) #months as cols
56
+
57
+ mn = "month_name"
58
+
59
+ code_month = pd.DataFrame.from_dict(
60
+ data={
61
+ mn: col_melt,
62
+ "month_value": list(range(1,13))})
63
+
64
+ col_maintain = list(data.columns[:-12])
65
+
66
+ data_long = data.melt(
67
+ id_vars=col_maintain,
68
+ value_vars=col_melt,
69
+ var_name=mn).merge(
70
+ code_month,
71
+ how="left",
72
+ on=mn)
73
+
74
+ data_long["time_stamp"] = data_long.apply(
75
+ to_date_row,axis = 1)
76
+
77
+ unused_date_cols = ["Anio", "month_name", "month_value"]
78
+
79
+ data_long.drop(columns= unused_date_cols, inplace=True)
80
+
81
+ sector_name = sector.name
82
+ cut_sector = ["CUT {} Origen", "CUT {} Destino"]
83
+
84
+ col_sector = [x.format(sector_name) for x in cut_sector]
85
+ kv = ["time_stamp", "value"]
86
+
87
+ cols = col_sector.copy()
88
+
89
+ for lcol in kv:
90
+ cols.append(lcol)
91
+ # cols = col_sector.extend(kv)
92
+
93
+ data_sector = data_long[cols].copy()
94
+
95
+ col_sector.append("time_stamp")
96
+ data_agg = data_sector.groupby(by = col_sector).sum()
97
+ data_agg.value = np.int32(data_agg.value.values)
98
+ data_agg.query("value > 0", inplace = True)
99
+
100
+
101
+ data_agg.reset_index(inplace=True)
102
+ data_agg.set_index("time_stamp", inplace=True)
103
+ data_agg = data_agg.to_period("M")
104
+
105
+ renamer = {
106
+ data_agg.columns[0]: "sector_origin",
107
+ data_agg.columns[1]: "sector_destiny"
108
+ }
109
+
110
+ data_agg.rename(columns=renamer, inplace=True)
111
+ data_agg.reset_index(inplace=True)
112
+
113
+
114
+ data_agg.set_index(["sector_origin", "sector_destiny","time_stamp"],inplace=True)
115
+
116
+ return data_agg
117
+
118
+
119
+ def predict_dataframe(data_agg: pd.DataFrame, h:int = 12, prd:int = 24) -> ThetaForecaster:
120
+ """predict dataframe
121
+
122
+ Args:
123
+ data_grouped (pd.DataFrame): grouped dataframe with values
124
+ h (int, optional): Window to forecast, in months
125
+
126
+ Returns:
127
+ pd.Series: _description_
128
+ """
129
+
130
+
131
+
132
+
133
+ data_flat = data_agg.reset_index().copy()
134
+ periods = data_flat["time_stamp"].unique()
135
+ last_2years = periods[-prd:]
136
+ data_flat = data_flat[data_flat["time_stamp"].isin(last_2years)]
137
+ data_flat.set_index(["sector_origin", "sector_destiny", "time_stamp"], inplace=True)
138
+
139
+
140
+ # forecasters = [
141
+ # # ("TBATS", TBATS(sp = 12)),
142
+ # ("Theta", ThetaForecaster(sp= 12)),
143
+ # ("ETS", AutoETS(sp = 12))
144
+ # ]
145
+
146
+
147
+ # forecaster = AutoEnsembleForecaster(
148
+ # forecasters=forecasters, test_size= 0.15)
149
+
150
+ forecaster = ThetaForecaster(sp=12)
151
+
152
+ fh = ForecastingHorizon(np.arange(1,h), is_relative= True)
153
+
154
+ forecaster = forecaster.fit(y = data_flat, fh=fh)
155
+ return forecaster
156
+
157
+
158
+ def create_plot(
159
+ data_agg: pd.DataFrame,
160
+ pred: pd.DataFrame,
161
+ # pred_inter: pd.DataFrame,
162
+ sector_origin: int | None = 1,
163
+ sector_destiny: int | None = 1,
164
+ type_prediction: str = TypePrediction.OriginDestiny.value):
165
+
166
+
167
+ def to_series(
168
+ data: pd.DataFrame,
169
+ is_multi: bool = False) -> pd.DataFrame:
170
+
171
+ df = data.reset_index().copy()
172
+ if type_prediction == TypePrediction.Destiny.value:
173
+ qry = "sector_destiny == {}".format(sector_destiny)
174
+
175
+ elif type_prediction == TypePrediction.Origin.value:
176
+ qry = "sector_origin == {}".format(sector_origin)
177
+
178
+ else:
179
+ qry = "sector_origin == {} & sector_destiny == {}".format(sector_origin, sector_destiny)
180
+
181
+ if not is_multi:
182
+ df.query(qry, inplace=True)
183
+ else:
184
+ df = df[(df.iloc[:, 0] == sector_origin) & (df.iloc[:, 1] == sector_destiny)]
185
+
186
+ if type_prediction == TypePrediction.Origin.value:
187
+ df = df.groupby(["sector_origin", "time_stamp"]).sum(numeric_only= True).reset_index()
188
+
189
+ elif type_prediction == TypePrediction.Destiny.value:
190
+ df = df.groupby(["sector_destiny", "time_stamp"]).sum(numeric_only=True).reset_index()
191
+
192
+
193
+ drop_cols = ["sector_origin", "sector_destiny"]
194
+
195
+ if is_multi:
196
+ drop_cols = [(x, "", "") for x in drop_cols]
197
+
198
+ return df.drop(columns=drop_cols).set_index("time_stamp").squeeze()
199
+
200
+
201
+ x = to_series(data_agg)
202
+ y = to_series(pred)
203
+
204
+ if type_prediction == TypePrediction.Destiny.value:
205
+
206
+ title = "Total monthly touristic travels to region {}".format(sector_destiny)
207
+
208
+
209
+ if type_prediction == TypePrediction.Origin.value:
210
+ title = "Total monthly touristic travels from region {}".format(sector_origin)
211
+
212
+
213
+ elif type_prediction == TypePrediction.OriginDestiny.value:
214
+
215
+
216
+ title = "Monthly touristic travels from region {} to region {}".format(
217
+ sector_origin, sector_destiny)
218
+
219
+ fig, _ = plot_series(x,y,labels=["value", "forecast"],title=title)
220
+
221
+ return fig
222
+
223
+ def save_object(object, path:str):
224
+ with open(path, "wb") as file:
225
+ pkl.dump(object, file)
226
+
227
+ def load_object(path: str):
228
+ with open(path, "rb") as file:
229
+ return pkl.load(file)
230
+
231
+
232
+
233
+
234
+ def run(argv = None):
235
+
236
+ path_raw = "./data/data_raw.pkl"
237
+ path_preprocessed = "./data/data_preprocessed.pkl"
238
+ path_forecaster = "./data/forecaster.pkl"
239
+
240
+
241
+ if not os.path.exists(path_raw):
242
+ data = load_data()
243
+ # save_object(data, path_raw)
244
+
245
+ else:
246
+ data = load_object(path_raw)
247
+
248
+ if not os.path.exists(path_preprocessed):
249
+ data_preprocessed = preprocess_data(data)
250
+ save_object(data_preprocessed, path_preprocessed)
251
+ else:
252
+ data_preprocessed = load_object(path_preprocessed)
253
+
254
+ if not os.path.exists(path_forecaster):
255
+ forecaster = predict_dataframe(data_preprocessed)
256
+ save_object(forecaster, path_forecaster)
257
+ else:
258
+
259
+ forecaster = load_object(path_forecaster)
260
+
261
+ pred = forecaster.predict()
262
+
263
+ def wrapper(sector_origin, sector_destiny, type_prediction):
264
+
265
+ sector_origin = int(sector_origin)
266
+ sector_destiny = int(sector_destiny)
267
+
268
+ return create_plot(
269
+ data_preprocessed, pred,
270
+ sector_origin, sector_destiny, type_prediction)
271
+
272
+
273
+
274
+ params_slider = {
275
+ "minimum": 1,
276
+ "maximum": 16,
277
+ "step": 1
278
+ }
279
+
280
+
281
+
282
+
283
+ port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
284
+
285
+
286
+
287
+ with gr.Blocks() as app:
288
+
289
+ input_type = gr.components.Radio(
290
+ choices= type_choices,
291
+ value = type_choices[-1],
292
+ type = "value",
293
+ label = "Prediction Aggregation")
294
+
295
+ with gr.Tab("Region"):
296
+
297
+ input_origin = gr.components.Slider(
298
+ **params_slider, label = "Origin"
299
+ )
300
+
301
+ input_destiny = gr.components.Slider(
302
+ **params_slider, label= "Destiny")
303
+
304
+ predict_region_btn = gr.Button("Predict region")
305
+
306
+ output_plot = gr.Plot()
307
+
308
+ predict_region_btn.click(
309
+ fn = wrapper,
310
+ inputs = [input_origin, input_destiny, input_type],
311
+ outputs = output_plot,
312
+ api_name= "predict_region"
313
+ )
314
+
315
+
316
+ app.launch(
317
+ server_name= "0.0.0.0",
318
+ server_port=port,
319
+ share=False)
320
+
321
+
322
+ if __name__ == "__main__":
323
+ run()