OwusuDynamo's picture
Upload app.py
a5cde1f
raw
history blame
3.44 kB
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 1 13:52:42 2023
@author: ME
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import seaborn as sns
import pickle
import base64
import joblib
from src.preprocess import preprocess_data
import mplcursors
#load json file
json_path = r"C:/Users/ME/Desktop/Blessing_AI/Weather_Prediction/Artifacts/feature_dict.joblib"
loaded_data = joblib.load(json_path)
#load model
model_path = r"C:/Users/ME/Desktop/Blessing_AI/Weather_Prediction/Artifacts/trained_prophet_model.pkl"
with open(model_path, 'rb') as file_:
model = pickle.load(file_)
st.title("Weather Forecast: Average Daily Temperature Projections for Nigerian Airports")
col1,col2 = st.columns(2)
start_d = col1.date_input("Select start date")
end_d = col2.date_input("Select end date")
#get number of days
delta = end_d - start_d
days = delta.days
#split start date into year ,month and date
y1 = start_d.year
m1 = start_d.month
d1 = start_d.day
#split end date into year ,month and date
y2 = end_d.year
m2 = end_d.month
d2 = end_d.day
start_date = (y1,m1,d1)
end_date = (y2,m2,d2)
if days > 10 :
st.write("Select a lower number of days range (10 and below)")
else:
name = st.selectbox("Select airport of interest",options=tuple(loaded_data["Airport_name"].keys()))
data = preprocess_data(
start_d=start_date,
end_d= end_date,
airport_name = name
)
forecast_bn = st.button("Forecast")
if forecast_bn:
pred_df = model.predict(data)
pred_df = pred_df[["ds","yhat"]]
pred_df.columns = ["Date","Prediction"]
#Display only year,month and date
pred_df["Date"]= pred_df['Date'].dt.strftime('%Y-%m-%d')
#Convert column to 2 dp
pred_df["Prediction"] = pred_df["Prediction"].round(2)
tab1,tab2 = st.tabs(["Predictions","Check plot"])
tab1.write(pred_df)
# Add a download button
csv = pred_df.to_csv(index=False)
b64 = base64.b64encode(csv.encode()).decode() # Encode the DataFrame as a base64 string
href = f'<a href="data:file/csv;base64,{b64}" download="data.csv">Download CSV File</a>'
st.markdown(href, unsafe_allow_html=True)
# Create the figure and axis objects
fig, ax = plt.subplots()
# Plot the time series
x = pred_df["Date"]
y = pred_df["Prediction"]
ax.plot(x, y, marker='o', linestyle='--')
# Set x-axis label
ax.set_xlabel('Date')
# Set y-axis label
ax.set_ylabel('Average temperature')
# Style the plot
sns.set_style('whitegrid')
# Set the limits of the x-axis and y-axis
ax.set_xlim(min(x), max(x))
ax.set_ylim(min(y), max(y))
# Format the x-axis date ticks
fig.autofmt_xdate()
# Add cursor interaction
cursor = mplcursors.cursor(hover=True)
# Define the annotation format
@cursor.connect("add")
def on_add(sel):
x = sel.target[0]
y = sel.target[1]
date_str = x.strftime("%Y-%m-%d")
sel.annotation.set_text(f'Date: {date_str}\nValue: {y:.2f}')
# Show the plot
tab2.pyplot(fig)