|
|
|
""" |
|
Created on Thu Jun 1 13:52:42 2023 |
|
|
|
@author: ME |
|
""" |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import streamlit as st |
|
import seaborn as sns |
|
import pickle |
|
import base64 |
|
import joblib |
|
from src.preprocess import preprocess_data |
|
import mplcursors |
|
|
|
|
|
|
|
json_path = r"C:/Users/ME/Desktop/Blessing_AI/Weather_Prediction/Artifacts/feature_dict.joblib" |
|
loaded_data = joblib.load(json_path) |
|
|
|
|
|
model_path = r"C:/Users/ME/Desktop/Blessing_AI/Weather_Prediction/Artifacts/trained_prophet_model.pkl" |
|
with open(model_path, 'rb') as file_: |
|
model = pickle.load(file_) |
|
|
|
|
|
st.title("Weather Forecast: Average Daily Temperature Projections for Nigerian Airports") |
|
|
|
col1,col2 = st.columns(2) |
|
|
|
start_d = col1.date_input("Select start date") |
|
end_d = col2.date_input("Select end date") |
|
|
|
delta = end_d - start_d |
|
days = delta.days |
|
|
|
|
|
y1 = start_d.year |
|
m1 = start_d.month |
|
d1 = start_d.day |
|
|
|
|
|
y2 = end_d.year |
|
m2 = end_d.month |
|
d2 = end_d.day |
|
|
|
start_date = (y1,m1,d1) |
|
end_date = (y2,m2,d2) |
|
|
|
if days > 10 : |
|
st.write("Select a lower number of days range (10 and below)") |
|
|
|
else: |
|
|
|
name = st.selectbox("Select airport of interest",options=tuple(loaded_data["Airport_name"].keys())) |
|
|
|
data = preprocess_data( |
|
start_d=start_date, |
|
end_d= end_date, |
|
airport_name = name |
|
) |
|
|
|
forecast_bn = st.button("Forecast") |
|
|
|
if forecast_bn: |
|
pred_df = model.predict(data) |
|
pred_df = pred_df[["ds","yhat"]] |
|
pred_df.columns = ["Date","Prediction"] |
|
|
|
|
|
pred_df["Date"]= pred_df['Date'].dt.strftime('%Y-%m-%d') |
|
|
|
|
|
pred_df["Prediction"] = pred_df["Prediction"].round(2) |
|
|
|
tab1,tab2 = st.tabs(["Predictions","Check plot"]) |
|
tab1.write(pred_df) |
|
|
|
|
|
csv = pred_df.to_csv(index=False) |
|
b64 = base64.b64encode(csv.encode()).decode() |
|
href = f'<a href="data:file/csv;base64,{b64}" download="data.csv">Download CSV File</a>' |
|
st.markdown(href, unsafe_allow_html=True) |
|
|
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
|
|
|
x = pred_df["Date"] |
|
y = pred_df["Prediction"] |
|
ax.plot(x, y, marker='o', linestyle='--') |
|
|
|
|
|
ax.set_xlabel('Date') |
|
|
|
|
|
ax.set_ylabel('Average temperature') |
|
|
|
|
|
sns.set_style('whitegrid') |
|
|
|
|
|
ax.set_xlim(min(x), max(x)) |
|
ax.set_ylim(min(y), max(y)) |
|
|
|
|
|
fig.autofmt_xdate() |
|
|
|
|
|
cursor = mplcursors.cursor(hover=True) |
|
|
|
@cursor.connect("add") |
|
def on_add(sel): |
|
x = sel.target[0] |
|
y = sel.target[1] |
|
date_str = x.strftime("%Y-%m-%d") |
|
sel.annotation.set_text(f'Date: {date_str}\nValue: {y:.2f}') |
|
|
|
|
|
tab2.pyplot(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|