Spaces:
Runtime error
Runtime error
Addressing tfa deprecation
#1
by
MBA98
- opened
app.py
CHANGED
@@ -4,6 +4,600 @@ import numpy as np
|
|
4 |
import tensorflow as tf
|
5 |
import tensorflow_addons as tfa
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
IMAGE_SIZE = 128
|
8 |
NUM_CLASSES = 3
|
9 |
|
@@ -106,7 +700,7 @@ def load_vgg16_model():
|
|
106 |
vgg16_DR = tf.keras.models.load_model('./VGG16_32_128_CLAHE_4_40_trainableFalse.h5', custom_objects={
|
107 |
'WeightedKappaLoss': tfa.losses.WeightedKappaLoss,
|
108 |
'CohenKappa': tfa.metrics.CohenKappa,
|
109 |
-
'F1Score':
|
110 |
'MultiLabelConfusionMatrix': tfa.metrics.MultiLabelConfusionMatrix
|
111 |
})
|
112 |
return vgg16_DR
|
|
|
4 |
import tensorflow as tf
|
5 |
import tensorflow_addons as tfa
|
6 |
|
7 |
+
# Imported
|
8 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
# ==============================================================================
|
22 |
+
"""Implements Weighted kappa loss."""
|
23 |
+
|
24 |
+
from typing import Optional
|
25 |
+
|
26 |
+
import tensorflow as tf
|
27 |
+
from typeguard import typechecked
|
28 |
+
|
29 |
+
from tensorflow_addons.utils.types import Number
|
30 |
+
|
31 |
+
|
32 |
+
@tf.keras.utils.register_keras_serializable(package="Addons")
|
33 |
+
class WeightedKappaLoss(tf.keras.losses.Loss):
|
34 |
+
r"""Implements the Weighted Kappa loss function.
|
35 |
+
|
36 |
+
Weighted Kappa loss was introduced in the
|
37 |
+
[Weighted kappa loss function for multi-class classification
|
38 |
+
of ordinal data in deep learning]
|
39 |
+
(https://www.sciencedirect.com/science/article/abs/pii/S0167865517301666).
|
40 |
+
Weighted Kappa is widely used in Ordinal Classification Problems.
|
41 |
+
The loss value lies in $ [-\infty, \log 2] $, where $ \log 2 $
|
42 |
+
means the random prediction.
|
43 |
+
|
44 |
+
Usage:
|
45 |
+
|
46 |
+
>>> kappa_loss = tfa.losses.WeightedKappaLoss(num_classes=4)
|
47 |
+
>>> y_true = tf.constant([[0, 0, 1, 0], [0, 1, 0, 0],
|
48 |
+
... [1, 0, 0, 0], [0, 0, 0, 1]])
|
49 |
+
>>> y_pred = tf.constant([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1],
|
50 |
+
... [0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]])
|
51 |
+
>>> loss = kappa_loss(y_true, y_pred)
|
52 |
+
>>> loss
|
53 |
+
<tf.Tensor: shape=(), dtype=float32, numpy=-1.1611925>
|
54 |
+
|
55 |
+
Usage with `tf.keras` API:
|
56 |
+
|
57 |
+
>>> model = tf.keras.Model()
|
58 |
+
>>> model.compile('sgd', loss=tfa.losses.WeightedKappaLoss(num_classes=4))
|
59 |
+
|
60 |
+
<... outputs should be softmax results
|
61 |
+
if you want to weight the samples, just multiply the outputs
|
62 |
+
by the sample weight ...>
|
63 |
+
|
64 |
+
"""
|
65 |
+
|
66 |
+
@typechecked
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
num_classes: int,
|
70 |
+
weightage: Optional[str] = "quadratic",
|
71 |
+
name: Optional[str] = "cohen_kappa_loss",
|
72 |
+
epsilon: Optional[Number] = 1e-6,
|
73 |
+
reduction: str = tf.keras.losses.Reduction.NONE,
|
74 |
+
):
|
75 |
+
r"""Creates a `WeightedKappaLoss` instance.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
num_classes: Number of unique classes in your dataset.
|
79 |
+
weightage: (Optional) Weighting to be considered for calculating
|
80 |
+
kappa statistics. A valid value is one of
|
81 |
+
['linear', 'quadratic']. Defaults to 'quadratic'.
|
82 |
+
name: (Optional) String name of the metric instance.
|
83 |
+
epsilon: (Optional) increment to avoid log zero,
|
84 |
+
so the loss will be $ \log(1 - k + \epsilon) $, where $ k $ lies
|
85 |
+
in $ [-1, 1] $. Defaults to 1e-6.
|
86 |
+
Raises:
|
87 |
+
ValueError: If the value passed for `weightage` is invalid
|
88 |
+
i.e. not any one of ['linear', 'quadratic']
|
89 |
+
"""
|
90 |
+
|
91 |
+
super().__init__(name=name, reduction=reduction)
|
92 |
+
|
93 |
+
if weightage not in ("linear", "quadratic"):
|
94 |
+
raise ValueError("Unknown kappa weighting type.")
|
95 |
+
|
96 |
+
self.weightage = weightage
|
97 |
+
self.num_classes = num_classes
|
98 |
+
self.epsilon = epsilon or tf.keras.backend.epsilon()
|
99 |
+
label_vec = tf.range(num_classes, dtype=tf.keras.backend.floatx())
|
100 |
+
self.row_label_vec = tf.reshape(label_vec, [1, num_classes])
|
101 |
+
self.col_label_vec = tf.reshape(label_vec, [num_classes, 1])
|
102 |
+
col_mat = tf.tile(self.col_label_vec, [1, num_classes])
|
103 |
+
row_mat = tf.tile(self.row_label_vec, [num_classes, 1])
|
104 |
+
if weightage == "linear":
|
105 |
+
self.weight_mat = tf.abs(col_mat - row_mat)
|
106 |
+
else:
|
107 |
+
self.weight_mat = (col_mat - row_mat) ** 2
|
108 |
+
|
109 |
+
def call(self, y_true, y_pred):
|
110 |
+
y_true = tf.cast(y_true, dtype=self.col_label_vec.dtype)
|
111 |
+
y_pred = tf.cast(y_pred, dtype=self.weight_mat.dtype)
|
112 |
+
batch_size = tf.shape(y_true)[0]
|
113 |
+
cat_labels = tf.matmul(y_true, self.col_label_vec)
|
114 |
+
cat_label_mat = tf.tile(cat_labels, [1, self.num_classes])
|
115 |
+
row_label_mat = tf.tile(self.row_label_vec, [batch_size, 1])
|
116 |
+
if self.weightage == "linear":
|
117 |
+
weight = tf.abs(cat_label_mat - row_label_mat)
|
118 |
+
else:
|
119 |
+
weight = (cat_label_mat - row_label_mat) ** 2
|
120 |
+
numerator = tf.reduce_sum(weight * y_pred)
|
121 |
+
label_dist = tf.reduce_sum(y_true, axis=0, keepdims=True)
|
122 |
+
pred_dist = tf.reduce_sum(y_pred, axis=0, keepdims=True)
|
123 |
+
w_pred_dist = tf.matmul(self.weight_mat, pred_dist, transpose_b=True)
|
124 |
+
denominator = tf.reduce_sum(tf.matmul(label_dist, w_pred_dist))
|
125 |
+
denominator /= tf.cast(batch_size, dtype=denominator.dtype)
|
126 |
+
loss = tf.math.divide_no_nan(numerator, denominator)
|
127 |
+
return tf.math.log(loss + self.epsilon)
|
128 |
+
|
129 |
+
def get_config(self):
|
130 |
+
config = {
|
131 |
+
"num_classes": self.num_classes,
|
132 |
+
"weightage": self.weightage,
|
133 |
+
"epsilon": self.epsilon,
|
134 |
+
}
|
135 |
+
base_config = super().get_config()
|
136 |
+
return {**base_config, **config}
|
137 |
+
|
138 |
+
|
139 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
140 |
+
#
|
141 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
142 |
+
# you may not use this file except in compliance with the License.
|
143 |
+
# You may obtain a copy of the License at
|
144 |
+
#
|
145 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
146 |
+
#
|
147 |
+
# Unless required by applicable law or agreed to in writing, software
|
148 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
149 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
150 |
+
# See the License for the specific language governing permissions and
|
151 |
+
# limitations under the License.
|
152 |
+
# ==============================================================================
|
153 |
+
"""Implements Cohen's Kappa."""
|
154 |
+
|
155 |
+
import tensorflow as tf
|
156 |
+
import numpy as np
|
157 |
+
import tensorflow.keras.backend as K
|
158 |
+
from tensorflow.keras.metrics import Metric
|
159 |
+
FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64]
|
160 |
+
AcceptableDTypes = Union[tf.DType, np.dtype, type, int, str, None]
|
161 |
+
|
162 |
+
from typeguard import typechecked
|
163 |
+
from typing import Optional
|
164 |
+
|
165 |
+
|
166 |
+
@tf.keras.utils.register_keras_serializable(package="Addons")
|
167 |
+
class CohenKappa(Metric):
|
168 |
+
"""Computes Kappa score between two raters.
|
169 |
+
|
170 |
+
The score lies in the range `[-1, 1]`. A score of -1 represents
|
171 |
+
complete disagreement between two raters whereas a score of 1
|
172 |
+
represents complete agreement between the two raters.
|
173 |
+
A score of 0 means agreement by chance.
|
174 |
+
|
175 |
+
Note: As of now, this implementation considers all labels
|
176 |
+
while calculating the Cohen's Kappa score.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
num_classes: Number of unique classes in your dataset.
|
180 |
+
weightage: (optional) Weighting to be considered for calculating
|
181 |
+
kappa statistics. A valid value is one of
|
182 |
+
[None, 'linear', 'quadratic']. Defaults to `None`
|
183 |
+
sparse_labels: (bool) Valid only for multi-class scenario.
|
184 |
+
If True, ground truth labels are expected to be integers
|
185 |
+
and not one-hot encoded.
|
186 |
+
regression: (bool) If set, that means the problem is being treated
|
187 |
+
as a regression problem where you are regressing the predictions.
|
188 |
+
**Note:** If you are regressing for the values, the the output layer
|
189 |
+
should contain a single unit.
|
190 |
+
name: (optional) String name of the metric instance
|
191 |
+
dtype: (optional) Data type of the metric result. Defaults to `None`.
|
192 |
+
|
193 |
+
Raises:
|
194 |
+
ValueError: If the value passed for `weightage` is invalid
|
195 |
+
i.e. not any one of [None, 'linear', 'quadratic'].
|
196 |
+
|
197 |
+
Usage:
|
198 |
+
|
199 |
+
>>> y_true = np.array([4, 4, 3, 4, 2, 4, 1, 1], dtype=np.int32)
|
200 |
+
>>> y_pred = np.array([4, 4, 3, 4, 4, 2, 1, 1], dtype=np.int32)
|
201 |
+
>>> weights = np.array([1, 1, 2, 5, 10, 2, 3, 3], dtype=np.int32)
|
202 |
+
>>> metric = tfa.metrics.CohenKappa(num_classes=5, sparse_labels=True)
|
203 |
+
>>> metric.update_state(y_true , y_pred)
|
204 |
+
<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
|
205 |
+
array([[0., 0., 0., 0., 0.],
|
206 |
+
[0., 2., 0., 0., 0.],
|
207 |
+
[0., 0., 0., 0., 1.],
|
208 |
+
[0., 0., 0., 1., 0.],
|
209 |
+
[0., 0., 1., 0., 3.]], dtype=float32)>
|
210 |
+
>>> result = metric.result()
|
211 |
+
>>> result.numpy()
|
212 |
+
0.61904764
|
213 |
+
>>> # To use this with weights, sample_weight argument can be used.
|
214 |
+
>>> metric = tfa.metrics.CohenKappa(num_classes=5, sparse_labels=True)
|
215 |
+
>>> metric.update_state(y_true , y_pred , sample_weight=weights)
|
216 |
+
<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
|
217 |
+
array([[ 0., 0., 0., 0., 0.],
|
218 |
+
[ 0., 6., 0., 0., 0.],
|
219 |
+
[ 0., 0., 0., 0., 10.],
|
220 |
+
[ 0., 0., 0., 2., 0.],
|
221 |
+
[ 0., 0., 2., 0., 7.]], dtype=float32)>
|
222 |
+
>>> result = metric.result()
|
223 |
+
>>> result.numpy()
|
224 |
+
0.37209308
|
225 |
+
|
226 |
+
Usage with `tf.keras` API:
|
227 |
+
|
228 |
+
>>> inputs = tf.keras.Input(shape=(10,))
|
229 |
+
>>> x = tf.keras.layers.Dense(10)(inputs)
|
230 |
+
>>> outputs = tf.keras.layers.Dense(1)(x)
|
231 |
+
>>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
|
232 |
+
>>> model.compile('sgd', loss='mse', metrics=[tfa.metrics.CohenKappa(num_classes=3, sparse_labels=True)])
|
233 |
+
"""
|
234 |
+
|
235 |
+
@typechecked
|
236 |
+
def __init__(
|
237 |
+
self,
|
238 |
+
num_classes: FloatTensorLike,
|
239 |
+
name: str = "cohen_kappa",
|
240 |
+
weightage: Optional[str] = None,
|
241 |
+
sparse_labels: bool = False,
|
242 |
+
regression: bool = False,
|
243 |
+
dtype: AcceptableDTypes = None,
|
244 |
+
):
|
245 |
+
"""Creates a `CohenKappa` instance."""
|
246 |
+
super().__init__(name=name, dtype=dtype)
|
247 |
+
|
248 |
+
if weightage not in (None, "linear", "quadratic"):
|
249 |
+
raise ValueError("Unknown kappa weighting type.")
|
250 |
+
|
251 |
+
if num_classes == 2:
|
252 |
+
self._update = self._update_binary_class_model
|
253 |
+
elif num_classes > 2:
|
254 |
+
self._update = self._update_multi_class_model
|
255 |
+
else:
|
256 |
+
raise ValueError(
|
257 |
+
"""Number of classes must be
|
258 |
+
greater than or euqal to two"""
|
259 |
+
)
|
260 |
+
|
261 |
+
self.weightage = weightage
|
262 |
+
self.num_classes = num_classes
|
263 |
+
self.regression = regression
|
264 |
+
self.sparse_labels = sparse_labels
|
265 |
+
self.conf_mtx = self.add_weight(
|
266 |
+
"conf_mtx",
|
267 |
+
shape=(self.num_classes, self.num_classes),
|
268 |
+
initializer=tf.keras.initializers.zeros,
|
269 |
+
dtype=tf.float32,
|
270 |
+
)
|
271 |
+
|
272 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
273 |
+
"""Accumulates the confusion matrix condition statistics.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
y_true: Labels assigned by the first annotator with shape
|
277 |
+
`[num_samples,]`.
|
278 |
+
y_pred: Labels assigned by the second annotator with shape
|
279 |
+
`[num_samples,]`. The kappa statistic is symmetric,
|
280 |
+
so swapping `y_true` and `y_pred` doesn't change the value.
|
281 |
+
sample_weight (optional): for weighting labels in confusion matrix
|
282 |
+
Defaults to `None`. The dtype for weights should be the same
|
283 |
+
as the dtype for confusion matrix. For more details,
|
284 |
+
please check `tf.math.confusion_matrix`.
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
Update op.
|
288 |
+
"""
|
289 |
+
return self._update(y_true, y_pred, sample_weight)
|
290 |
+
|
291 |
+
def _update_binary_class_model(self, y_true, y_pred, sample_weight=None):
|
292 |
+
y_true = tf.cast(y_true, dtype=tf.int64)
|
293 |
+
y_pred = tf.cast(y_pred, dtype=tf.float32)
|
294 |
+
y_pred = tf.cast(y_pred > 0.5, dtype=tf.int64)
|
295 |
+
return self._update_confusion_matrix(y_true, y_pred, sample_weight)
|
296 |
+
|
297 |
+
@tf.function
|
298 |
+
def _update_multi_class_model(self, y_true, y_pred, sample_weight=None):
|
299 |
+
v = tf.argmax(y_true, axis=1) if not self.sparse_labels else y_true
|
300 |
+
y_true = tf.cast(v, dtype=tf.int64)
|
301 |
+
|
302 |
+
y_pred = self._cast_ypred(y_pred)
|
303 |
+
|
304 |
+
return self._update_confusion_matrix(y_true, y_pred, sample_weight)
|
305 |
+
|
306 |
+
@tf.function
|
307 |
+
def _cast_ypred(self, y_pred):
|
308 |
+
if tf.rank(y_pred) > 1:
|
309 |
+
if not self.regression:
|
310 |
+
y_pred = tf.cast(tf.argmax(y_pred, axis=-1), dtype=tf.int64)
|
311 |
+
else:
|
312 |
+
y_pred = tf.math.round(tf.math.abs(y_pred))
|
313 |
+
y_pred = tf.cast(y_pred, dtype=tf.int64)
|
314 |
+
else:
|
315 |
+
y_pred = tf.cast(y_pred, dtype=tf.int64)
|
316 |
+
return y_pred
|
317 |
+
|
318 |
+
@tf.function
|
319 |
+
def _safe_squeeze(self, y):
|
320 |
+
y = tf.squeeze(y)
|
321 |
+
|
322 |
+
# Check for scalar result
|
323 |
+
if tf.rank(y) == 0:
|
324 |
+
y = tf.expand_dims(y, 0)
|
325 |
+
|
326 |
+
return y
|
327 |
+
|
328 |
+
def _update_confusion_matrix(self, y_true, y_pred, sample_weight):
|
329 |
+
y_true = self._safe_squeeze(y_true)
|
330 |
+
y_pred = self._safe_squeeze(y_pred)
|
331 |
+
|
332 |
+
new_conf_mtx = tf.math.confusion_matrix(
|
333 |
+
labels=y_true,
|
334 |
+
predictions=y_pred,
|
335 |
+
num_classes=self.num_classes,
|
336 |
+
weights=sample_weight,
|
337 |
+
dtype=tf.float32,
|
338 |
+
)
|
339 |
+
|
340 |
+
return self.conf_mtx.assign_add(new_conf_mtx)
|
341 |
+
|
342 |
+
def result(self):
|
343 |
+
nb_ratings = tf.shape(self.conf_mtx)[0]
|
344 |
+
weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.float32)
|
345 |
+
|
346 |
+
# 2. Create a weight matrix
|
347 |
+
if self.weightage is None:
|
348 |
+
diagonal = tf.zeros([nb_ratings], dtype=tf.float32)
|
349 |
+
weight_mtx = tf.linalg.set_diag(weight_mtx, diagonal=diagonal)
|
350 |
+
else:
|
351 |
+
weight_mtx += tf.cast(tf.range(nb_ratings), dtype=tf.float32)
|
352 |
+
weight_mtx = tf.cast(weight_mtx, dtype=self.dtype)
|
353 |
+
|
354 |
+
if self.weightage == "linear":
|
355 |
+
weight_mtx = tf.abs(weight_mtx - tf.transpose(weight_mtx))
|
356 |
+
else:
|
357 |
+
weight_mtx = tf.pow((weight_mtx - tf.transpose(weight_mtx)), 2)
|
358 |
+
|
359 |
+
weight_mtx = tf.cast(weight_mtx, dtype=self.dtype)
|
360 |
+
|
361 |
+
# 3. Get counts
|
362 |
+
actual_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=1)
|
363 |
+
pred_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=0)
|
364 |
+
|
365 |
+
# 4. Get the outer product
|
366 |
+
out_prod = pred_ratings_hist[..., None] * actual_ratings_hist[None, ...]
|
367 |
+
|
368 |
+
# 5. Normalize the confusion matrix and outer product
|
369 |
+
conf_mtx = self.conf_mtx / tf.reduce_sum(self.conf_mtx)
|
370 |
+
out_prod = out_prod / tf.reduce_sum(out_prod)
|
371 |
+
|
372 |
+
conf_mtx = tf.cast(conf_mtx, dtype=self.dtype)
|
373 |
+
out_prod = tf.cast(out_prod, dtype=self.dtype)
|
374 |
+
|
375 |
+
# 6. Calculate Kappa score
|
376 |
+
numerator = tf.reduce_sum(conf_mtx * weight_mtx)
|
377 |
+
denominator = tf.reduce_sum(out_prod * weight_mtx)
|
378 |
+
return tf.cond(
|
379 |
+
tf.math.is_nan(denominator),
|
380 |
+
true_fn=lambda: 0.0,
|
381 |
+
false_fn=lambda: 1 - (numerator / denominator),
|
382 |
+
)
|
383 |
+
|
384 |
+
def get_config(self):
|
385 |
+
"""Returns the serializable config of the metric."""
|
386 |
+
|
387 |
+
config = {
|
388 |
+
"num_classes": self.num_classes,
|
389 |
+
"weightage": self.weightage,
|
390 |
+
"sparse_labels": self.sparse_labels,
|
391 |
+
"regression": self.regression,
|
392 |
+
}
|
393 |
+
base_config = super().get_config()
|
394 |
+
return {**base_config, **config}
|
395 |
+
|
396 |
+
def reset_state(self):
|
397 |
+
"""Resets all of the metric state variables."""
|
398 |
+
|
399 |
+
for v in self.variables:
|
400 |
+
K.set_value(
|
401 |
+
v,
|
402 |
+
np.zeros((self.num_classes, self.num_classes), v.dtype.as_numpy_dtype),
|
403 |
+
)
|
404 |
+
|
405 |
+
def reset_states(self):
|
406 |
+
# Backwards compatibility alias of `reset_state`. New classes should
|
407 |
+
# only implement `reset_state`.
|
408 |
+
# Required in Tensorflow < 2.5.0
|
409 |
+
return self.reset_state()
|
410 |
+
|
411 |
+
|
412 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
413 |
+
#
|
414 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
415 |
+
# you may not use this file except in compliance with the License.
|
416 |
+
# You may obtain a copy of the License at
|
417 |
+
#
|
418 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
419 |
+
#
|
420 |
+
# Unless required by applicable law or agreed to in writing, software
|
421 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
422 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
423 |
+
# See the License for the specific language governing permissions and
|
424 |
+
# limitations under the License.
|
425 |
+
# ==============================================================================
|
426 |
+
"""Implements Multi-label confusion matrix scores."""
|
427 |
+
|
428 |
+
import warnings
|
429 |
+
|
430 |
+
import tensorflow as tf
|
431 |
+
from tensorflow.keras import backend as K
|
432 |
+
from tensorflow.keras.metrics import Metric
|
433 |
+
import numpy as np
|
434 |
+
|
435 |
+
from typeguard import typechecked
|
436 |
+
|
437 |
+
|
438 |
+
class MultiLabelConfusionMatrix(Metric):
|
439 |
+
"""Computes Multi-label confusion matrix.
|
440 |
+
|
441 |
+
Class-wise confusion matrix is computed for the
|
442 |
+
evaluation of classification.
|
443 |
+
|
444 |
+
If multi-class input is provided, it will be treated
|
445 |
+
as multilabel data.
|
446 |
+
|
447 |
+
Consider classification problem with two classes
|
448 |
+
(i.e num_classes=2).
|
449 |
+
|
450 |
+
Resultant matrix `M` will be in the shape of `(num_classes, 2, 2)`.
|
451 |
+
|
452 |
+
Every class `i` has a dedicated matrix of shape `(2, 2)` that contains:
|
453 |
+
|
454 |
+
- true negatives for class `i` in `M(0,0)`
|
455 |
+
- false positives for class `i` in `M(0,1)`
|
456 |
+
- false negatives for class `i` in `M(1,0)`
|
457 |
+
- true positives for class `i` in `M(1,1)`
|
458 |
+
|
459 |
+
Args:
|
460 |
+
num_classes: `int`, the number of labels the prediction task can have.
|
461 |
+
name: (Optional) string name of the metric instance.
|
462 |
+
dtype: (Optional) data type of the metric result.
|
463 |
+
|
464 |
+
Usage:
|
465 |
+
|
466 |
+
>>> # multilabel confusion matrix
|
467 |
+
>>> y_true = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.int32)
|
468 |
+
>>> y_pred = np.array([[1, 0, 0], [0, 1, 1]], dtype=np.int32)
|
469 |
+
>>> metric = tfa.metrics.MultiLabelConfusionMatrix(num_classes=3)
|
470 |
+
>>> metric.update_state(y_true, y_pred)
|
471 |
+
>>> result = metric.result()
|
472 |
+
>>> result.numpy() #doctest: -DONT_ACCEPT_BLANKLINE
|
473 |
+
array([[[1., 0.],
|
474 |
+
[0., 1.]],
|
475 |
+
<BLANKLINE>
|
476 |
+
[[1., 0.],
|
477 |
+
[0., 1.]],
|
478 |
+
<BLANKLINE>
|
479 |
+
[[0., 1.],
|
480 |
+
[1., 0.]]], dtype=float32)
|
481 |
+
>>> # if multiclass input is provided
|
482 |
+
>>> y_true = np.array([[1, 0, 0], [0, 1, 0]], dtype=np.int32)
|
483 |
+
>>> y_pred = np.array([[1, 0, 0], [0, 0, 1]], dtype=np.int32)
|
484 |
+
>>> metric = tfa.metrics.MultiLabelConfusionMatrix(num_classes=3)
|
485 |
+
>>> metric.update_state(y_true, y_pred)
|
486 |
+
>>> result = metric.result()
|
487 |
+
>>> result.numpy() #doctest: -DONT_ACCEPT_BLANKLINE
|
488 |
+
array([[[1., 0.],
|
489 |
+
[0., 1.]],
|
490 |
+
<BLANKLINE>
|
491 |
+
[[1., 0.],
|
492 |
+
[1., 0.]],
|
493 |
+
<BLANKLINE>
|
494 |
+
[[1., 1.],
|
495 |
+
[0., 0.]]], dtype=float32)
|
496 |
+
|
497 |
+
"""
|
498 |
+
|
499 |
+
@typechecked
|
500 |
+
def __init__(
|
501 |
+
self,
|
502 |
+
num_classes: FloatTensorLike,
|
503 |
+
name: str = "Multilabel_confusion_matrix",
|
504 |
+
dtype: AcceptableDTypes = None,
|
505 |
+
**kwargs,
|
506 |
+
):
|
507 |
+
super().__init__(name=name, dtype=dtype)
|
508 |
+
self.num_classes = num_classes
|
509 |
+
self.true_positives = self.add_weight(
|
510 |
+
"true_positives",
|
511 |
+
shape=[self.num_classes],
|
512 |
+
initializer="zeros",
|
513 |
+
dtype=self.dtype,
|
514 |
+
)
|
515 |
+
self.false_positives = self.add_weight(
|
516 |
+
"false_positives",
|
517 |
+
shape=[self.num_classes],
|
518 |
+
initializer="zeros",
|
519 |
+
dtype=self.dtype,
|
520 |
+
)
|
521 |
+
self.false_negatives = self.add_weight(
|
522 |
+
"false_negatives",
|
523 |
+
shape=[self.num_classes],
|
524 |
+
initializer="zeros",
|
525 |
+
dtype=self.dtype,
|
526 |
+
)
|
527 |
+
self.true_negatives = self.add_weight(
|
528 |
+
"true_negatives",
|
529 |
+
shape=[self.num_classes],
|
530 |
+
initializer="zeros",
|
531 |
+
dtype=self.dtype,
|
532 |
+
)
|
533 |
+
|
534 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
535 |
+
if sample_weight is not None:
|
536 |
+
warnings.warn(
|
537 |
+
"`sample_weight` is not None. Be aware that MultiLabelConfusionMatrix "
|
538 |
+
"does not take `sample_weight` into account when computing the metric "
|
539 |
+
"value."
|
540 |
+
)
|
541 |
+
|
542 |
+
y_true = tf.cast(y_true, tf.int32)
|
543 |
+
y_pred = tf.cast(y_pred, tf.int32)
|
544 |
+
# true positive
|
545 |
+
true_positive = tf.math.count_nonzero(y_true * y_pred, 0)
|
546 |
+
# predictions sum
|
547 |
+
pred_sum = tf.math.count_nonzero(y_pred, 0)
|
548 |
+
# true labels sum
|
549 |
+
true_sum = tf.math.count_nonzero(y_true, 0)
|
550 |
+
false_positive = pred_sum - true_positive
|
551 |
+
false_negative = true_sum - true_positive
|
552 |
+
y_true_negative = tf.math.not_equal(y_true, 1)
|
553 |
+
y_pred_negative = tf.math.not_equal(y_pred, 1)
|
554 |
+
true_negative = tf.math.count_nonzero(
|
555 |
+
tf.math.logical_and(y_true_negative, y_pred_negative), axis=0
|
556 |
+
)
|
557 |
+
|
558 |
+
# true positive state update
|
559 |
+
self.true_positives.assign_add(tf.cast(true_positive, self.dtype))
|
560 |
+
# false positive state update
|
561 |
+
self.false_positives.assign_add(tf.cast(false_positive, self.dtype))
|
562 |
+
# false negative state update
|
563 |
+
self.false_negatives.assign_add(tf.cast(false_negative, self.dtype))
|
564 |
+
# true negative state update
|
565 |
+
self.true_negatives.assign_add(tf.cast(true_negative, self.dtype))
|
566 |
+
|
567 |
+
def result(self):
|
568 |
+
flat_confusion_matrix = tf.convert_to_tensor(
|
569 |
+
[
|
570 |
+
self.true_negatives,
|
571 |
+
self.false_positives,
|
572 |
+
self.false_negatives,
|
573 |
+
self.true_positives,
|
574 |
+
]
|
575 |
+
)
|
576 |
+
# reshape into 2*2 matrix
|
577 |
+
confusion_matrix = tf.reshape(tf.transpose(flat_confusion_matrix), [-1, 2, 2])
|
578 |
+
|
579 |
+
return confusion_matrix
|
580 |
+
|
581 |
+
def get_config(self):
|
582 |
+
"""Returns the serializable config of the metric."""
|
583 |
+
|
584 |
+
config = {
|
585 |
+
"num_classes": self.num_classes,
|
586 |
+
}
|
587 |
+
base_config = super().get_config()
|
588 |
+
return {**base_config, **config}
|
589 |
+
|
590 |
+
def reset_state(self):
|
591 |
+
reset_value = np.zeros(self.num_classes, dtype=np.int32)
|
592 |
+
K.batch_set_value([(v, reset_value) for v in self.variables])
|
593 |
+
|
594 |
+
def reset_states(self):
|
595 |
+
# Backwards compatibility alias of `reset_state`. New classes should
|
596 |
+
# only implement `reset_state`.
|
597 |
+
# Required in Tensorflow < 2.5.0
|
598 |
+
return self.reset_state()
|
599 |
+
#####
|
600 |
+
|
601 |
IMAGE_SIZE = 128
|
602 |
NUM_CLASSES = 3
|
603 |
|
|
|
700 |
vgg16_DR = tf.keras.models.load_model('./VGG16_32_128_CLAHE_4_40_trainableFalse.h5', custom_objects={
|
701 |
'WeightedKappaLoss': tfa.losses.WeightedKappaLoss,
|
702 |
'CohenKappa': tfa.metrics.CohenKappa,
|
703 |
+
'F1Score': tf.keras.metrics.F1Score,
|
704 |
'MultiLabelConfusionMatrix': tfa.metrics.MultiLabelConfusionMatrix
|
705 |
})
|
706 |
return vgg16_DR
|