CJRobert commited on
Commit
3bd93ea
·
verified ·
1 Parent(s): 851b070

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -1
app.py CHANGED
@@ -49,7 +49,20 @@ def lstm_prediction(file, epochs):
49
  model.add(LSTM(4, input_shape=(1, look_back)))
50
  model.add(Dense(1))
51
  model.compile(loss='mean_squared_error', optimizer='adam')
52
- model.fit(trainX, trainY, epochs=epochs, batch_size=1, verbose=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # Make predictions
55
  trainPredict = model.predict(trainX)
 
49
  model.add(LSTM(4, input_shape=(1, look_back)))
50
  model.add(Dense(1))
51
  model.compile(loss='mean_squared_error', optimizer='adam')
52
+
53
+ # Set up a callback to update Streamlit during training
54
+ class StreamlitCallback(tf.keras.callbacks.Callback):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.epoch_bar = st.progress(0)
58
+ self.loss_placeholder = st.empty()
59
+
60
+ def on_epoch_end(self, epoch, logs=None):
61
+ self.epoch_bar.progress((epoch + 1) / epochs)
62
+ self.loss_placeholder.text(f'Epoch {epoch + 1}/{epochs}, Loss: {logs["loss"]:.4f}')
63
+
64
+ # Fit the model
65
+ model.fit(trainX, trainY, epochs=epochs, batch_size=1, verbose=0, callbacks=[StreamlitCallback()])
66
 
67
  # Make predictions
68
  trainPredict = model.predict(trainX)