import pandas as pd import streamlit as st from model import Model from plots import Plots from stock_data_loader import StockDataLoader class StockModelPage: def __init__(self): self.tickers = ['NVDA', 'AAPL', 'GOOGL', 'MSFT', 'AMZN'] self.setup_sidebar() def setup_sidebar(self): self.ticker = st.sidebar.selectbox('Choose Stock Ticker', self.tickers) self.start_date = st.sidebar.date_input('Start Date', value=pd.to_datetime('2010-01-01')) self.end_date = st.sidebar.date_input('End Date', value=pd.to_datetime('today')) self.load_button_clicked = st.sidebar.button('Load Data') def load_data(self): if self.load_button_clicked: loader = StockDataLoader(self.ticker, self.start_date, self.end_date) st.session_state['stock_data'] = loader.get_stock_data() st.write("--------------------------------------------") st.write(f"Data for {self.ticker} from {self.start_date} to {self.end_date} loaded successfully!") def handle_model_training(self): if 'stock_data' in st.session_state: stock_data = st.session_state['stock_data'] if st.button('Train Model'): st.write("Training Model...") model = Model(stock_data) model.train_lstm() predictions = model.make_predictions() future_predictions = model.forecast_future(days=5) self.plot_predictions(stock_data, predictions, future_predictions) else: st.write("Click the button above to train the model.") else: st.write("--------------------------------------------") st.write("Please load data before training the model.") def plot_predictions(self, stock_data, predictions, future_predictions): plot_instance = Plots(stock_data) plot_instance.plot_predictions(predictions, future_predictions) def run(self): st.write("--------------------------------------------") st.write(f'
🤖 Real-Time Stock Prediction', unsafe_allow_html=True) self.load_data() self.handle_model_training()