Ayushs799 commited on
Commit
08ef9f5
·
1 Parent(s): 9639a6f

Classification Conformal

Browse files
Files changed (4) hide show
  1. app.py +114 -14
  2. cerr.npy +3 -0
  3. final_images.npy +3 -0
  4. final_pred.npy +3 -0
app.py CHANGED
@@ -8,24 +8,16 @@ from sklearn.linear_model import LinearRegression
8
  from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score,mean_absolute_percentage_error
9
  from matplotlib import pyplot as plt
10
  import matplotlib.image as mpimg
11
- import pickle
 
12
 
13
  load_reg_data = False
14
  load_class_data = False
15
- california_img=plt.imread("./california.png")
16
-
17
-
18
- # plt.rcParams['axes.spines.right'] = False
19
- # plt.rcParams['axes.spines.top'] = False
20
- # plt.rcParams['axes.spines.left'] = False
21
- # plt.rcParams['axes.spines.bottom'] = False
22
 
23
 
24
 
25
  def conformal_Predict(cal_err,alpha = 0.8):
26
-
27
  assert alpha != None, " Provide a value of alpha "
28
- # assert cal_err!= [] or None, "Provide the caliberation errors"
29
  idx = int(alpha*len(cal_err))
30
  return cal_err[idx]
31
 
@@ -40,13 +32,17 @@ if __name__ == '__main__':
40
  y_test = np.load("./Reg_Test_y.npy")
41
  err_calib = np.load("./Reg_calib_err.npy")
42
  y_pred = np.load("./Reg_y_pred.npy")
 
43
  load_reg_data = True
44
 
45
 
46
  if not(load_class_data):
47
- pass
48
-
 
 
49
 
 
50
 
51
  st.title("Conformal Prediction")
52
  intro_tab , reg_tab , class_tab = st.tabs(["Introduction","Regression", "Classification"])
@@ -117,8 +113,112 @@ if __name__ == '__main__':
117
 
118
 
119
  with class_tab:
120
- st.subheader("Classification")
121
- st.write("This is the classification tab")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
 
 
8
  from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score,mean_absolute_percentage_error
9
  from matplotlib import pyplot as plt
10
  import matplotlib.image as mpimg
11
+ import torch
12
+
13
 
14
  load_reg_data = False
15
  load_class_data = False
 
 
 
 
 
 
 
16
 
17
 
18
 
19
  def conformal_Predict(cal_err,alpha = 0.8):
 
20
  assert alpha != None, " Provide a value of alpha "
 
21
  idx = int(alpha*len(cal_err))
22
  return cal_err[idx]
23
 
 
32
  y_test = np.load("./Reg_Test_y.npy")
33
  err_calib = np.load("./Reg_calib_err.npy")
34
  y_pred = np.load("./Reg_y_pred.npy")
35
+ california_img=plt.imread("./california.png")
36
  load_reg_data = True
37
 
38
 
39
  if not(load_class_data):
40
+ img = np.load("./final_images.npy")
41
+ pred = np.load("./final_pred.npy")
42
+ cls_calib = np.load("./cerr.npy")
43
+ load_class_data = True
44
 
45
+
46
 
47
  st.title("Conformal Prediction")
48
  intro_tab , reg_tab , class_tab = st.tabs(["Introduction","Regression", "Classification"])
 
113
 
114
 
115
  with class_tab:
116
+
117
+ st.write("For Classification we are using Fashion-MNIST dataset. Fashion-MNIST is a dataset of Zalando's article images. Zalando intends Fashion-MNIST to serve as a direct drop-in for benchmarking machine learning algorithms. Each example is assigned to one of the following labels: 0 T-shirt/top, 1 Trouser,2 Pullover, 3 Dress, 4 Coat, 5 Sandal, 6 Shirt, 7 Sneaker, 8 Bag, 9 Ankle boot")
118
+
119
+ st.write("---")
120
+
121
+ st.write("Lets assume you have a model trained for Object detection but you cant just rely on the softmax output for that model. This is where conformal prediction comes into play. We can use the alpha value to pick up a threshold. When softmax scores go beyond this threshold score then onlt that label is considered as the predicted class.")
122
+
123
+ st.write("The higher the value of alpha more the model is certain about its prediction")
124
+
125
+
126
+
127
+ c1,c2,c3 = st.columns(3)
128
+ with c2:
129
+ alpha1 = st.slider('Select a value of alpha for the Model',min_value=0.1,max_value=.99,value=0.5,step=0.05)
130
+ sigma = conformal_Predict(cls_calib,alpha1)
131
+
132
+ labels = np.array(['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'])
133
+
134
+ with st.container():
135
+
136
+ c1,col1, col2, col3,c2 = st.columns([0.2,0.3,0.3,0.3,0.2])
137
+
138
+ with col1:
139
+ fig1, ax1 = plt.figure(), plt.gca()
140
+ ax1.imshow(torch.tensor(img[0]).permute(1,2,0),cmap='gray')
141
+ ax1.spines['top'].set_visible(False)
142
+ ax1.spines['bottom'].set_visible(False)
143
+ ax1.spines['right'].set_visible(False)
144
+ ax1.spines['left'].set_visible(False)
145
+ ax1.set_xticks([])
146
+ ax1.set_yticks([])
147
+ st.pyplot(fig1)
148
+
149
+ fig1, ax1 = plt.figure(), plt.gca()
150
+ ax1.bar(range(10),pred[0])
151
+ ax1.axhline(y=sigma,linestyle='dashed',c='r')
152
+ ax1.set_xlabel("Classe Labels")
153
+ ax1.set_ylabel("SoftMax Probabilities")
154
+ ax1.set_title("Class Scores with Threshold")
155
+ ax1.set_xticks([i for i in range(10)])
156
+ st.pyplot(fig1)
157
+
158
+ out_labels = labels[pred[0]>sigma]
159
+ if len(out_labels)==0:
160
+ out_labels = ["None"]
161
+ out_labels = ",".join(out_labels)
162
+ st.write("Ouput Labels : "+out_labels)
163
+ st.write("True Label : Coat")
164
+
165
+
166
+ with col2:
167
+ fig1, ax1 = plt.figure(), plt.gca()
168
+ ax1.imshow(torch.tensor(img[1]).permute(1,2,0),cmap='gray')
169
+ ax1.spines['top'].set_visible(False)
170
+ ax1.spines['bottom'].set_visible(False)
171
+ ax1.spines['right'].set_visible(False)
172
+ ax1.spines['left'].set_visible(False)
173
+ ax1.set_xticks([])
174
+ ax1.set_yticks([])
175
+ st.pyplot(fig1)
176
+
177
+ fig1, ax1 = plt.figure(), plt.gca()
178
+ ax1.bar(range(10),pred[1])
179
+ ax1.axhline(y=sigma,linestyle='dashed',c='r')
180
+ ax1.set_xlabel("Classe Labels")
181
+ ax1.set_ylabel("SoftMax Probabilities")
182
+ ax1.set_title("Class Scores with Threshold")
183
+ ax1.set_xticks([i for i in range(10)])
184
+ st.pyplot(fig1)
185
+
186
+ out_labels = labels[pred[1]>sigma]
187
+ if len(out_labels)==0:
188
+ out_labels = ["None"]
189
+ out_labels = ",".join(out_labels)
190
+ st.write("Ouput Labels : "+out_labels)
191
+ st.write("True Label : Ankle Boot")
192
+
193
+ with col3:
194
+ fig1, ax1 = plt.figure(), plt.gca()
195
+ ax1.imshow(torch.tensor(img[2]).permute(1,2,0),cmap='gray')
196
+ ax1.spines['top'].set_visible(False)
197
+ ax1.spines['bottom'].set_visible(False)
198
+ ax1.spines['right'].set_visible(False)
199
+ ax1.spines['left'].set_visible(False)
200
+ ax1.set_xticks([])
201
+ ax1.set_yticks([])
202
+ st.pyplot(fig1)
203
+
204
+ fig1, ax1 = plt.figure(), plt.gca()
205
+ ax1.bar(range(10),pred[2])
206
+ ax1.axhline(y=sigma,linestyle='dashed',c='r')
207
+ ax1.set_xlabel("Classe Labels")
208
+ ax1.set_ylabel("SoftMax Probabilities")
209
+ ax1.set_title("Class Scores with Threshold")
210
+ ax1.set_xticks([i for i in range(10)])
211
+ st.pyplot(fig1)
212
+
213
+ out_labels = labels[pred[2]>sigma]
214
+ if len(out_labels)==0:
215
+ out_labels = ["None"]
216
+ out_labels = ",".join(out_labels)
217
+ st.write("Ouput Labels : "+out_labels)
218
+ st.write("True Label : Bag")
219
+
220
+
221
+
222
 
223
 
224
 
cerr.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25367f293f8667889b42da0c304630cee93f1dc3f6d5faebcb3211489c500613
3
+ size 4128
final_images.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa9347527bfb8edfa7a17e3e07e6aa6f8c40a22cfb0b7fdc29b2b3b6c01066d9
3
+ size 9536
final_pred.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dab1a251533ac80f3125474ac74ab3ea6088be558d8bffab4774ee3a6e7f304
3
+ size 248