Spaces:
Sleeping
Sleeping
Classification Conformal
Browse files- app.py +114 -14
- cerr.npy +3 -0
- final_images.npy +3 -0
- final_pred.npy +3 -0
app.py
CHANGED
@@ -8,24 +8,16 @@ from sklearn.linear_model import LinearRegression
|
|
8 |
from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score,mean_absolute_percentage_error
|
9 |
from matplotlib import pyplot as plt
|
10 |
import matplotlib.image as mpimg
|
11 |
-
import
|
|
|
12 |
|
13 |
load_reg_data = False
|
14 |
load_class_data = False
|
15 |
-
california_img=plt.imread("./california.png")
|
16 |
-
|
17 |
-
|
18 |
-
# plt.rcParams['axes.spines.right'] = False
|
19 |
-
# plt.rcParams['axes.spines.top'] = False
|
20 |
-
# plt.rcParams['axes.spines.left'] = False
|
21 |
-
# plt.rcParams['axes.spines.bottom'] = False
|
22 |
|
23 |
|
24 |
|
25 |
def conformal_Predict(cal_err,alpha = 0.8):
|
26 |
-
|
27 |
assert alpha != None, " Provide a value of alpha "
|
28 |
-
# assert cal_err!= [] or None, "Provide the caliberation errors"
|
29 |
idx = int(alpha*len(cal_err))
|
30 |
return cal_err[idx]
|
31 |
|
@@ -40,13 +32,17 @@ if __name__ == '__main__':
|
|
40 |
y_test = np.load("./Reg_Test_y.npy")
|
41 |
err_calib = np.load("./Reg_calib_err.npy")
|
42 |
y_pred = np.load("./Reg_y_pred.npy")
|
|
|
43 |
load_reg_data = True
|
44 |
|
45 |
|
46 |
if not(load_class_data):
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
|
|
|
50 |
|
51 |
st.title("Conformal Prediction")
|
52 |
intro_tab , reg_tab , class_tab = st.tabs(["Introduction","Regression", "Classification"])
|
@@ -117,8 +113,112 @@ if __name__ == '__main__':
|
|
117 |
|
118 |
|
119 |
with class_tab:
|
120 |
-
|
121 |
-
st.write("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
|
124 |
|
|
|
8 |
from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score,mean_absolute_percentage_error
|
9 |
from matplotlib import pyplot as plt
|
10 |
import matplotlib.image as mpimg
|
11 |
+
import torch
|
12 |
+
|
13 |
|
14 |
load_reg_data = False
|
15 |
load_class_data = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
|
19 |
def conformal_Predict(cal_err,alpha = 0.8):
|
|
|
20 |
assert alpha != None, " Provide a value of alpha "
|
|
|
21 |
idx = int(alpha*len(cal_err))
|
22 |
return cal_err[idx]
|
23 |
|
|
|
32 |
y_test = np.load("./Reg_Test_y.npy")
|
33 |
err_calib = np.load("./Reg_calib_err.npy")
|
34 |
y_pred = np.load("./Reg_y_pred.npy")
|
35 |
+
california_img=plt.imread("./california.png")
|
36 |
load_reg_data = True
|
37 |
|
38 |
|
39 |
if not(load_class_data):
|
40 |
+
img = np.load("./final_images.npy")
|
41 |
+
pred = np.load("./final_pred.npy")
|
42 |
+
cls_calib = np.load("./cerr.npy")
|
43 |
+
load_class_data = True
|
44 |
|
45 |
+
|
46 |
|
47 |
st.title("Conformal Prediction")
|
48 |
intro_tab , reg_tab , class_tab = st.tabs(["Introduction","Regression", "Classification"])
|
|
|
113 |
|
114 |
|
115 |
with class_tab:
|
116 |
+
|
117 |
+
st.write("For Classification we are using Fashion-MNIST dataset. Fashion-MNIST is a dataset of Zalando's article images. Zalando intends Fashion-MNIST to serve as a direct drop-in for benchmarking machine learning algorithms. Each example is assigned to one of the following labels: 0 T-shirt/top, 1 Trouser,2 Pullover, 3 Dress, 4 Coat, 5 Sandal, 6 Shirt, 7 Sneaker, 8 Bag, 9 Ankle boot")
|
118 |
+
|
119 |
+
st.write("---")
|
120 |
+
|
121 |
+
st.write("Lets assume you have a model trained for Object detection but you cant just rely on the softmax output for that model. This is where conformal prediction comes into play. We can use the alpha value to pick up a threshold. When softmax scores go beyond this threshold score then onlt that label is considered as the predicted class.")
|
122 |
+
|
123 |
+
st.write("The higher the value of alpha more the model is certain about its prediction")
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
c1,c2,c3 = st.columns(3)
|
128 |
+
with c2:
|
129 |
+
alpha1 = st.slider('Select a value of alpha for the Model',min_value=0.1,max_value=.99,value=0.5,step=0.05)
|
130 |
+
sigma = conformal_Predict(cls_calib,alpha1)
|
131 |
+
|
132 |
+
labels = np.array(['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'])
|
133 |
+
|
134 |
+
with st.container():
|
135 |
+
|
136 |
+
c1,col1, col2, col3,c2 = st.columns([0.2,0.3,0.3,0.3,0.2])
|
137 |
+
|
138 |
+
with col1:
|
139 |
+
fig1, ax1 = plt.figure(), plt.gca()
|
140 |
+
ax1.imshow(torch.tensor(img[0]).permute(1,2,0),cmap='gray')
|
141 |
+
ax1.spines['top'].set_visible(False)
|
142 |
+
ax1.spines['bottom'].set_visible(False)
|
143 |
+
ax1.spines['right'].set_visible(False)
|
144 |
+
ax1.spines['left'].set_visible(False)
|
145 |
+
ax1.set_xticks([])
|
146 |
+
ax1.set_yticks([])
|
147 |
+
st.pyplot(fig1)
|
148 |
+
|
149 |
+
fig1, ax1 = plt.figure(), plt.gca()
|
150 |
+
ax1.bar(range(10),pred[0])
|
151 |
+
ax1.axhline(y=sigma,linestyle='dashed',c='r')
|
152 |
+
ax1.set_xlabel("Classe Labels")
|
153 |
+
ax1.set_ylabel("SoftMax Probabilities")
|
154 |
+
ax1.set_title("Class Scores with Threshold")
|
155 |
+
ax1.set_xticks([i for i in range(10)])
|
156 |
+
st.pyplot(fig1)
|
157 |
+
|
158 |
+
out_labels = labels[pred[0]>sigma]
|
159 |
+
if len(out_labels)==0:
|
160 |
+
out_labels = ["None"]
|
161 |
+
out_labels = ",".join(out_labels)
|
162 |
+
st.write("Ouput Labels : "+out_labels)
|
163 |
+
st.write("True Label : Coat")
|
164 |
+
|
165 |
+
|
166 |
+
with col2:
|
167 |
+
fig1, ax1 = plt.figure(), plt.gca()
|
168 |
+
ax1.imshow(torch.tensor(img[1]).permute(1,2,0),cmap='gray')
|
169 |
+
ax1.spines['top'].set_visible(False)
|
170 |
+
ax1.spines['bottom'].set_visible(False)
|
171 |
+
ax1.spines['right'].set_visible(False)
|
172 |
+
ax1.spines['left'].set_visible(False)
|
173 |
+
ax1.set_xticks([])
|
174 |
+
ax1.set_yticks([])
|
175 |
+
st.pyplot(fig1)
|
176 |
+
|
177 |
+
fig1, ax1 = plt.figure(), plt.gca()
|
178 |
+
ax1.bar(range(10),pred[1])
|
179 |
+
ax1.axhline(y=sigma,linestyle='dashed',c='r')
|
180 |
+
ax1.set_xlabel("Classe Labels")
|
181 |
+
ax1.set_ylabel("SoftMax Probabilities")
|
182 |
+
ax1.set_title("Class Scores with Threshold")
|
183 |
+
ax1.set_xticks([i for i in range(10)])
|
184 |
+
st.pyplot(fig1)
|
185 |
+
|
186 |
+
out_labels = labels[pred[1]>sigma]
|
187 |
+
if len(out_labels)==0:
|
188 |
+
out_labels = ["None"]
|
189 |
+
out_labels = ",".join(out_labels)
|
190 |
+
st.write("Ouput Labels : "+out_labels)
|
191 |
+
st.write("True Label : Ankle Boot")
|
192 |
+
|
193 |
+
with col3:
|
194 |
+
fig1, ax1 = plt.figure(), plt.gca()
|
195 |
+
ax1.imshow(torch.tensor(img[2]).permute(1,2,0),cmap='gray')
|
196 |
+
ax1.spines['top'].set_visible(False)
|
197 |
+
ax1.spines['bottom'].set_visible(False)
|
198 |
+
ax1.spines['right'].set_visible(False)
|
199 |
+
ax1.spines['left'].set_visible(False)
|
200 |
+
ax1.set_xticks([])
|
201 |
+
ax1.set_yticks([])
|
202 |
+
st.pyplot(fig1)
|
203 |
+
|
204 |
+
fig1, ax1 = plt.figure(), plt.gca()
|
205 |
+
ax1.bar(range(10),pred[2])
|
206 |
+
ax1.axhline(y=sigma,linestyle='dashed',c='r')
|
207 |
+
ax1.set_xlabel("Classe Labels")
|
208 |
+
ax1.set_ylabel("SoftMax Probabilities")
|
209 |
+
ax1.set_title("Class Scores with Threshold")
|
210 |
+
ax1.set_xticks([i for i in range(10)])
|
211 |
+
st.pyplot(fig1)
|
212 |
+
|
213 |
+
out_labels = labels[pred[2]>sigma]
|
214 |
+
if len(out_labels)==0:
|
215 |
+
out_labels = ["None"]
|
216 |
+
out_labels = ",".join(out_labels)
|
217 |
+
st.write("Ouput Labels : "+out_labels)
|
218 |
+
st.write("True Label : Bag")
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
|
223 |
|
224 |
|
cerr.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25367f293f8667889b42da0c304630cee93f1dc3f6d5faebcb3211489c500613
|
3 |
+
size 4128
|
final_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa9347527bfb8edfa7a17e3e07e6aa6f8c40a22cfb0b7fdc29b2b3b6c01066d9
|
3 |
+
size 9536
|
final_pred.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4dab1a251533ac80f3125474ac74ab3ea6088be558d8bffab4774ee3a6e7f304
|
3 |
+
size 248
|