File size: 2,274 Bytes
ef94dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pandas as pd
import streamlit as st
from model import Model
from plots import Plots
from stock_data_loader import StockDataLoader

class StockModelPage:
    def __init__(self):
        self.tickers = ['NVDA', 'AAPL', 'GOOGL', 'MSFT', 'AMZN']
        self.setup_sidebar()

    def setup_sidebar(self):
        self.ticker = st.sidebar.selectbox('Choose Stock Ticker', self.tickers)
        self.start_date = st.sidebar.date_input('Start Date', value=pd.to_datetime('2010-01-01'))
        self.end_date = st.sidebar.date_input('End Date', value=pd.to_datetime('today'))
        self.load_button_clicked = st.sidebar.button('Load Data')

    def load_data(self):
        if self.load_button_clicked:
            loader = StockDataLoader(self.ticker, self.start_date, self.end_date)
            st.session_state['stock_data'] = loader.get_stock_data()
            st.write("--------------------------------------------")
            st.write(f"Data for {self.ticker} from {self.start_date} to {self.end_date} loaded successfully!")

    def handle_model_training(self):
        if 'stock_data' in st.session_state:
            stock_data = st.session_state['stock_data']
            if st.button('Train Model'):
                st.write("Training Model...")
                model = Model(stock_data)
                model.train_lstm()
                predictions = model.make_predictions()
                future_predictions = model.forecast_future(days=5)
                self.plot_predictions(stock_data, predictions, future_predictions)
            else:
                st.write("Click the button above to train the model.")
        else:
            st.write("--------------------------------------------")
            st.write("Please load data before training the model.")

    def plot_predictions(self, stock_data, predictions, future_predictions):
        plot_instance = Plots(stock_data)
        plot_instance.plot_predictions(predictions, future_predictions)

    def run(self):
        st.write("--------------------------------------------")
        st.write(f'<div style="font-size:50px">🤖 Real-Time Stock Prediction', unsafe_allow_html=True)
        self.load_data()
        self.handle_model_training()