jonathan-cristovao's picture
Upload 10 files
ef94dec verified
raw
history blame
2.27 kB
import pandas as pd
import streamlit as st
from model import Model
from plots import Plots
from stock_data_loader import StockDataLoader
class StockModelPage:
def __init__(self):
self.tickers = ['NVDA', 'AAPL', 'GOOGL', 'MSFT', 'AMZN']
self.setup_sidebar()
def setup_sidebar(self):
self.ticker = st.sidebar.selectbox('Choose Stock Ticker', self.tickers)
self.start_date = st.sidebar.date_input('Start Date', value=pd.to_datetime('2010-01-01'))
self.end_date = st.sidebar.date_input('End Date', value=pd.to_datetime('today'))
self.load_button_clicked = st.sidebar.button('Load Data')
def load_data(self):
if self.load_button_clicked:
loader = StockDataLoader(self.ticker, self.start_date, self.end_date)
st.session_state['stock_data'] = loader.get_stock_data()
st.write("--------------------------------------------")
st.write(f"Data for {self.ticker} from {self.start_date} to {self.end_date} loaded successfully!")
def handle_model_training(self):
if 'stock_data' in st.session_state:
stock_data = st.session_state['stock_data']
if st.button('Train Model'):
st.write("Training Model...")
model = Model(stock_data)
model.train_lstm()
predictions = model.make_predictions()
future_predictions = model.forecast_future(days=5)
self.plot_predictions(stock_data, predictions, future_predictions)
else:
st.write("Click the button above to train the model.")
else:
st.write("--------------------------------------------")
st.write("Please load data before training the model.")
def plot_predictions(self, stock_data, predictions, future_predictions):
plot_instance = Plots(stock_data)
plot_instance.plot_predictions(predictions, future_predictions)
def run(self):
st.write("--------------------------------------------")
st.write(f'<div style="font-size:50px">🤖 Real-Time Stock Prediction', unsafe_allow_html=True)
self.load_data()
self.handle_model_training()