Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -49,7 +49,20 @@ def lstm_prediction(file, epochs):
|
|
| 49 |
model.add(LSTM(4, input_shape=(1, look_back)))
|
| 50 |
model.add(Dense(1))
|
| 51 |
model.compile(loss='mean_squared_error', optimizer='adam')
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# Make predictions
|
| 55 |
trainPredict = model.predict(trainX)
|
|
|
|
| 49 |
model.add(LSTM(4, input_shape=(1, look_back)))
|
| 50 |
model.add(Dense(1))
|
| 51 |
model.compile(loss='mean_squared_error', optimizer='adam')
|
| 52 |
+
|
| 53 |
+
# Set up a callback to update Streamlit during training
|
| 54 |
+
class StreamlitCallback(tf.keras.callbacks.Callback):
|
| 55 |
+
def __init__(self):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.epoch_bar = st.progress(0)
|
| 58 |
+
self.loss_placeholder = st.empty()
|
| 59 |
+
|
| 60 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 61 |
+
self.epoch_bar.progress((epoch + 1) / epochs)
|
| 62 |
+
self.loss_placeholder.text(f'Epoch {epoch + 1}/{epochs}, Loss: {logs["loss"]:.4f}')
|
| 63 |
+
|
| 64 |
+
# Fit the model
|
| 65 |
+
model.fit(trainX, trainY, epochs=epochs, batch_size=1, verbose=0, callbacks=[StreamlitCallback()])
|
| 66 |
|
| 67 |
# Make predictions
|
| 68 |
trainPredict = model.predict(trainX)
|