Spaces:
Runtime error
Runtime error
BucketHeadP65
commited on
Commit
•
2d26f82
1
Parent(s):
345d42d
ROC Curve
Browse files- README.md +62 -29
- app.py +1 -1
- requirements.txt +2 -1
- roc_curve.py +131 -71
- tests.py +0 -17
README.md
CHANGED
@@ -1,50 +1,83 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
- evaluate
|
7 |
-
- metric
|
8 |
-
description: "TODO: add a description here"
|
9 |
sdk: gradio
|
10 |
-
sdk_version: 3.0
|
11 |
app_file: app.py
|
12 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
---
|
14 |
|
15 |
-
# Metric Card for
|
16 |
|
17 |
-
***Module Card Instructions:*** *Fill out the following subsections. Feel free to take a look at existing metric cards if you'd like examples.*
|
18 |
|
19 |
## Metric Description
|
20 |
-
*Give a brief overview of this metric, including what task(s) it is usually used for, if any.*
|
21 |
|
22 |
-
|
23 |
-
|
24 |
|
25 |
-
*Provide simplest possible example for using the metric*
|
26 |
|
27 |
-
|
28 |
-
*List all input arguments in the format below*
|
29 |
-
- **input_field** *(type): Definition of input, with explanation if necessary. State any default value(s).*
|
30 |
|
31 |
-
|
32 |
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
*State the range of possible values that the metric's output can take, as well as what in that range is considered good. For example: "This metric can take on any value between 0 and 100, inclusive. Higher scores are better."*
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
###
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
## Further References
|
50 |
-
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: ROC Curve
|
3 |
+
emoji: 📉
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
|
|
|
|
|
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.17.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
tags:
|
11 |
+
- evaluate
|
12 |
+
- metric
|
13 |
+
description: >-
|
14 |
+
Compute Receiver operating characteristic (ROC).
|
15 |
+
Note: this implementation is restricted to the binary classification task.
|
16 |
---
|
17 |
|
18 |
+
# Metric Card for Confusion Matrix
|
19 |
|
|
|
20 |
|
21 |
## Metric Description
|
|
|
22 |
|
23 |
+
Compute Receiver operating characteristic (ROC).
|
24 |
+
Note: this implementation is restricted to the binary classification task.
|
25 |
|
|
|
26 |
|
27 |
+
## How to Use
|
|
|
|
|
28 |
|
29 |
+
At minimum, this metric requires predictions and references as inputs.
|
30 |
|
31 |
+
```python
|
32 |
+
>>> cfm_metric = evaluate.load("BucketHeadP65/roc_curve")
|
33 |
+
>>> results = cfm_metric.compute(references=[1, 0, 1, 1, 0], prediction_scores=[0.1, 0.4, 0.6, 0.7, 0.1])
|
34 |
+
>>> print(results)
|
35 |
+
{'roc_curve': (array([0. , 0. , 0. , 0.5, 1. ]), array([0. , 0.33333333, 0.66666667, 0.66666667, 1. ]), array([1.69999999, 0.69999999, 0.60000002, 0.40000001, 0.1 ]))}
|
36 |
+
```
|
37 |
|
|
|
38 |
|
39 |
+
### Inputs
|
40 |
+
- **prediction_scores** (`list` of `float`): Target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by "decision_function" on some classifiers).
|
41 |
+
- **references** (`list` of `int`): Ground truth labels.
|
42 |
+
- **pos_label** (`int` or `str`): default=None True binary labels. If labels are not either {-1, 1} or {0, 1}, then pos_label should be explicitly given.
|
43 |
+
- **sample_weight** (`list` of `float`): Sample weights Defaults to None.
|
44 |
+
- **drop_intermediate** (`bool`): default=True
|
45 |
+
Whether to drop some suboptimal thresholds which would not appear
|
46 |
+
on a plotted ROC curve. This is useful in order to create lighter
|
47 |
+
ROC curves.
|
48 |
|
49 |
+
### Output Values
|
50 |
+
- **fpr** (`ndarray`): Increasing false positive rates such that element i is the false
|
51 |
+
positive rate of predictions with score >= `thresholds[i]`.
|
52 |
+
- **tpr** (`ndarray`): Increasing true positive rates such that element `i` is the true
|
53 |
+
positive rate of predictions with score >= `thresholds[i]`.
|
54 |
+
- **thresholds** (`ndarray`): Decreasing thresholds on the decision function used to compute
|
55 |
+
`fpr` and `tpr`. `thresholds[0]` represents no instances being predicted
|
56 |
+
and is arbitrarily set to `max(y_score) + 1`.
|
57 |
|
58 |
+
Output Example(s):
|
59 |
+
```python
|
60 |
+
'roc_curve': (array([0. , 0. , 0. , 0.5, 1. ]), array([0. , 0.33333333, 0.66666667, 0.66666667, 1. ]), array([1.69999999, 0.69999999, 0.60000002, 0.40000001, 0.1 ]))}
|
61 |
|
62 |
+
```
|
63 |
+
This metric outputs a dictionary, containing the confusion matrix.
|
64 |
|
65 |
+
## Citation(s)
|
66 |
+
```bibtex
|
67 |
+
@article{scikit-learn,
|
68 |
+
title={Scikit-learn: Machine Learning in {P}ython},
|
69 |
+
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
|
70 |
+
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
|
71 |
+
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
|
72 |
+
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
|
73 |
+
journal={Journal of Machine Learning Research},
|
74 |
+
volume={12},
|
75 |
+
pages={2825--2830},
|
76 |
+
year={2011}
|
77 |
+
}
|
78 |
+
```
|
79 |
## Further References
|
80 |
+
Wikipedia entry for the Confusion matrix
|
81 |
+
<https://en.wikipedia.org/wiki/Confusion_matrix>`_
|
82 |
+
(Wikipedia and other references may use a different
|
83 |
+
convention for axes).
|
app.py
CHANGED
@@ -3,4 +3,4 @@ from evaluate.utils import launch_gradio_widget
|
|
3 |
|
4 |
|
5 |
module = evaluate.load("BucketHeadP65/roc_curve")
|
6 |
-
launch_gradio_widget(module)
|
|
|
3 |
|
4 |
|
5 |
module = evaluate.load("BucketHeadP65/roc_curve")
|
6 |
+
launch_gradio_widget(module)
|
requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
git+https://github.com/huggingface/evaluate@
|
|
|
|
1 |
+
git+https://github.com/huggingface/evaluate@75c09ff6cd599744d0641e4ef1d4cfedbd35eec9
|
2 |
+
scikit-learn
|
roc_curve.py
CHANGED
@@ -1,95 +1,155 @@
|
|
1 |
-
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
"""TODO: Add a description here."""
|
15 |
|
16 |
-
import evaluate
|
17 |
import datasets
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
@InProceedings{huggingface:module,
|
23 |
-
title = {A great new module},
|
24 |
-
authors={huggingface, Inc.},
|
25 |
-
year={2020}
|
26 |
-
}
|
27 |
-
"""
|
28 |
-
|
29 |
-
# TODO: Add description of the module here
|
30 |
-
_DESCRIPTION = """\
|
31 |
-
This new module is designed to solve this great ML task and is crafted with a lot of care.
|
32 |
"""
|
33 |
|
34 |
|
35 |
-
# TODO: Add description of the arguments of the module here
|
36 |
_KWARGS_DESCRIPTION = """
|
37 |
-
Calculates how good are predictions given some references, using certain scores
|
38 |
Args:
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
Returns:
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
Examples:
|
47 |
-
Examples should be written in doctest format, and should illustrate how
|
48 |
-
to use the function.
|
49 |
|
50 |
-
|
51 |
-
>>>
|
52 |
-
>>>
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
"""
|
55 |
|
56 |
-
# TODO: Define external resources urls if needed
|
57 |
-
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
61 |
-
class RocCurve(evaluate.Metric):
|
62 |
-
"""TODO: Short description of my evaluation module."""
|
63 |
|
|
|
|
|
64 |
def _info(self):
|
65 |
-
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
66 |
return evaluate.MetricInfo(
|
67 |
-
# This is the description that will appear on the modules page.
|
68 |
-
module_type="metric",
|
69 |
description=_DESCRIPTION,
|
70 |
citation=_CITATION,
|
71 |
inputs_description=_KWARGS_DESCRIPTION,
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
)
|
83 |
|
84 |
-
def
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
return {
|
94 |
-
"
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Confusion Matrix metric."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
3 |
import datasets
|
4 |
+
import evaluate
|
5 |
+
from sklearn.metrics import roc_curve
|
6 |
|
7 |
+
_DESCRIPTION = """
|
8 |
+
Compute Receiver operating characteristic (ROC).
|
9 |
+
Note: this implementation is restricted to the binary classification task.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"""
|
11 |
|
12 |
|
|
|
13 |
_KWARGS_DESCRIPTION = """
|
|
|
14 |
Args:
|
15 |
+
|
16 |
+
y_true : ndarray of shape (n_samples,)
|
17 |
+
True binary labels. If labels are not either {-1, 1} or {0, 1}, then
|
18 |
+
pos_label should be explicitly given.
|
19 |
+
|
20 |
+
y_score : ndarray of shape (n_samples,)
|
21 |
+
Target scores, can either be probability estimates of the positive
|
22 |
+
class, confidence values, or non-thresholded measure of decisions
|
23 |
+
(as returned by "decision_function" on some classifiers).
|
24 |
+
|
25 |
+
pos_label : int or str, default=None
|
26 |
+
The label of the positive class.
|
27 |
+
When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1},
|
28 |
+
``pos_label`` is set to 1, otherwise an error will be raised.
|
29 |
+
|
30 |
+
sample_weight : array-like of shape (n_samples,), default=None
|
31 |
+
Sample weights.
|
32 |
+
|
33 |
+
drop_intermediate : bool, default=True
|
34 |
+
Whether to drop some suboptimal thresholds which would not appear
|
35 |
+
on a plotted ROC curve. This is useful in order to create lighter
|
36 |
+
ROC curves.
|
37 |
+
|
38 |
+
.. versionadded:: 0.17
|
39 |
+
parameter *drop_intermediate*.
|
40 |
+
|
41 |
Returns:
|
42 |
+
|
43 |
+
fpr : ndarray of shape (>2,)
|
44 |
+
Increasing false positive rates such that element i is the false
|
45 |
+
positive rate of predictions with score >= `thresholds[i]`.
|
46 |
+
|
47 |
+
tpr : ndarray of shape (>2,)
|
48 |
+
Increasing true positive rates such that element `i` is the true
|
49 |
+
positive rate of predictions with score >= `thresholds[i]`.
|
50 |
+
|
51 |
+
thresholds : ndarray of shape = (n_thresholds,)
|
52 |
+
Decreasing thresholds on the decision function used to compute
|
53 |
+
fpr and tpr. `thresholds[0]` represents no instances being predicted
|
54 |
+
and is arbitrarily set to `max(y_score) + 1`.
|
55 |
+
|
56 |
+
See Also:
|
57 |
+
|
58 |
+
RocCurveDisplay.from_estimator : Plot Receiver Operating Characteristic
|
59 |
+
(ROC) curve given an estimator and some data.
|
60 |
+
RocCurveDisplay.from_predictions : Plot Receiver Operating Characteristic
|
61 |
+
(ROC) curve given the true and predicted values.
|
62 |
+
det_curve: Compute error rates for different probability thresholds.
|
63 |
+
roc_auc_score : Compute the area under the ROC curve.
|
64 |
+
|
65 |
+
Notes:
|
66 |
+
|
67 |
+
Since the thresholds are sorted from low to high values, they
|
68 |
+
are reversed upon returning them to ensure they correspond to both ``fpr``
|
69 |
+
and ``tpr``, which are sorted in reversed order during their calculation.
|
70 |
+
|
71 |
+
References:
|
72 |
+
|
73 |
+
.. [1] `Wikipedia entry for the Receiver operating characteristic
|
74 |
+
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
|
75 |
+
|
76 |
+
.. [2] Fawcett T. An introduction to ROC analysis[J]. Pattern Recognition
|
77 |
+
Letters, 2006, 27(8):861-874.
|
78 |
+
|
79 |
Examples:
|
|
|
|
|
80 |
|
81 |
+
>>> import numpy as np
|
82 |
+
>>> from sklearn import metrics
|
83 |
+
>>> y = np.array([1, 1, 2, 2])
|
84 |
+
>>> scores = np.array([0.1, 0.4, 0.35, 0.8])
|
85 |
+
>>> fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=2)
|
86 |
+
>>> fpr
|
87 |
+
array([0. , 0. , 0.5, 0.5, 1. ])
|
88 |
+
>>> tpr
|
89 |
+
array([0. , 0.5, 0.5, 1. , 1. ])
|
90 |
+
>>> thresholds
|
91 |
+
array([1.8 , 0.8 , 0.4 , 0.35, 0.1 ])
|
92 |
"""
|
93 |
|
|
|
|
|
94 |
|
95 |
+
_CITATION = """
|
96 |
+
@article{scikit-learn,
|
97 |
+
title={Scikit-learn: Machine Learning in {P}ython},
|
98 |
+
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
|
99 |
+
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
|
100 |
+
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
|
101 |
+
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
|
102 |
+
journal={Journal of Machine Learning Research},
|
103 |
+
volume={12},
|
104 |
+
pages={2825--2830},
|
105 |
+
year={2011}
|
106 |
+
}
|
107 |
+
"""
|
108 |
|
|
|
|
|
|
|
109 |
|
110 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
111 |
+
class ConfusionMatrix(evaluate.Metric):
|
112 |
def _info(self):
|
|
|
113 |
return evaluate.MetricInfo(
|
|
|
|
|
114 |
description=_DESCRIPTION,
|
115 |
citation=_CITATION,
|
116 |
inputs_description=_KWARGS_DESCRIPTION,
|
117 |
+
features=datasets.Features(
|
118 |
+
{
|
119 |
+
"prediction_scores": datasets.Sequence(datasets.Value("float")),
|
120 |
+
"references": datasets.Value("int32"),
|
121 |
+
}
|
122 |
+
if self.config_name == "multiclass"
|
123 |
+
else {
|
124 |
+
"references": datasets.Sequence(datasets.Value("int32")),
|
125 |
+
"prediction_scores": datasets.Sequence(datasets.Value("float")),
|
126 |
+
}
|
127 |
+
if self.config_name == "multilabel"
|
128 |
+
else {
|
129 |
+
"references": datasets.Value("int32"),
|
130 |
+
"prediction_scores": datasets.Value("float"),
|
131 |
+
}
|
132 |
+
),
|
133 |
+
reference_urls=[
|
134 |
+
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html"
|
135 |
+
],
|
136 |
)
|
137 |
|
138 |
+
def _compute(
|
139 |
+
self,
|
140 |
+
references,
|
141 |
+
prediction_scores,
|
142 |
+
*,
|
143 |
+
pos_label=None,
|
144 |
+
sample_weight=None,
|
145 |
+
drop_intermediate=True
|
146 |
+
):
|
147 |
return {
|
148 |
+
"roc_curve": roc_curve(
|
149 |
+
y_true=references,
|
150 |
+
y_score=prediction_scores,
|
151 |
+
pos_label=pos_label,
|
152 |
+
sample_weight=sample_weight,
|
153 |
+
drop_intermediate=drop_intermediate,
|
154 |
+
)
|
155 |
+
}
|
tests.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
test_cases = [
|
2 |
-
{
|
3 |
-
"predictions": [0, 0],
|
4 |
-
"references": [1, 1],
|
5 |
-
"result": {"metric_score": 0}
|
6 |
-
},
|
7 |
-
{
|
8 |
-
"predictions": [1, 1],
|
9 |
-
"references": [1, 1],
|
10 |
-
"result": {"metric_score": 1}
|
11 |
-
},
|
12 |
-
{
|
13 |
-
"predictions": [1, 0],
|
14 |
-
"references": [1, 1],
|
15 |
-
"result": {"metric_score": 0.5}
|
16 |
-
}
|
17 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|