KaburaJ commited on
Commit
ff3df08
·
1 Parent(s): b13ee49

binary image classification

Browse files
Files changed (3) hide show
  1. Classification_app.py +112 -0
  2. background.webp +0 -0
  3. requirements.txt +7 -0
Classification_app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import base64
3
+ import streamlit as st
4
+ import tensorflow as tf
5
+ from PIL import Image
6
+ import numpy as np
7
+ from keras.optimizers import Adam
8
+ import os
9
+ import json
10
+ import pickle
11
+ from sklearn.preprocessing import OneHotEncoder
12
+ from keras.models import model_from_json
13
+
14
+ st.markdown('<h1 style="color:white;">CNN Image classification model</h1>', unsafe_allow_html=True)
15
+ st.markdown('<h2 style="color:white;">The image classification model classifies images into zebra and horse</h2>', unsafe_allow_html=True)
16
+
17
+ st.cache(allow_output_mutation=True)
18
+ def get_base64_of_bin_file(bin_file):
19
+ with open(bin_file, 'rb') as f:
20
+ data = f.read()
21
+ return base64.b64encode(data).decode()
22
+
23
+ def set_png_as_page_bg(png_file):
24
+ bin_str = get_base64_of_bin_file(png_file)
25
+ page_bg_img = '''
26
+ <style>
27
+ .stApp {
28
+ background-image: url("data:image/png;base64,%s");
29
+ background-size: cover;
30
+ background-repeat: no-repeat;
31
+ background-attachment: scroll; # doesn't work
32
+ }
33
+ </style>
34
+ ''' % bin_str
35
+
36
+ st.markdown(page_bg_img, unsafe_allow_html=True)
37
+ return
38
+
39
+ set_png_as_page_bg('background.webp')
40
+
41
+
42
+ # def load_model():
43
+ # # load json and create model
44
+ # json_file = open('model.json', 'r')
45
+ # loaded_model_json = json_file.read()
46
+ # json_file.close()
47
+ # CNN_class_index = model_from_json(loaded_model_json)
48
+ # # load weights into new model
49
+ # model = CNN_class_index.load_weights("model.h5")
50
+
51
+ # #model= tf.keras.load_model('model.h5')
52
+ # #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
53
+ # return model, CNN_class_index
54
+ def load_model():
55
+ # Load the model architecture
56
+ with open('model.json', 'r') as f:
57
+ model = model_from_json(f.read())
58
+
59
+ # Load the model weights
60
+ model.load_weights('model.h5')
61
+ #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
62
+ return model
63
+
64
+
65
+ def image_transformation(image):
66
+ image = Image._resize_dispatcher(image, (256, 256))
67
+ # image= np.resize((256,256))
68
+ image = np.array(image)
69
+ np.save('images.npy', image)
70
+ image = np.load('images.npy', allow_pickle=True)
71
+
72
+ return image
73
+
74
+
75
+ def image_prediction(image, model):
76
+ image = image_transformation(image=image)
77
+ outputs = model.predict(image)
78
+ _, y_hat = outputs.max(1)
79
+ predicted_idx = str(y_hat.item())
80
+ return predicted_idx
81
+
82
+ def main():
83
+
84
+ image_file = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png'])
85
+
86
+ if image_file:
87
+
88
+ left_column, right_column = st.columns(2)
89
+ left_column.image(image_file, caption="Uploaded image", use_column_width=True)
90
+ image = Image.open(image_file)
91
+ image = image_transformation(image=image)
92
+
93
+
94
+ pred_button = st.button("Predict")
95
+
96
+ model = load_model()
97
+ # label = ['Zebra', 'Horse']
98
+ # label = np.array(label).reshape(1, -1)
99
+ # ohe= OneHotEncoder()
100
+ # labels = ohe.fit_transform(label).toarray()
101
+
102
+ if pred_button:
103
+ image_prediction(image, model)
104
+ outputs = model.predict(image)
105
+ _, y_hat = outputs.max(1)
106
+ predicted_idx = str(y_hat.item())
107
+ right_column.title("Prediction")
108
+ right_column.write(predicted_idx)
109
+
110
+
111
+ if __name__ == '__main__':
112
+ main()
background.webp ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ base64
4
+ tensorflow
5
+ PIL
6
+ numpy
7
+ keras.models