louiecerv commited on
Commit
781fd86
·
1 Parent(s): 20322d0

sync with remote

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -123,10 +123,11 @@ tf.keras.backend.clear_session()
123
 
124
  # set up the model architecture
125
  model = tf.keras.models.Sequential([
126
- Flatten(input_shape=[256, 256, 1]),
127
- Dense(64, activation='relu'),
128
- Dense(256*256*2, activation='softmax'),
129
- Reshape((256, 256, 2))
 
130
  ])
131
 
132
  # specify how to train the model with algorithm, the loss function and metrics
@@ -140,4 +141,7 @@ model_summary = StringIO()
140
  model.summary(print_fn=lambda x: model_summary.write(x + '\n'))
141
 
142
  # Display the model summary in Streamlit
143
- st.markdown(model_summary.getvalue())
 
 
 
 
123
 
124
  # set up the model architecture
125
  model = tf.keras.models.Sequential([
126
+ tf.keras.layers.Input(shape=(256, 256, 1)), # Define input shape
127
+ tf.keras.layers.Flatten(),
128
+ tf.keras.layers.Dense(64, activation='relu'),
129
+ tf.keras.layers.Dense(256*256*2, activation='softmax'),
130
+ tf.keras.layers.Reshape((256, 256, 2))
131
  ])
132
 
133
  # specify how to train the model with algorithm, the loss function and metrics
 
141
  model.summary(print_fn=lambda x: model_summary.write(x + '\n'))
142
 
143
  # Display the model summary in Streamlit
144
+ st.markdown(model_summary.getvalue())
145
+
146
+ # plot the model including the sizes of the model
147
+ tf.keras.utils.plot_model(model, show_shapes=True)