Spaces:
Runtime error
Runtime error
import base64 | |
import streamlit as st | |
import tensorflow as tf | |
from PIL import Image | |
import numpy as np | |
from keras.optimizers import Adam | |
import os | |
import json | |
import pickle | |
from sklearn.preprocessing import OneHotEncoder | |
from keras.models import model_from_json | |
st.markdown('<h1 style="color:white;">CNN Image classification model</h1>', unsafe_allow_html=True) | |
st.markdown('<h2 style="color:white;">The image classification model classifies images into zebra and horse</h2>', unsafe_allow_html=True) | |
st.cache(allow_output_mutation=True) | |
def get_base64_of_bin_file(bin_file): | |
with open(bin_file, 'rb') as f: | |
data = f.read() | |
return base64.b64encode(data).decode() | |
def set_png_as_page_bg(png_file): | |
bin_str = get_base64_of_bin_file(png_file) | |
page_bg_img = ''' | |
<style> | |
.stApp { | |
background-image: url("data:image/png;base64,%s"); | |
background-size: cover; | |
background-repeat: no-repeat; | |
background-attachment: scroll; # doesn't work | |
} | |
</style> | |
''' % bin_str | |
st.markdown(page_bg_img, unsafe_allow_html=True) | |
return | |
set_png_as_page_bg('background.webp') | |
# def load_model(): | |
# # load json and create model | |
# json_file = open('model.json', 'r') | |
# loaded_model_json = json_file.read() | |
# json_file.close() | |
# CNN_class_index = model_from_json(loaded_model_json) | |
# # load weights into new model | |
# model = CNN_class_index.load_weights("model.h5") | |
# #model= tf.keras.load_model('model.h5') | |
# #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json")) | |
# return model, CNN_class_index | |
def load_model(): | |
# Load the model architecture | |
with open('model.json', 'r') as f: | |
model = model_from_json(f.read()) | |
# Load the model weights | |
model.load_weights('model.h5') | |
#CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json")) | |
return model | |
def image_transformation(image): | |
image = Image._resize_dispatcher(image, (256, 256)) | |
# image= np.resize((256,256)) | |
image = np.array(image) | |
np.save('images.npy', image) | |
image = np.load('images.npy', allow_pickle=True) | |
return image | |
def image_prediction(image, model): | |
image = image_transformation(image=image) | |
outputs = model.predict(image) | |
_, y_hat = outputs.max(1) | |
predicted_idx = str(y_hat.item()) | |
return predicted_idx | |
def main(): | |
image_file = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png']) | |
if image_file: | |
left_column, right_column = st.columns(2) | |
left_column.image(image_file, caption="Uploaded image", use_column_width=True) | |
image = Image.open(image_file) | |
image = image_transformation(image=image) | |
pred_button = st.button("Predict") | |
model = load_model() | |
# label = ['Zebra', 'Horse'] | |
# label = np.array(label).reshape(1, -1) | |
# ohe= OneHotEncoder() | |
# labels = ohe.fit_transform(label).toarray() | |
if pred_button: | |
image_prediction(image, model) | |
outputs = model.predict(image) | |
_, y_hat = outputs.max(1) | |
predicted_idx = str(y_hat.item()) | |
right_column.title("Prediction") | |
right_column.write(predicted_idx) | |
if __name__ == '__main__': | |
main() |