import base64
import streamlit as st
import tensorflow as tf
from PIL import Image
import numpy as np
from keras.optimizers import Adam
import os
import json
import pickle
from sklearn.preprocessing import OneHotEncoder
from keras.models import model_from_json
st.markdown('
CNN Image classification model
', unsafe_allow_html=True)
st.markdown('The image classification model classifies images into zebra and horse
', unsafe_allow_html=True)
st.cache(allow_output_mutation=True)
def get_base64_of_bin_file(bin_file):
with open(bin_file, 'rb') as f:
data = f.read()
return base64.b64encode(data).decode()
def set_png_as_page_bg(png_file):
bin_str = get_base64_of_bin_file(png_file)
page_bg_img = '''
''' % bin_str
st.markdown(page_bg_img, unsafe_allow_html=True)
return
set_png_as_page_bg('background.webp')
# def load_model():
# # load json and create model
# json_file = open('model.json', 'r')
# loaded_model_json = json_file.read()
# json_file.close()
# CNN_class_index = model_from_json(loaded_model_json)
# # load weights into new model
# model = CNN_class_index.load_weights("model.h5")
# #model= tf.keras.load_model('model.h5')
# #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
# return model, CNN_class_index
def load_model():
# Load the model architecture
with open('model.json', 'r') as f:
model = model_from_json(f.read())
# Load the model weights
model.load_weights('model.h5')
#CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
return model
def image_transformation(image):
#image = Image._resize_dispatcher(image, (256, 256))
# image= np.resize((256,256))
image = np.array(image)
# np.save('images.npy', image)
# image = np.load('images.npy', allow_pickle=True)
return image
def image_prediction(image, model):
image = image_transformation(image=image)
outputs = model.predict(image)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return predicted_idx
def main():
image_file = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png'])
if image_file:
left_column, right_column = st.columns(2)
left_column.image(image_file, caption="Uploaded image", use_column_width=True)
image = Image.open(image_file)
image = image_transformation(image=image)
pred_button = st.button("Predict")
model = load_model()
# label = ['Zebra', 'Horse']
# label = np.array(label).reshape(1, -1)
# ohe= OneHotEncoder()
# labels = ohe.fit_transform(label).toarray()
if pred_button:
image_prediction(image, model)
outputs = model.predict(image)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
right_column.title("Prediction")
right_column.write(predicted_idx)
if __name__ == '__main__':
main()