import base64 import streamlit as st import tensorflow as tf from PIL import Image import numpy as np from keras.optimizers import Adam import os import json import pickle from sklearn.preprocessing import OneHotEncoder from keras.models import model_from_json st.markdown('

CNN Image classification model

', unsafe_allow_html=True) st.markdown('

The image classification model classifies images into zebra and horse

', unsafe_allow_html=True) st.cache(allow_output_mutation=True) def get_base64_of_bin_file(bin_file): with open(bin_file, 'rb') as f: data = f.read() return base64.b64encode(data).decode() def set_png_as_page_bg(png_file): bin_str = get_base64_of_bin_file(png_file) page_bg_img = ''' ''' % bin_str st.markdown(page_bg_img, unsafe_allow_html=True) return set_png_as_page_bg('background.webp') # def load_model(): # # load json and create model # json_file = open('model.json', 'r') # loaded_model_json = json_file.read() # json_file.close() # CNN_class_index = model_from_json(loaded_model_json) # # load weights into new model # model = CNN_class_index.load_weights("model.h5") # #model= tf.keras.load_model('model.h5') # #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json")) # return model, CNN_class_index def load_model(): # Load the model architecture with open('model.json', 'r') as f: model = model_from_json(f.read()) # Load the model weights model.load_weights('model.h5') #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json")) return model def image_transformation(image): #image = Image._resize_dispatcher(image, (256, 256)) # image= np.resize((256,256)) image = np.array(image) # np.save('images.npy', image) # image = np.load('images.npy', allow_pickle=True) return image def image_prediction(image, model): image = image_transformation(image=image) outputs = model.predict(image) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return predicted_idx def main(): image_file = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png']) if image_file: left_column, right_column = st.columns(2) left_column.image(image_file, caption="Uploaded image", use_column_width=True) image = Image.open(image_file) image = image_transformation(image=image) pred_button = st.button("Predict") model = load_model() # label = ['Zebra', 'Horse'] # label = np.array(label).reshape(1, -1) # ohe= OneHotEncoder() # labels = ohe.fit_transform(label).toarray() if pred_button: image_prediction(image, model) outputs = model.predict(image) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) right_column.title("Prediction") right_column.write(predicted_idx) if __name__ == '__main__': main()