Spaces:
Sleeping
Sleeping
Duplicate from aseifert/ExplaiNER
Browse filesCo-authored-by: Alexander Seifert <[email protected]>
- .gitattributes +27 -0
- .gitignore +169 -0
- Makefile +10 -0
- README.md +78 -0
- html/index.md +116 -0
- html/screenshot.jpg +0 -0
- requirements.txt +16 -0
- src/__init__.py +0 -0
- src/app.py +114 -0
- src/data.py +228 -0
- src/load.py +101 -0
- src/model.py +33 -0
- src/subpages/__init__.py +14 -0
- src/subpages/attention.py +160 -0
- src/subpages/debug.py +27 -0
- src/subpages/emoji-en-US.json +0 -0
- src/subpages/faiss.py +58 -0
- src/subpages/find_duplicates.py +52 -0
- src/subpages/hidden_states.py +194 -0
- src/subpages/home.py +163 -0
- src/subpages/inspect.py +41 -0
- src/subpages/losses.py +67 -0
- src/subpages/lossy_samples.py +115 -0
- src/subpages/metrics.py +104 -0
- src/subpages/misclassified.py +82 -0
- src/subpages/page.py +50 -0
- src/subpages/probing.py +57 -0
- src/subpages/random_samples.py +50 -0
- src/subpages/raw_data.py +57 -0
- src/utils.py +255 -0
.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Created by https://www.gitignore.io/api/python,osx,linux
|
3 |
+
# Edit at https://www.gitignore.io/?templates=python,osx,linux
|
4 |
+
|
5 |
+
### Linux ###
|
6 |
+
*~
|
7 |
+
|
8 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
9 |
+
.fuse_hidden*
|
10 |
+
|
11 |
+
# KDE directory preferences
|
12 |
+
.directory
|
13 |
+
|
14 |
+
# Linux trash folder which might appear on any partition or disk
|
15 |
+
.Trash-*
|
16 |
+
|
17 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
18 |
+
.nfs*
|
19 |
+
|
20 |
+
### OSX ###
|
21 |
+
# General
|
22 |
+
.DS_Store
|
23 |
+
.AppleDouble
|
24 |
+
.LSOverride
|
25 |
+
|
26 |
+
# Icon must end with two \r
|
27 |
+
Icon
|
28 |
+
|
29 |
+
# Thumbnails
|
30 |
+
._*
|
31 |
+
|
32 |
+
# Files that might appear in the root of a volume
|
33 |
+
.DocumentRevisions-V100
|
34 |
+
.fseventsd
|
35 |
+
.Spotlight-V100
|
36 |
+
.TemporaryItems
|
37 |
+
.Trashes
|
38 |
+
.VolumeIcon.icns
|
39 |
+
.com.apple.timemachine.donotpresent
|
40 |
+
|
41 |
+
# Directories potentially created on remote AFP share
|
42 |
+
.AppleDB
|
43 |
+
.AppleDesktop
|
44 |
+
Network Trash Folder
|
45 |
+
Temporary Items
|
46 |
+
.apdisk
|
47 |
+
|
48 |
+
### Python ###
|
49 |
+
# Byte-compiled / optimized / DLL files
|
50 |
+
__pycache__/
|
51 |
+
*.py[cod]
|
52 |
+
*$py.class
|
53 |
+
|
54 |
+
# C extensions
|
55 |
+
*.so
|
56 |
+
|
57 |
+
# Distribution / packaging
|
58 |
+
.Python
|
59 |
+
build/
|
60 |
+
develop-eggs/
|
61 |
+
dist/
|
62 |
+
downloads/
|
63 |
+
eggs/
|
64 |
+
.eggs/
|
65 |
+
lib/
|
66 |
+
lib64/
|
67 |
+
parts/
|
68 |
+
sdist/
|
69 |
+
var/
|
70 |
+
wheels/
|
71 |
+
pip-wheel-metadata/
|
72 |
+
share/python-wheels/
|
73 |
+
*.egg-info/
|
74 |
+
.installed.cfg
|
75 |
+
*.egg
|
76 |
+
MANIFEST
|
77 |
+
|
78 |
+
# PyInstaller
|
79 |
+
# Usually these files are written by a python script from a template
|
80 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
81 |
+
*.manifest
|
82 |
+
*.spec
|
83 |
+
|
84 |
+
# Installer logs
|
85 |
+
pip-log.txt
|
86 |
+
pip-delete-this-directory.txt
|
87 |
+
|
88 |
+
# Unit test / coverage reports
|
89 |
+
htmlcov/
|
90 |
+
.tox/
|
91 |
+
.nox/
|
92 |
+
.coverage
|
93 |
+
.coverage.*
|
94 |
+
.cache
|
95 |
+
nosetests.xml
|
96 |
+
coverage.xml
|
97 |
+
*.cover
|
98 |
+
.hypothesis/
|
99 |
+
.pytest_cache/
|
100 |
+
|
101 |
+
# Translations
|
102 |
+
*.mo
|
103 |
+
*.pot
|
104 |
+
|
105 |
+
# Scrapy stuff:
|
106 |
+
.scrapy
|
107 |
+
|
108 |
+
# Sphinx documentation
|
109 |
+
docs/_build/
|
110 |
+
|
111 |
+
# PyBuilder
|
112 |
+
target/
|
113 |
+
|
114 |
+
# pyenv
|
115 |
+
.python-version
|
116 |
+
|
117 |
+
# pipenv
|
118 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
119 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
120 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
121 |
+
# install all needed dependencies.
|
122 |
+
#Pipfile.lock
|
123 |
+
|
124 |
+
# celery beat schedule file
|
125 |
+
celerybeat-schedule
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Spyder project settings
|
131 |
+
.spyderproject
|
132 |
+
.spyproject
|
133 |
+
|
134 |
+
# Rope project settings
|
135 |
+
.ropeproject
|
136 |
+
|
137 |
+
# Mr Developer
|
138 |
+
.mr.developer.cfg
|
139 |
+
.project
|
140 |
+
.pydevproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# End of https://www.gitignore.io/api/python,osx,linux
|
154 |
+
|
155 |
+
.idea/
|
156 |
+
.ipynb_checkpoints/
|
157 |
+
node_modules/
|
158 |
+
data/books/
|
159 |
+
docx/cache/
|
160 |
+
docx/data/
|
161 |
+
save_dir/
|
162 |
+
cache_dir/
|
163 |
+
outputs/
|
164 |
+
models/
|
165 |
+
runs/
|
166 |
+
.vscode/
|
167 |
+
doc/
|
168 |
+
html/*.html
|
169 |
+
vis2.zip
|
Makefile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
doc:
|
2 |
+
pdoc --docformat google src -o doc
|
3 |
+
|
4 |
+
vis2: doc
|
5 |
+
pandoc html/index.md -s -o html/index.html
|
6 |
+
rm -rf src/__pycache__ && rm -rf src/subpages/__pycache__
|
7 |
+
zip -r vis2.zip doc html src Makefile presentation.pdf requirements.txt
|
8 |
+
|
9 |
+
run:
|
10 |
+
python -m streamlit run src/app.py
|
README.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: ExplaiNER
|
3 |
+
emoji: 🏷️
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: indigo
|
6 |
+
python_version: 3.9
|
7 |
+
sdk: streamlit
|
8 |
+
sdk_version: 1.10.0
|
9 |
+
app_file: src/app.py
|
10 |
+
pinned: true
|
11 |
+
duplicated_from: aseifert/ExplaiNER
|
12 |
+
---
|
13 |
+
|
14 |
+
# 🏷️ ExplaiNER: Error Analysis for NER models & datasets
|
15 |
+
|
16 |
+
Error Analysis is an important but often overlooked part of the data science project lifecycle, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an extensive toolkit to probe any NER model/dataset combination, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further improvements.
|
17 |
+
|
18 |
+
## Sections
|
19 |
+
|
20 |
+
|
21 |
+
### Activations
|
22 |
+
|
23 |
+
A group of neurons tends to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model.
|
24 |
+
|
25 |
+
|
26 |
+
### Embeddings
|
27 |
+
|
28 |
+
For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
|
29 |
+
|
30 |
+
|
31 |
+
### Probing
|
32 |
+
|
33 |
+
A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors.
|
34 |
+
|
35 |
+
|
36 |
+
### Metrics
|
37 |
+
|
38 |
+
The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
|
39 |
+
|
40 |
+
|
41 |
+
### Misclassified
|
42 |
+
|
43 |
+
This page contains all misclassified examples and allows filtering by specific error types.
|
44 |
+
|
45 |
+
|
46 |
+
### Loss by Token/Label
|
47 |
+
|
48 |
+
Show count, mean and median loss per token and label.
|
49 |
+
|
50 |
+
|
51 |
+
### Samples by Loss
|
52 |
+
|
53 |
+
Show every example sorted by loss (descending) for close inspection.
|
54 |
+
|
55 |
+
|
56 |
+
### Random Samples
|
57 |
+
|
58 |
+
Show random samples. Simple method, but it often turns up interesting things.
|
59 |
+
|
60 |
+
|
61 |
+
### Find Duplicates
|
62 |
+
|
63 |
+
Find potential duplicates in the data using cosine similarity.
|
64 |
+
|
65 |
+
|
66 |
+
### Inspect
|
67 |
+
|
68 |
+
Inspect your whole dataset, either unfiltered or by id.
|
69 |
+
|
70 |
+
|
71 |
+
### Raw data
|
72 |
+
|
73 |
+
See the data as seen by your model.
|
74 |
+
|
75 |
+
|
76 |
+
### Debug
|
77 |
+
|
78 |
+
Debug info.
|
html/index.md
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: "🏷️ ExplaiNER"
|
3 |
+
subtitle: "Error Analysis for NER models & datasets"
|
4 |
+
---
|
5 |
+
|
6 |
+
<div style="text-align: center">
|
7 |
+
<img src="screenshot.jpg" alt="drawing" width="480px"/>
|
8 |
+
</div>
|
9 |
+
|
10 |
+
_Error Analysis is an important but often overlooked part of the data science project lifecycle, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an extensive toolkit to probe any NER model/dataset combination, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further improvements._
|
11 |
+
|
12 |
+
[Documentation](../doc/index.html) | [Slides](../presentation.pdf) | [Github](https://github.com/aseifert/ExplaiNER)
|
13 |
+
|
14 |
+
|
15 |
+
## Getting started
|
16 |
+
|
17 |
+
```bash
|
18 |
+
# Install requirements
|
19 |
+
pip install -r requirements.txt # you'll need Python 3.9+
|
20 |
+
|
21 |
+
# Run
|
22 |
+
make run
|
23 |
+
```
|
24 |
+
|
25 |
+
## Description
|
26 |
+
|
27 |
+
Some interesting **visualization techniques** contained in this project:
|
28 |
+
|
29 |
+
* customizable visualization of neural network activation, based on the embedding layer and the feed-forward layers of the selected transformer model. ([Alammar 2021](https://aclanthology.org/2021.acl-demo.30/))
|
30 |
+
* customizable similarity map of a 2d projection of the model's final layer's hidden states, using various algorithms (a bit like the [Tensorflow Embedding Projector](https://projector.tensorflow.org/))
|
31 |
+
* inline HTML representation of samples with token-level prediction + labels (my own; see below under 'Samples by loss' for more info)
|
32 |
+
|
33 |
+
|
34 |
+
**Libraries** important to this project:
|
35 |
+
|
36 |
+
* `streamlit` for demoing (custom multi-page feature hacked in, also using session state)
|
37 |
+
* `plotly` and `matplotlib` for charting
|
38 |
+
* `transformers` for providing the models, and `datasets` for, well, the datasets
|
39 |
+
* a forked, slightly modified version of [`ecco`](https://github.com/jalammar/ecco) for visualizing the neural net activations
|
40 |
+
* `sentence_transformers` for finding potential duplicates
|
41 |
+
* `scikit-learn` for TruncatedSVD & PCA, `umap-learn` for UMAP
|
42 |
+
|
43 |
+
|
44 |
+
## Application Sections
|
45 |
+
|
46 |
+
|
47 |
+
Activations
|
48 |
+
|
49 |
+
> A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model.
|
50 |
+
|
51 |
+
|
52 |
+
Hidden States
|
53 |
+
|
54 |
+
> For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
|
55 |
+
>
|
56 |
+
> Using these projections you can visually identify data points that end up in the wrong neighborhood, indicating prediction/labeling errors.
|
57 |
+
|
58 |
+
|
59 |
+
Probing
|
60 |
+
|
61 |
+
> A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors.
|
62 |
+
|
63 |
+
|
64 |
+
Metrics
|
65 |
+
|
66 |
+
> The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
|
67 |
+
>
|
68 |
+
> With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone.
|
69 |
+
|
70 |
+
|
71 |
+
Misclassified
|
72 |
+
|
73 |
+
> This page contains all misclassified examples and allows filtering by specific error types. Helps you get an understanding of the types of errors your model makes.
|
74 |
+
|
75 |
+
|
76 |
+
Loss by Token/Label
|
77 |
+
|
78 |
+
> Show count, mean and median loss per token and label.
|
79 |
+
>
|
80 |
+
> Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues.
|
81 |
+
|
82 |
+
|
83 |
+
Samples by Loss
|
84 |
+
|
85 |
+
> Show every example sorted by loss (descending) for close inspection.
|
86 |
+
>
|
87 |
+
> Apart from a (token-based) dataframe view, there's also an HTML representation of the samples, which is very information-dense but really helpful, once you got used to reading it:
|
88 |
+
>
|
89 |
+
> Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class.
|
90 |
+
>
|
91 |
+
> For short texts, the dataframe view can be sufficient, but for longer texts the HTML view tends to be more useful.
|
92 |
+
|
93 |
+
|
94 |
+
Random Samples
|
95 |
+
|
96 |
+
> Show random samples. Simple method, but it often turns up interesting things.
|
97 |
+
|
98 |
+
|
99 |
+
Find Duplicates
|
100 |
+
|
101 |
+
> Find potential duplicates in the data using cosine similarity.
|
102 |
+
|
103 |
+
|
104 |
+
Inspect
|
105 |
+
|
106 |
+
> Inspect your whole dataset, either unfiltered or by id.
|
107 |
+
|
108 |
+
|
109 |
+
Raw data
|
110 |
+
|
111 |
+
> See the data as seen by your model.
|
112 |
+
|
113 |
+
|
114 |
+
Debug
|
115 |
+
|
116 |
+
> Debug info.
|
html/screenshot.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://download.pytorch.org/whl/cpu/torch-1.11.0%2Bcpu-cp39-cp39-linux_x86_64.whl
|
2 |
+
streamlit
|
3 |
+
pandas
|
4 |
+
scikit-learn
|
5 |
+
plotly
|
6 |
+
sentence-transformers
|
7 |
+
transformers
|
8 |
+
tokenizers
|
9 |
+
datasets
|
10 |
+
numpy
|
11 |
+
matplotlib
|
12 |
+
seqeval
|
13 |
+
streamlit-aggrid
|
14 |
+
streamlit_option_menu
|
15 |
+
pdoc
|
16 |
+
git+https://github.com/aseifert/ecco.git@streamlit
|
src/__init__.py
ADDED
File without changes
|
src/app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""The App module is the main entry point for the application.
|
2 |
+
|
3 |
+
Run `streamlit run app.py` to start the app.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import streamlit as st
|
8 |
+
from streamlit_option_menu import option_menu
|
9 |
+
|
10 |
+
from src.load import load_context
|
11 |
+
from src.subpages import (
|
12 |
+
DebugPage,
|
13 |
+
FindDuplicatesPage,
|
14 |
+
HomePage,
|
15 |
+
LossesPage,
|
16 |
+
LossySamplesPage,
|
17 |
+
MetricsPage,
|
18 |
+
MisclassifiedPage,
|
19 |
+
Page,
|
20 |
+
ProbingPage,
|
21 |
+
RandomSamplesPage,
|
22 |
+
RawDataPage,
|
23 |
+
)
|
24 |
+
from src.subpages.attention import AttentionPage
|
25 |
+
from src.subpages.hidden_states import HiddenStatesPage
|
26 |
+
from src.subpages.inspect import InspectPage
|
27 |
+
from src.utils import classmap
|
28 |
+
|
29 |
+
sts = st.sidebar
|
30 |
+
st.set_page_config(
|
31 |
+
layout="wide",
|
32 |
+
page_title="Error Analysis",
|
33 |
+
page_icon="🏷️",
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def _show_menu(pages: list[Page]) -> int:
|
38 |
+
with st.sidebar:
|
39 |
+
page_names = [p.name for p in pages]
|
40 |
+
page_icons = [p.icon for p in pages]
|
41 |
+
selected_menu_item = st.session_state.active_page = option_menu(
|
42 |
+
menu_title="ExplaiNER",
|
43 |
+
options=page_names,
|
44 |
+
icons=page_icons,
|
45 |
+
menu_icon="layout-wtf",
|
46 |
+
default_index=0,
|
47 |
+
)
|
48 |
+
return page_names.index(selected_menu_item)
|
49 |
+
assert False
|
50 |
+
|
51 |
+
|
52 |
+
def _initialize_session_state(pages: list[Page]):
|
53 |
+
if "active_page" not in st.session_state:
|
54 |
+
for page in pages:
|
55 |
+
st.session_state.update(**page._get_widget_defaults())
|
56 |
+
st.session_state.update(st.session_state)
|
57 |
+
|
58 |
+
|
59 |
+
def _write_color_legend(context):
|
60 |
+
def style(x):
|
61 |
+
return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]
|
62 |
+
|
63 |
+
labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
|
64 |
+
colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]
|
65 |
+
|
66 |
+
color_legend_df = pd.DataFrame(
|
67 |
+
[classmap[l] for l in labels], columns=["label"], index=labels
|
68 |
+
).T
|
69 |
+
st.sidebar.write(
|
70 |
+
color_legend_df.T.style.apply(style, axis=0).set_properties(
|
71 |
+
**{"color": "white", "text-align": "center"}
|
72 |
+
)
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def main():
|
77 |
+
"""The main entry point for the application."""
|
78 |
+
pages: list[Page] = [
|
79 |
+
HomePage(),
|
80 |
+
AttentionPage(),
|
81 |
+
HiddenStatesPage(),
|
82 |
+
ProbingPage(),
|
83 |
+
MetricsPage(),
|
84 |
+
LossySamplesPage(),
|
85 |
+
LossesPage(),
|
86 |
+
MisclassifiedPage(),
|
87 |
+
RandomSamplesPage(),
|
88 |
+
FindDuplicatesPage(),
|
89 |
+
InspectPage(),
|
90 |
+
RawDataPage(),
|
91 |
+
DebugPage(),
|
92 |
+
]
|
93 |
+
|
94 |
+
_initialize_session_state(pages)
|
95 |
+
|
96 |
+
selected_page_idx = _show_menu(pages)
|
97 |
+
selected_page = pages[selected_page_idx]
|
98 |
+
|
99 |
+
if isinstance(selected_page, HomePage):
|
100 |
+
selected_page.render()
|
101 |
+
return
|
102 |
+
|
103 |
+
if "model_name" not in st.session_state:
|
104 |
+
# this can happen if someone loads another page directly (without going through home)
|
105 |
+
st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'")
|
106 |
+
return
|
107 |
+
|
108 |
+
context = load_context(**st.session_state)
|
109 |
+
_write_color_legend(context)
|
110 |
+
selected_page.render(context)
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
main()
|
src/data.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
import torch
|
6 |
+
from datasets import Dataset, DatasetDict, load_dataset # type: ignore
|
7 |
+
from torch.nn.functional import cross_entropy
|
8 |
+
from transformers import DataCollatorForTokenClassification # type: ignore
|
9 |
+
|
10 |
+
from src.utils import device, tokenizer_hash_funcs
|
11 |
+
|
12 |
+
|
13 |
+
@st.cache(allow_output_mutation=True)
|
14 |
+
def get_data(
|
15 |
+
ds_name: str, config_name: str, split_name: str, split_sample_size: int, randomize_sample: bool
|
16 |
+
) -> Dataset:
|
17 |
+
"""Loads a Dataset from the HuggingFace hub (if not already loaded).
|
18 |
+
|
19 |
+
Uses `datasets.load_dataset` to load the dataset (see its documentation for additional details).
|
20 |
+
|
21 |
+
Args:
|
22 |
+
ds_name (str): Path or name of the dataset.
|
23 |
+
config_name (str): Name of the dataset configuration.
|
24 |
+
split_name (str): Which split of the data to load.
|
25 |
+
split_sample_size (int): The number of examples to load from the split.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Dataset: A Dataset object.
|
29 |
+
"""
|
30 |
+
ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(
|
31 |
+
seed=0 if randomize_sample else None
|
32 |
+
) # type: ignore
|
33 |
+
split = ds[split_name].select(range(split_sample_size))
|
34 |
+
return split
|
35 |
+
|
36 |
+
|
37 |
+
@st.cache(
|
38 |
+
allow_output_mutation=True,
|
39 |
+
hash_funcs=tokenizer_hash_funcs,
|
40 |
+
)
|
41 |
+
def get_collator(tokenizer) -> DataCollatorForTokenClassification:
|
42 |
+
"""Returns a DataCollator that will dynamically pad the inputs received, as well as the labels.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
tokenizer ([PreTrainedTokenizer] or [PreTrainedTokenizerFast]): The tokenizer used for encoding the data.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
DataCollatorForTokenClassification: The DataCollatorForTokenClassification object.
|
49 |
+
"""
|
50 |
+
return DataCollatorForTokenClassification(tokenizer)
|
51 |
+
|
52 |
+
|
53 |
+
def create_word_ids_from_input_ids(tokenizer, input_ids: list[int]) -> list[int]:
|
54 |
+
"""Takes a list of input_ids and return corresponding word_ids
|
55 |
+
|
56 |
+
Args:
|
57 |
+
tokenizer: The tokenizer that was used to obtain the input ids.
|
58 |
+
input_ids (list[int]): List of token ids.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
list[int]: Word ids corresponding to the input ids.
|
62 |
+
"""
|
63 |
+
word_ids = []
|
64 |
+
wid = -1
|
65 |
+
tokens = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
|
66 |
+
|
67 |
+
for i, tok in enumerate(tokens):
|
68 |
+
if tok in tokenizer.all_special_tokens:
|
69 |
+
word_ids.append(-1)
|
70 |
+
continue
|
71 |
+
|
72 |
+
if not tokens[i - 1].endswith("@@") and tokens[i - 1] != "<unk>":
|
73 |
+
wid += 1
|
74 |
+
|
75 |
+
word_ids.append(wid)
|
76 |
+
|
77 |
+
assert len(word_ids) == len(input_ids)
|
78 |
+
return word_ids
|
79 |
+
|
80 |
+
|
81 |
+
def tokenize(batch, tokenizer) -> dict:
|
82 |
+
"""Tokenizes a batch of examples.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
batch: The examples to tokenize
|
86 |
+
tokenizer: The tokenizer to use
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
dict: The tokenized batch
|
90 |
+
"""
|
91 |
+
tokenized_inputs = tokenizer(batch["tokens"], truncation=True, is_split_into_words=True)
|
92 |
+
labels = []
|
93 |
+
wids = []
|
94 |
+
|
95 |
+
for idx, label in enumerate(batch["ner_tags"]):
|
96 |
+
try:
|
97 |
+
word_ids = tokenized_inputs.word_ids(batch_index=idx)
|
98 |
+
except ValueError:
|
99 |
+
word_ids = create_word_ids_from_input_ids(
|
100 |
+
tokenizer, tokenized_inputs["input_ids"][idx]
|
101 |
+
)
|
102 |
+
previous_word_idx = None
|
103 |
+
label_ids = []
|
104 |
+
for word_idx in word_ids:
|
105 |
+
if word_idx == -1 or word_idx is None or word_idx == previous_word_idx:
|
106 |
+
label_ids.append(-100)
|
107 |
+
else:
|
108 |
+
label_ids.append(label[word_idx])
|
109 |
+
previous_word_idx = word_idx
|
110 |
+
wids.append(word_ids)
|
111 |
+
labels.append(label_ids)
|
112 |
+
tokenized_inputs["word_ids"] = wids
|
113 |
+
tokenized_inputs["labels"] = labels
|
114 |
+
return tokenized_inputs
|
115 |
+
|
116 |
+
|
117 |
+
def stringify_ner_tags(batch: dict, tags) -> dict:
|
118 |
+
"""Stringifies a dataset batch's NER tags."""
|
119 |
+
return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}
|
120 |
+
|
121 |
+
|
122 |
+
def encode_dataset(split: Dataset, tokenizer):
|
123 |
+
"""Encodes a dataset split.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
split (Dataset): A Dataset object.
|
127 |
+
tokenizer: A PreTrainedTokenizer object.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Dataset: A Dataset object with the encoded inputs.
|
131 |
+
"""
|
132 |
+
|
133 |
+
tags = split.features["ner_tags"].feature
|
134 |
+
split = split.map(partial(stringify_ner_tags, tags=tags), batched=True)
|
135 |
+
remove_columns = split.column_names
|
136 |
+
ids = split["id"]
|
137 |
+
split = split.map(
|
138 |
+
partial(tokenize, tokenizer=tokenizer),
|
139 |
+
batched=True,
|
140 |
+
remove_columns=remove_columns,
|
141 |
+
)
|
142 |
+
word_ids = [[id if id is not None else -1 for id in wids] for wids in split["word_ids"]]
|
143 |
+
return split.remove_columns(["word_ids"]), word_ids, ids
|
144 |
+
|
145 |
+
|
146 |
+
def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
|
147 |
+
"""Runs the forward pass for a batch of examples.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
batch: The batch to process
|
151 |
+
model: The model to process the batch with
|
152 |
+
collator: A data collator
|
153 |
+
num_classes (int): Number of classes
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
dict: a dictionary containing `losses`, `preds` and `hidden_states`
|
157 |
+
"""
|
158 |
+
|
159 |
+
# Convert dict of lists to list of dicts suitable for data collator
|
160 |
+
features = [dict(zip(batch, t)) for t in zip(*batch.values())]
|
161 |
+
|
162 |
+
# Pad inputs and labels and put all tensors on device
|
163 |
+
batch = collator(features)
|
164 |
+
input_ids = batch["input_ids"].to(device)
|
165 |
+
attention_mask = batch["attention_mask"].to(device)
|
166 |
+
labels = batch["labels"].to(device)
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
# Pass data through model
|
170 |
+
output = model(input_ids, attention_mask, output_hidden_states=True)
|
171 |
+
# logit.size: [batch_size, sequence_length, classes]
|
172 |
+
|
173 |
+
# Predict class with largest logit value on classes axis
|
174 |
+
preds = torch.argmax(output.logits, axis=-1).cpu().numpy() # type: ignore
|
175 |
+
|
176 |
+
# Calculate loss per token after flattening batch dimension with view
|
177 |
+
loss = cross_entropy(
|
178 |
+
output.logits.view(-1, num_classes), labels.view(-1), reduction="none"
|
179 |
+
)
|
180 |
+
|
181 |
+
# Unflatten batch dimension and convert to numpy array
|
182 |
+
loss = loss.view(len(input_ids), -1).cpu().numpy()
|
183 |
+
hidden_states = output.hidden_states[-1].cpu().numpy()
|
184 |
+
|
185 |
+
# logits = output.logits.view(len(input_ids), -1).cpu().numpy()
|
186 |
+
|
187 |
+
return {"losses": loss, "preds": preds, "hidden_states": hidden_states}
|
188 |
+
|
189 |
+
|
190 |
+
def predict(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
|
191 |
+
"""Generates predictions for a given dataset split and returns the results as a dataframe.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
split_encoded (Dataset): The dataset to process
|
195 |
+
model: The model to process the dataset with
|
196 |
+
tokenizer: The tokenizer to process the dataset with
|
197 |
+
collator: The data collator to use
|
198 |
+
tags: The tags used in the dataset
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
pd.DataFrame: A dataframe containing token-level predictions.
|
202 |
+
"""
|
203 |
+
|
204 |
+
split_encoded = split_encoded.map(
|
205 |
+
partial(
|
206 |
+
forward_pass_with_label,
|
207 |
+
model=model,
|
208 |
+
collator=collator,
|
209 |
+
num_classes=tags.num_classes,
|
210 |
+
),
|
211 |
+
batched=True,
|
212 |
+
batch_size=8,
|
213 |
+
)
|
214 |
+
df: pd.DataFrame = split_encoded.to_pandas() # type: ignore
|
215 |
+
|
216 |
+
df["tokens"] = df["input_ids"].apply(
|
217 |
+
lambda x: tokenizer.convert_ids_to_tokens(x) # type: ignore
|
218 |
+
)
|
219 |
+
df["labels"] = df["labels"].apply(
|
220 |
+
lambda x: ["IGN" if i == -100 else tags.int2str(int(i)) for i in x]
|
221 |
+
)
|
222 |
+
df["preds"] = df["preds"].apply(lambda x: [model.config.id2label[i] for i in x])
|
223 |
+
df["preds"] = df.apply(lambda x: x["preds"][: len(x["input_ids"])], axis=1)
|
224 |
+
df["losses"] = df.apply(lambda x: x["losses"][: len(x["input_ids"])], axis=1)
|
225 |
+
df["hidden_states"] = df.apply(lambda x: x["hidden_states"][: len(x["input_ids"])], axis=1)
|
226 |
+
df["total_loss"] = df["losses"].apply(sum)
|
227 |
+
|
228 |
+
return df
|
src/load.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
from datasets import Dataset # type: ignore
|
6 |
+
|
7 |
+
from src.data import encode_dataset, get_collator, get_data, predict
|
8 |
+
from src.model import get_encoder, get_model, get_tokenizer
|
9 |
+
from src.subpages import Context
|
10 |
+
from src.utils import align_sample, device, explode_df
|
11 |
+
|
12 |
+
_TOKENIZER_NAME = (
|
13 |
+
"xlm-roberta-base",
|
14 |
+
"gagan3012/bert-tiny-finetuned-ner",
|
15 |
+
"distilbert-base-german-cased",
|
16 |
+
)[0]
|
17 |
+
|
18 |
+
|
19 |
+
def _load_models_and_tokenizer(
|
20 |
+
encoder_model_name: str,
|
21 |
+
model_name: str,
|
22 |
+
tokenizer_name: Optional[str],
|
23 |
+
device: str = "cpu",
|
24 |
+
):
|
25 |
+
sentence_encoder = get_encoder(encoder_model_name, device=device)
|
26 |
+
tokenizer = get_tokenizer(tokenizer_name if tokenizer_name else model_name)
|
27 |
+
labels = "O B-COMMA".split() if "comma" in model_name else None
|
28 |
+
model = get_model(model_name, labels=labels)
|
29 |
+
return sentence_encoder, model, tokenizer
|
30 |
+
|
31 |
+
|
32 |
+
@st.cache(allow_output_mutation=True)
|
33 |
+
def load_context(
|
34 |
+
encoder_model_name: str,
|
35 |
+
model_name: str,
|
36 |
+
ds_name: str,
|
37 |
+
ds_config_name: str,
|
38 |
+
ds_split_name: str,
|
39 |
+
split_sample_size: int,
|
40 |
+
randomize_sample: bool,
|
41 |
+
**kw_args,
|
42 |
+
) -> Context:
|
43 |
+
"""Utility method loading (almost) everything we need for the application.
|
44 |
+
This exists just because we want to cache the results of this function.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
encoder_model_name (str): Name of the sentence encoder to load.
|
48 |
+
model_name (str): Name of the NER model to load.
|
49 |
+
ds_name (str): Dataset name or path.
|
50 |
+
ds_config_name (str): Dataset config name.
|
51 |
+
ds_split_name (str): Dataset split name.
|
52 |
+
split_sample_size (int): Number of examples to load from the split.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
Context: An object containing everything we need for the application.
|
56 |
+
"""
|
57 |
+
|
58 |
+
sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
|
59 |
+
encoder_model_name=encoder_model_name,
|
60 |
+
model_name=model_name,
|
61 |
+
tokenizer_name=_TOKENIZER_NAME if "comma" in model_name else None,
|
62 |
+
device=str(device),
|
63 |
+
)
|
64 |
+
collator = get_collator(tokenizer)
|
65 |
+
|
66 |
+
# load data related stuff
|
67 |
+
split: Dataset = get_data(
|
68 |
+
ds_name, ds_config_name, ds_split_name, split_sample_size, randomize_sample
|
69 |
+
)
|
70 |
+
tags = split.features["ner_tags"].feature
|
71 |
+
split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
|
72 |
+
|
73 |
+
# transform into dataframe
|
74 |
+
df = predict(split_encoded, model, tokenizer, collator, tags)
|
75 |
+
df["word_ids"] = word_ids
|
76 |
+
df["ids"] = ids
|
77 |
+
|
78 |
+
# explode, clean, merge
|
79 |
+
df_tokens = explode_df(df)
|
80 |
+
df_tokens_cleaned = df_tokens.query("labels != 'IGN'")
|
81 |
+
df_merged = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
|
82 |
+
df_tokens_merged = explode_df(df_merged)
|
83 |
+
|
84 |
+
return Context(
|
85 |
+
**{
|
86 |
+
"model": model,
|
87 |
+
"tokenizer": tokenizer,
|
88 |
+
"sentence_encoder": sentence_encoder,
|
89 |
+
"df": df,
|
90 |
+
"df_tokens": df_tokens,
|
91 |
+
"df_tokens_cleaned": df_tokens_cleaned,
|
92 |
+
"df_tokens_merged": df_tokens_merged,
|
93 |
+
"tags": tags,
|
94 |
+
"labels": tags.names,
|
95 |
+
"split_sample_size": split_sample_size,
|
96 |
+
"ds_name": ds_name,
|
97 |
+
"ds_config_name": ds_config_name,
|
98 |
+
"ds_split_name": ds_split_name,
|
99 |
+
"split": split,
|
100 |
+
}
|
101 |
+
)
|
src/model.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from transformers import AutoModelForTokenClassification # type: ignore
|
4 |
+
from transformers import AutoTokenizer # type: ignore
|
5 |
+
|
6 |
+
|
7 |
+
@st.experimental_singleton()
|
8 |
+
def get_model(model_name: str, labels=None):
|
9 |
+
if labels is None:
|
10 |
+
return AutoModelForTokenClassification.from_pretrained(
|
11 |
+
model_name,
|
12 |
+
output_attentions=True,
|
13 |
+
) # type: ignore
|
14 |
+
else:
|
15 |
+
id2label = {idx: tag for idx, tag in enumerate(labels)}
|
16 |
+
label2id = {tag: idx for idx, tag in enumerate(labels)}
|
17 |
+
return AutoModelForTokenClassification.from_pretrained(
|
18 |
+
model_name,
|
19 |
+
output_attentions=True,
|
20 |
+
num_labels=len(labels),
|
21 |
+
id2label=id2label,
|
22 |
+
label2id=label2id,
|
23 |
+
) # type: ignore
|
24 |
+
|
25 |
+
|
26 |
+
@st.experimental_singleton()
|
27 |
+
def get_encoder(model_name: str, device: str = "cpu"):
|
28 |
+
return SentenceTransformer(model_name, device=device)
|
29 |
+
|
30 |
+
|
31 |
+
@st.experimental_singleton()
|
32 |
+
def get_tokenizer(tokenizer_name: str):
|
33 |
+
return AutoTokenizer.from_pretrained(tokenizer_name)
|
src/subpages/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.subpages.attention import AttentionPage
|
2 |
+
from src.subpages.debug import DebugPage
|
3 |
+
from src.subpages.find_duplicates import FindDuplicatesPage
|
4 |
+
from src.subpages.hidden_states import HiddenStatesPage
|
5 |
+
from src.subpages.home import HomePage
|
6 |
+
from src.subpages.inspect import InspectPage
|
7 |
+
from src.subpages.losses import LossesPage
|
8 |
+
from src.subpages.lossy_samples import LossySamplesPage
|
9 |
+
from src.subpages.metrics import MetricsPage
|
10 |
+
from src.subpages.misclassified import MisclassifiedPage
|
11 |
+
from src.subpages.page import Context, Page
|
12 |
+
from src.subpages.probing import ProbingPage
|
13 |
+
from src.subpages.random_samples import RandomSamplesPage
|
14 |
+
from src.subpages.raw_data import RawDataPage
|
src/subpages/attention.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model.
|
3 |
+
"""
|
4 |
+
import ecco
|
5 |
+
import streamlit as st
|
6 |
+
from streamlit.components.v1 import html
|
7 |
+
|
8 |
+
from src.subpages.page import Context, Page # type: ignore
|
9 |
+
|
10 |
+
_SETUP_HTML = """
|
11 |
+
<script src="https://requirejs.org/docs/release/2.3.6/minified/require.js"></script>
|
12 |
+
<script>
|
13 |
+
var ecco_url = 'https://storage.googleapis.com/ml-intro/ecco/'
|
14 |
+
//var ecco_url = 'http://localhost:8000/'
|
15 |
+
|
16 |
+
if (window.ecco === undefined) window.ecco = {}
|
17 |
+
|
18 |
+
// Setup the paths of the script we'll be using
|
19 |
+
requirejs.config({
|
20 |
+
urlArgs: "bust=" + (new Date()).getTime(),
|
21 |
+
nodeRequire: require,
|
22 |
+
paths: {
|
23 |
+
d3: "https://d3js.org/d3.v6.min", // This is only for use in setup.html and basic.html
|
24 |
+
"d3-array": "https://d3js.org/d3-array.v2.min",
|
25 |
+
jquery: "https://code.jquery.com/jquery-3.5.1.min",
|
26 |
+
ecco: ecco_url + 'js/0.0.6/ecco-bundle.min',
|
27 |
+
xregexp: 'https://cdnjs.cloudflare.com/ajax/libs/xregexp/3.2.0/xregexp-all.min'
|
28 |
+
}
|
29 |
+
});
|
30 |
+
|
31 |
+
// Add the css file
|
32 |
+
//requirejs(['d3'],
|
33 |
+
// function (d3) {
|
34 |
+
// d3.select('#css').attr('href', ecco_url + 'html/styles.css')
|
35 |
+
// })
|
36 |
+
|
37 |
+
console.log('Ecco initialize!!')
|
38 |
+
|
39 |
+
// returns a 'basic' object. basic.init() selects the html div we'll be
|
40 |
+
// rendering the html into, adds styles.css to the document.
|
41 |
+
define('basic', ['d3'],
|
42 |
+
function (d3) {
|
43 |
+
return {
|
44 |
+
init: function (viz_id = null) {
|
45 |
+
if (viz_id == null) {
|
46 |
+
viz_id = "viz_" + Math.round(Math.random() * 10000000)
|
47 |
+
}
|
48 |
+
// Select the div rendered below, change its id
|
49 |
+
const div = d3.select('#basic').attr('id', viz_id),
|
50 |
+
div_parent = d3.select('#' + viz_id).node().parentNode
|
51 |
+
|
52 |
+
// Link to CSS file
|
53 |
+
d3.select(div_parent).insert('link')
|
54 |
+
.attr('rel', 'stylesheet')
|
55 |
+
.attr('type', 'text/css')
|
56 |
+
.attr('href', ecco_url + 'html/0.0.2/styles.css')
|
57 |
+
|
58 |
+
return viz_id
|
59 |
+
}
|
60 |
+
}
|
61 |
+
}, function (err) {
|
62 |
+
console.log(err);
|
63 |
+
}
|
64 |
+
)
|
65 |
+
</script>
|
66 |
+
|
67 |
+
<head>
|
68 |
+
<link id='css' rel="stylesheet" type="text/css">
|
69 |
+
</head>
|
70 |
+
<div id="basic"></div>
|
71 |
+
"""
|
72 |
+
|
73 |
+
|
74 |
+
@st.cache(allow_output_mutation=True)
|
75 |
+
def _load_ecco_model():
|
76 |
+
model_config = {
|
77 |
+
"embedding": "embeddings.word_embeddings",
|
78 |
+
"type": "mlm",
|
79 |
+
"activations": [r"ffn\.lin1"],
|
80 |
+
"token_prefix": "",
|
81 |
+
"partial_token_prefix": "##",
|
82 |
+
}
|
83 |
+
return ecco.from_pretrained(
|
84 |
+
"elastic/distilbert-base-uncased-finetuned-conll03-english",
|
85 |
+
model_config=model_config,
|
86 |
+
activations=True,
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
class AttentionPage(Page):
|
91 |
+
name = "Activations"
|
92 |
+
icon = "activity"
|
93 |
+
|
94 |
+
def _get_widget_defaults(self):
|
95 |
+
return {
|
96 |
+
"act_n_components": 8,
|
97 |
+
"act_default_text": """Now I ask you: what can be expected of man since he is a being endowed with strange qualities? Shower upon him every earthly blessing, drown him in a sea of happiness, so that nothing but bubbles of bliss can be seen on the surface; give him economic prosperity, such that he should have nothing else to do but sleep, eat cakes and busy himself with the continuation of his species, and even then out of sheer ingratitude, sheer spite, man would play you some nasty trick. He would even risk his cakes and would deliberately desire the most fatal rubbish, the most uneconomical absurdity, simply to introduce into all this positive good sense his fatal fantastic element. It is just his fantastic dreams, his vulgar folly that he will desire to retain, simply in order to prove to himself--as though that were so necessary-- that men still are men and not the keys of a piano, which the laws of nature threaten to control so completely that soon one will be able to desire nothing but by the calendar. And that is not all: even if man really were nothing but a piano-key, even if this were proved to him by natural science and mathematics, even then he would not become reasonable, but would purposely do something perverse out of simple ingratitude, simply to gain his point. And if he does not find means he will contrive destruction and chaos, will contrive sufferings of all sorts, only to gain his point! He will launch a curse upon the world, and as only man can curse (it is his privilege, the primary distinction between him and other animals), may be by his curse alone he will attain his object--that is, convince himself that he is a man and not a piano-key!""",
|
98 |
+
"act_from_layer": 0,
|
99 |
+
"act_to_layer": 5,
|
100 |
+
}
|
101 |
+
|
102 |
+
def render(self, context: Context):
|
103 |
+
st.title(self.name)
|
104 |
+
|
105 |
+
with st.expander("ℹ️", expanded=True):
|
106 |
+
st.write(
|
107 |
+
"A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model."
|
108 |
+
)
|
109 |
+
|
110 |
+
lm = _load_ecco_model()
|
111 |
+
|
112 |
+
col1, _, col2 = st.columns([1.5, 0.5, 4])
|
113 |
+
with col1:
|
114 |
+
st.subheader("Settings")
|
115 |
+
n_components = st.slider(
|
116 |
+
"#components",
|
117 |
+
key="act_n_components",
|
118 |
+
min_value=2,
|
119 |
+
max_value=10,
|
120 |
+
step=1,
|
121 |
+
)
|
122 |
+
from_layer = st.slider(
|
123 |
+
"from layer",
|
124 |
+
key="act_from_layer",
|
125 |
+
value=0,
|
126 |
+
min_value=0,
|
127 |
+
max_value=len(lm.model.transformer.layer) - 1,
|
128 |
+
step=1,
|
129 |
+
)
|
130 |
+
to_layer = (
|
131 |
+
st.slider(
|
132 |
+
"to layer",
|
133 |
+
key="act_to_layer",
|
134 |
+
value=0,
|
135 |
+
min_value=0,
|
136 |
+
max_value=len(lm.model.transformer.layer) - 1,
|
137 |
+
step=1,
|
138 |
+
)
|
139 |
+
+ 1
|
140 |
+
)
|
141 |
+
|
142 |
+
if to_layer <= from_layer:
|
143 |
+
st.error("to_layer must be >= from_layer")
|
144 |
+
st.stop()
|
145 |
+
|
146 |
+
with col2:
|
147 |
+
st.subheader("–")
|
148 |
+
text = st.text_area("Text", key="act_default_text", height=240)
|
149 |
+
|
150 |
+
inputs = lm.tokenizer([text], return_tensors="pt")
|
151 |
+
output = lm(inputs)
|
152 |
+
nmf = output.run_nmf(n_components=n_components, from_layer=from_layer, to_layer=to_layer)
|
153 |
+
data = nmf.explore(returnData=True)
|
154 |
+
_JS_TEMPLATE = f"""<script>requirejs(['basic', 'ecco'], function(basic, ecco){{
|
155 |
+
const viz_id = basic.init()
|
156 |
+
ecco.interactiveTokensAndFactorSparklines(viz_id, {data}, {{ 'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}} }} }})
|
157 |
+
}}, function (err) {{
|
158 |
+
console.log(err);
|
159 |
+
}})</script>"""
|
160 |
+
html(_SETUP_HTML + _JS_TEMPLATE, height=800, scrolling=True)
|
src/subpages/debug.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pip._internal.operations import freeze
|
3 |
+
|
4 |
+
from src.subpages.page import Context, Page
|
5 |
+
|
6 |
+
|
7 |
+
class DebugPage(Page):
|
8 |
+
name = "Debug"
|
9 |
+
icon = "bug"
|
10 |
+
|
11 |
+
def render(self, context: Context):
|
12 |
+
st.title(self.name)
|
13 |
+
# with st.expander("💡", expanded=True):
|
14 |
+
# st.write("Some debug info.")
|
15 |
+
|
16 |
+
st.subheader("Installed Packages")
|
17 |
+
# get output of pip freeze from system
|
18 |
+
with st.expander("pip freeze"):
|
19 |
+
st.code("\n".join(freeze.freeze()))
|
20 |
+
|
21 |
+
st.subheader("Streamlit Session State")
|
22 |
+
st.json(st.session_state)
|
23 |
+
st.subheader("Tokenizer")
|
24 |
+
st.code(context.tokenizer)
|
25 |
+
st.subheader("Model")
|
26 |
+
st.code(context.model.config)
|
27 |
+
st.code(context.model)
|
src/subpages/emoji-en-US.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/subpages/faiss.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from datasets import Dataset
|
3 |
+
|
4 |
+
from src.subpages.page import Context, Page # type: ignore
|
5 |
+
from src.utils import device, explode_df, htmlify_labeled_example, tag_text
|
6 |
+
|
7 |
+
|
8 |
+
class FaissPage(Page):
|
9 |
+
name = "Bla"
|
10 |
+
icon = "x-octagon"
|
11 |
+
|
12 |
+
def render(self, context: Context):
|
13 |
+
dd = Dataset.from_pandas(context.df_tokens_merged, preserve_index=False) # type: ignore
|
14 |
+
|
15 |
+
dd.add_faiss_index(column="hidden_states", index_name="token_index")
|
16 |
+
token_id, text = (
|
17 |
+
6,
|
18 |
+
"Die Wissenschaft ist eine wichtige Grundlage für die Entwicklung von neuen Technologien.",
|
19 |
+
)
|
20 |
+
token_id, text = (
|
21 |
+
15,
|
22 |
+
"Außer der unbewussten Beeinflussung eines Resultats gibt es auch noch andere Motive die das reine strahlende Licht der Wissenschaft etwas zu trüben vermögen.",
|
23 |
+
)
|
24 |
+
token_id, text = (
|
25 |
+
3,
|
26 |
+
"Mit mehr Instrumenten einer besseren präziseren Datenbasis ist auch ein viel besseres smarteres Risikomanagement möglich.",
|
27 |
+
)
|
28 |
+
token_id, text = (
|
29 |
+
7,
|
30 |
+
"Es gilt die akademische Viertelstunde das heißt Beginn ist fünfzehn Minuten später.",
|
31 |
+
)
|
32 |
+
token_id, text = (
|
33 |
+
7,
|
34 |
+
"Damit einher geht übrigens auch dass Marcella Collocinis Tochter keine wie auch immer geartete strafrechtliche Verfolgung zu befürchten hat.",
|
35 |
+
)
|
36 |
+
token_id, text = (
|
37 |
+
16,
|
38 |
+
"After Steve Jobs met with Bill Gates of Microsoft back in 1993, they went to Cupertino and made the deal.",
|
39 |
+
)
|
40 |
+
|
41 |
+
tagged = tag_text(text, context.tokenizer, context.model, device)
|
42 |
+
hidden_states = tagged["hidden_states"]
|
43 |
+
# tagged.drop("hidden_states", inplace=True, axis=1)
|
44 |
+
# hidden_states_vec = svd.transform([hidden_states[token_id]])[0].astype(np.float32)
|
45 |
+
hidden_states_vec = hidden_states[token_id]
|
46 |
+
tagged = tagged.astype(str)
|
47 |
+
tagged["probs"] = tagged["probs"].apply(lambda x: x[:-2])
|
48 |
+
tagged["check"] = tagged["probs"].apply(
|
49 |
+
lambda x: "✅ ✅" if int(x) < 100 else "✅" if int(x) < 1000 else ""
|
50 |
+
)
|
51 |
+
st.dataframe(tagged.drop("hidden_states", axis=1).T)
|
52 |
+
results = dd.get_nearest_examples("token_index", hidden_states_vec, k=10)
|
53 |
+
for i, (dist, idx, token) in enumerate(
|
54 |
+
zip(results.scores, results.examples["ids"], results.examples["tokens"])
|
55 |
+
):
|
56 |
+
st.code(f"{dist:.3f} {token}")
|
57 |
+
sample = context.df_tokens_merged.query(f"ids == '{idx}'")
|
58 |
+
st.write(f"[{i};{idx}] " + htmlify_labeled_example(sample), unsafe_allow_html=True)
|
src/subpages/find_duplicates.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Find potential duplicates in the data using cosine similarity."""
|
2 |
+
import streamlit as st
|
3 |
+
from sentence_transformers.util import cos_sim
|
4 |
+
|
5 |
+
from src.subpages.page import Context, Page
|
6 |
+
|
7 |
+
|
8 |
+
@st.cache()
|
9 |
+
def get_sims(texts: list[str], sentence_encoder):
|
10 |
+
embeddings = sentence_encoder.encode(texts, batch_size=8, convert_to_numpy=True)
|
11 |
+
return cos_sim(embeddings, embeddings)
|
12 |
+
|
13 |
+
|
14 |
+
class FindDuplicatesPage(Page):
|
15 |
+
name = "Find Duplicates"
|
16 |
+
icon = "fingerprint"
|
17 |
+
|
18 |
+
def _get_widget_defaults(self):
|
19 |
+
return {
|
20 |
+
"cutoff": 0.95,
|
21 |
+
}
|
22 |
+
|
23 |
+
def render(self, context: Context):
|
24 |
+
st.title("Find Duplicates")
|
25 |
+
with st.expander("💡", expanded=True):
|
26 |
+
st.write("Find potential duplicates in the data using cosine similarity.")
|
27 |
+
|
28 |
+
cutoff = st.slider("Similarity threshold", min_value=0.0, max_value=1.0, key="cutoff")
|
29 |
+
# split.add_faiss_index(column="embeddings", index_name="sent_index")
|
30 |
+
# st.write("Index is ready")
|
31 |
+
# sentence_encoder.encode(["hello world"], batch_size=8)
|
32 |
+
# st.write(split["tokens"][0])
|
33 |
+
texts = [" ".join(ts) for ts in context.split["tokens"]]
|
34 |
+
sims = get_sims(texts, context.sentence_encoder)
|
35 |
+
|
36 |
+
candidates = []
|
37 |
+
for i in range(len(sims)):
|
38 |
+
for j in range(i + 1, len(sims)):
|
39 |
+
if sims[i][j] >= cutoff:
|
40 |
+
candidates.append((sims[i][j], i, j))
|
41 |
+
candidates.sort(reverse=False)
|
42 |
+
|
43 |
+
for (sim, i, j) in candidates[:100]:
|
44 |
+
st.markdown(f"**Possible duplicate ({i}, {j}, sim: {sim:.3f}):**")
|
45 |
+
st.markdown("* " + " ".join(context.split["tokens"][i]))
|
46 |
+
st.markdown("* " + " ".join(context.split["tokens"][j]))
|
47 |
+
|
48 |
+
# st.write("queries")
|
49 |
+
# results = split.get_nearest_examples("sent_index", np.array(split["embeddings"][0], dtype=np.float32), k=2)
|
50 |
+
# results = split.get_nearest_examples_batch("sent_index", queries, k=2)
|
51 |
+
# st.write(results.total_examples[0]["id"][1])
|
52 |
+
# st.write(results.total_examples[0])
|
src/subpages/hidden_states.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
import plotly.express as px
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
import streamlit as st
|
8 |
+
|
9 |
+
from src.subpages.page import Context, Page
|
10 |
+
|
11 |
+
|
12 |
+
@st.cache
|
13 |
+
def reduce_dim_svd(X, n_iter: int, random_state=42):
|
14 |
+
"""Dimensionality reduction using truncated SVD (aka LSA).
|
15 |
+
|
16 |
+
This transformer performs linear dimensionality reduction by means of truncated singular value decomposition (SVD). Contrary to PCA, this estimator does not center the data before computing the singular value decomposition. This means it can work with sparse matrices efficiently.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
X: Training data
|
20 |
+
n_iter (int): Desired dimensionality of output data. Must be strictly less than the number of features.
|
21 |
+
random_state (int, optional): Used during randomized svd. Pass an int for reproducible results across multiple function calls. Defaults to 42.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
|
25 |
+
"""
|
26 |
+
from sklearn.decomposition import TruncatedSVD
|
27 |
+
|
28 |
+
svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
|
29 |
+
return svd.fit_transform(X)
|
30 |
+
|
31 |
+
|
32 |
+
@st.cache
|
33 |
+
def reduce_dim_pca(X, random_state=42):
|
34 |
+
"""Principal component analysis (PCA).
|
35 |
+
|
36 |
+
Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. The input data is centered but not scaled for each feature before applying the SVD.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
X: Training data
|
40 |
+
random_state (int, optional): Used when the 'arpack' or 'randomized' solvers are used. Pass an int for reproducible results across multiple function calls.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
|
44 |
+
"""
|
45 |
+
from sklearn.decomposition import PCA
|
46 |
+
|
47 |
+
return PCA(n_components=2, random_state=random_state).fit_transform(X)
|
48 |
+
|
49 |
+
|
50 |
+
@st.cache
|
51 |
+
def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
|
52 |
+
"""Uniform Manifold Approximation and Projection
|
53 |
+
|
54 |
+
Finds a low dimensional embedding of the data that approximates an underlying manifold.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
X: Training data
|
58 |
+
n_neighbors (int, optional): The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. Defaults to 5.
|
59 |
+
min_dist (float, optional): The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. The value should be set relative to the `spread` value, which determines the scale at which embedded points will be spread out. Defaults to 0.1.
|
60 |
+
metric (str, optional): The metric to use to compute distances in high dimensional space (see UMAP docs for options). Defaults to "euclidean".
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
|
64 |
+
"""
|
65 |
+
from umap import UMAP
|
66 |
+
|
67 |
+
return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
|
68 |
+
|
69 |
+
|
70 |
+
class HiddenStatesPage(Page):
|
71 |
+
name = "Hidden States"
|
72 |
+
icon = "grid-3x3"
|
73 |
+
|
74 |
+
def _get_widget_defaults(self):
|
75 |
+
return {
|
76 |
+
"n_tokens": 1_000,
|
77 |
+
"svd_n_iter": 5,
|
78 |
+
"svd_random_state": 42,
|
79 |
+
"umap_n_neighbors": 15,
|
80 |
+
"umap_metric": "euclidean",
|
81 |
+
"umap_min_dist": 0.1,
|
82 |
+
}
|
83 |
+
|
84 |
+
def render(self, context: Context):
|
85 |
+
st.title("Embeddings")
|
86 |
+
|
87 |
+
with st.expander("💡", expanded=True):
|
88 |
+
st.write(
|
89 |
+
"For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements signified by a small black border."
|
90 |
+
)
|
91 |
+
|
92 |
+
col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
|
93 |
+
df = context.df_tokens_merged.copy()
|
94 |
+
dim_algo = "SVD"
|
95 |
+
n_tokens = 100
|
96 |
+
|
97 |
+
with col1:
|
98 |
+
st.subheader("Settings")
|
99 |
+
n_tokens = st.slider(
|
100 |
+
"#tokens",
|
101 |
+
key="n_tokens",
|
102 |
+
min_value=100,
|
103 |
+
max_value=len(df["tokens"].unique()),
|
104 |
+
step=100,
|
105 |
+
)
|
106 |
+
|
107 |
+
dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
|
108 |
+
if dim_algo == "SVD":
|
109 |
+
svd_n_iter = st.slider(
|
110 |
+
"#iterations",
|
111 |
+
key="svd_n_iter",
|
112 |
+
min_value=1,
|
113 |
+
max_value=10,
|
114 |
+
step=1,
|
115 |
+
)
|
116 |
+
elif dim_algo == "UMAP":
|
117 |
+
umap_n_neighbors = st.slider(
|
118 |
+
"#neighbors",
|
119 |
+
key="umap_n_neighbors",
|
120 |
+
min_value=2,
|
121 |
+
max_value=100,
|
122 |
+
step=1,
|
123 |
+
)
|
124 |
+
umap_min_dist = st.number_input(
|
125 |
+
"Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
|
126 |
+
)
|
127 |
+
umap_metric = st.selectbox(
|
128 |
+
"Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
pass
|
132 |
+
|
133 |
+
with col2:
|
134 |
+
sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
|
135 |
+
|
136 |
+
X = np.array(df["hidden_states"].tolist())
|
137 |
+
transformed_hidden_states = None
|
138 |
+
if dim_algo == "SVD":
|
139 |
+
transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore
|
140 |
+
elif dim_algo == "PCA":
|
141 |
+
transformed_hidden_states = reduce_dim_pca(X)
|
142 |
+
elif dim_algo == "UMAP":
|
143 |
+
transformed_hidden_states = reduce_dim_umap(
|
144 |
+
X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
|
145 |
+
)
|
146 |
+
|
147 |
+
assert isinstance(transformed_hidden_states, np.ndarray)
|
148 |
+
df["x"] = transformed_hidden_states[:, 0]
|
149 |
+
df["y"] = transformed_hidden_states[:, 1]
|
150 |
+
df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
|
151 |
+
df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
|
152 |
+
df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
|
153 |
+
df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
|
154 |
+
df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
|
155 |
+
df["disagreements"] = df["labels"] != df["preds"]
|
156 |
+
|
157 |
+
subset = df[:n_tokens]
|
158 |
+
disagreements_trace = go.Scatter(
|
159 |
+
x=subset[subset["disagreements"]]["x"],
|
160 |
+
y=subset[subset["disagreements"]]["y"],
|
161 |
+
mode="markers",
|
162 |
+
marker=dict(
|
163 |
+
size=6,
|
164 |
+
color="rgba(0,0,0,0)",
|
165 |
+
line=dict(width=1),
|
166 |
+
),
|
167 |
+
hoverinfo="skip",
|
168 |
+
)
|
169 |
+
|
170 |
+
st.subheader("Projection Results")
|
171 |
+
|
172 |
+
fig = px.scatter(
|
173 |
+
subset,
|
174 |
+
x="x",
|
175 |
+
y="y",
|
176 |
+
color="labels",
|
177 |
+
hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"],
|
178 |
+
hover_name="tokens",
|
179 |
+
title="Colored by label",
|
180 |
+
)
|
181 |
+
fig.add_trace(disagreements_trace)
|
182 |
+
st.plotly_chart(fig)
|
183 |
+
|
184 |
+
fig = px.scatter(
|
185 |
+
subset,
|
186 |
+
x="x",
|
187 |
+
y="y",
|
188 |
+
color="preds",
|
189 |
+
hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"],
|
190 |
+
hover_name="tokens",
|
191 |
+
title="Colored by prediction",
|
192 |
+
)
|
193 |
+
fig.add_trace(disagreements_trace)
|
194 |
+
st.plotly_chart(fig)
|
src/subpages/home.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
from src.data import get_data
|
8 |
+
from src.subpages.page import Context, Page
|
9 |
+
from src.utils import PROJ, classmap, color_map_color
|
10 |
+
|
11 |
+
_SENTENCE_ENCODER_MODEL = (
|
12 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
13 |
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
14 |
+
)[0]
|
15 |
+
_MODEL_NAME = (
|
16 |
+
"elastic/distilbert-base-uncased-finetuned-conll03-english",
|
17 |
+
"gagan3012/bert-tiny-finetuned-ner",
|
18 |
+
"socialmediaie/bertweet-base_wnut17_ner",
|
19 |
+
"sberbank-ai/bert-base-NER-reptile-5-datasets",
|
20 |
+
"aseifert/comma-xlm-roberta-base",
|
21 |
+
"dslim/bert-base-NER",
|
22 |
+
"aseifert/distilbert-base-german-cased-comma-derstandard",
|
23 |
+
)[0]
|
24 |
+
_DATASET_NAME = (
|
25 |
+
"conll2003",
|
26 |
+
"wnut_17",
|
27 |
+
"aseifert/comma",
|
28 |
+
)[0]
|
29 |
+
_CONFIG_NAME = (
|
30 |
+
"conll2003",
|
31 |
+
"wnut_17",
|
32 |
+
"seifertverlag",
|
33 |
+
)[0]
|
34 |
+
|
35 |
+
|
36 |
+
class HomePage(Page):
|
37 |
+
name = "Home / Setup"
|
38 |
+
icon = "house"
|
39 |
+
|
40 |
+
def _get_widget_defaults(self):
|
41 |
+
return {
|
42 |
+
"encoder_model_name": _SENTENCE_ENCODER_MODEL,
|
43 |
+
"model_name": _MODEL_NAME,
|
44 |
+
"ds_name": _DATASET_NAME,
|
45 |
+
"ds_split_name": "validation",
|
46 |
+
"ds_config_name": _CONFIG_NAME,
|
47 |
+
"split_sample_size": 512,
|
48 |
+
"randomize_sample": True,
|
49 |
+
}
|
50 |
+
|
51 |
+
def render(self, context: Optional[Context] = None):
|
52 |
+
st.title("ExplaiNER")
|
53 |
+
|
54 |
+
with st.expander("💡", expanded=True):
|
55 |
+
st.write(
|
56 |
+
"**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further **improving both model AND dataset**."
|
57 |
+
)
|
58 |
+
st.write(
|
59 |
+
"**Note:** This Space requires a fair amount of computation, so please be patient with the loading animations. 🙏 I am caching as much as possible, so after the first wait most things should be precomputed."
|
60 |
+
)
|
61 |
+
st.write(
|
62 |
+
"_Caveat: Even though everything is customizable here, I haven't tested this app much with different models/datasets._"
|
63 |
+
)
|
64 |
+
|
65 |
+
col1, _, col2a, col2b = st.columns([0.8, 0.05, 0.15, 0.15])
|
66 |
+
|
67 |
+
with col1:
|
68 |
+
random_form_key = f"settings-{random.randint(0, 100000)}"
|
69 |
+
# FIXME: for some reason I'm getting the following error if I don't randomize the key:
|
70 |
+
"""
|
71 |
+
2022-05-05 20:37:16.507 Traceback (most recent call last):
|
72 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/scriptrunner/script_runner.py", line 443, in _run_script
|
73 |
+
exec(code, module.__dict__)
|
74 |
+
File "/Users/zoro/code/error-analysis/main.py", line 162, in <module>
|
75 |
+
main()
|
76 |
+
File "/Users/zoro/code/error-analysis/main.py", line 102, in main
|
77 |
+
show_setup()
|
78 |
+
File "/Users/zoro/code/error-analysis/section/setup.py", line 68, in show_setup
|
79 |
+
st.form_submit_button("Load Model & Data")
|
80 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 240, in form_submit_button
|
81 |
+
return self._form_submit_button(
|
82 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 260, in _form_submit_button
|
83 |
+
return self.dg._button(
|
84 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/button.py", line 304, in _button
|
85 |
+
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
|
86 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/utils.py", line 74, in check_session_state_rules
|
87 |
+
raise StreamlitAPIException(
|
88 |
+
streamlit.errors.StreamlitAPIException: Values for st.button, st.download_button, st.file_uploader, and st.form cannot be set using st.session_state.
|
89 |
+
"""
|
90 |
+
with st.form(key=random_form_key):
|
91 |
+
st.subheader("Model & Data Selection")
|
92 |
+
st.text_input(
|
93 |
+
label="NER Model:",
|
94 |
+
key="model_name",
|
95 |
+
help="Path or name of the model to use",
|
96 |
+
)
|
97 |
+
st.text_input(
|
98 |
+
label="Encoder Model:",
|
99 |
+
key="encoder_model_name",
|
100 |
+
help="Path or name of the encoder to use for duplicate detection",
|
101 |
+
)
|
102 |
+
ds_name = st.text_input(
|
103 |
+
label="Dataset:",
|
104 |
+
key="ds_name",
|
105 |
+
help="Path or name of the dataset to use",
|
106 |
+
)
|
107 |
+
ds_config_name = st.text_input(
|
108 |
+
label="Config (optional):",
|
109 |
+
key="ds_config_name",
|
110 |
+
)
|
111 |
+
ds_split_name = st.selectbox(
|
112 |
+
label="Split:",
|
113 |
+
options=["train", "validation", "test"],
|
114 |
+
key="ds_split_name",
|
115 |
+
)
|
116 |
+
split_sample_size = st.number_input(
|
117 |
+
"Sample size:",
|
118 |
+
step=16,
|
119 |
+
key="split_sample_size",
|
120 |
+
help="Sample size for the split, speeds up processing inside streamlit",
|
121 |
+
)
|
122 |
+
randomize_sample = st.checkbox(
|
123 |
+
"Randomize sample",
|
124 |
+
key="randomize_sample",
|
125 |
+
help="Whether to randomize the sample",
|
126 |
+
)
|
127 |
+
# breakpoint()
|
128 |
+
# st.form_submit_button("Submit")
|
129 |
+
st.form_submit_button("Load Model & Data")
|
130 |
+
|
131 |
+
split = get_data(
|
132 |
+
ds_name, ds_config_name, ds_split_name, split_sample_size, randomize_sample # type: ignore
|
133 |
+
)
|
134 |
+
labels = list(
|
135 |
+
set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"])
|
136 |
+
)
|
137 |
+
|
138 |
+
with col2a:
|
139 |
+
st.subheader("Classes")
|
140 |
+
st.write("**Color**")
|
141 |
+
colors = {label: color_map_color(i / len(labels)) for i, label in enumerate(labels)}
|
142 |
+
for label in labels:
|
143 |
+
if f"color_{label}" not in st.session_state:
|
144 |
+
st.session_state[f"color_{label}"] = colors[label]
|
145 |
+
st.color_picker(label, key=f"color_{label}")
|
146 |
+
with col2b:
|
147 |
+
st.subheader("—")
|
148 |
+
st.write("**Icon**")
|
149 |
+
emojis = list(json.load(open(PROJ / "subpages/emoji-en-US.json")).keys())
|
150 |
+
for label in labels:
|
151 |
+
if f"icon_{label}" not in st.session_state:
|
152 |
+
st.session_state[f"icon_{label}"] = classmap[label]
|
153 |
+
st.selectbox(label, key=f"icon_{label}", options=emojis)
|
154 |
+
classmap[label] = st.session_state[f"icon_{label}"]
|
155 |
+
|
156 |
+
# if st.button("Reset to defaults"):
|
157 |
+
# st.session_state.update(**get_home_page_defaults())
|
158 |
+
# # time.sleep 2 secs
|
159 |
+
# import time
|
160 |
+
# time.sleep(1)
|
161 |
+
|
162 |
+
# # st.legacy_caching.clear_cache()
|
163 |
+
# st.experimental_rerun()
|
src/subpages/inspect.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Inspect your whole dataset, either unfiltered or by id."""
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from src.subpages.page import Context, Page
|
5 |
+
from src.utils import aggrid_interactive_table, colorize_classes
|
6 |
+
|
7 |
+
|
8 |
+
class InspectPage(Page):
|
9 |
+
name = "Inspect"
|
10 |
+
icon = "search"
|
11 |
+
|
12 |
+
def render(self, context: Context):
|
13 |
+
st.title(self.name)
|
14 |
+
with st.expander("💡", expanded=True):
|
15 |
+
st.write("Inspect your whole dataset, either unfiltered or by id.")
|
16 |
+
|
17 |
+
df = context.df_tokens
|
18 |
+
cols = (
|
19 |
+
"ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
|
20 |
+
)
|
21 |
+
if "token_type_ids" not in df.columns:
|
22 |
+
cols.remove("token_type_ids")
|
23 |
+
df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols]
|
24 |
+
|
25 |
+
if st.checkbox("Filter by id", value=True):
|
26 |
+
ids = list(sorted(map(int, df.ids.unique())))
|
27 |
+
next_id = st.session_state.get("next_id", 0)
|
28 |
+
|
29 |
+
example_id = st.selectbox("Select an example", ids, index=next_id)
|
30 |
+
df = df[df.ids == str(example_id)][1:-1]
|
31 |
+
# st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses")) # type: ignore
|
32 |
+
st.dataframe(colorize_classes(df.round(3).astype(str)))
|
33 |
+
|
34 |
+
# if st.button("➡️ Next example"):
|
35 |
+
# st.session_state.next_id = (ids.index(example_id) + 1) % len(ids)
|
36 |
+
# st.experimental_rerun()
|
37 |
+
# if st.button("⬅️ Previous example"):
|
38 |
+
# st.session_state.next_id = (ids.index(example_id) - 1) % len(ids)
|
39 |
+
# st.experimental_rerun()
|
40 |
+
else:
|
41 |
+
aggrid_interactive_table(df.round(3))
|
src/subpages/losses.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Show count, mean and median loss per token and label."""
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from src.subpages.page import Context, Page
|
5 |
+
from src.utils import AgGrid, aggrid_interactive_table
|
6 |
+
|
7 |
+
|
8 |
+
@st.cache
|
9 |
+
def get_loss_by_token(df_tokens):
|
10 |
+
return (
|
11 |
+
df_tokens.groupby("tokens")[["losses"]]
|
12 |
+
.agg(["count", "mean", "median", "sum"])
|
13 |
+
.droplevel(level=0, axis=1) # Get rid of multi-level columns
|
14 |
+
.sort_values(by="sum", ascending=False)
|
15 |
+
.reset_index()
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
@st.cache
|
20 |
+
def get_loss_by_label(df_tokens):
|
21 |
+
return (
|
22 |
+
df_tokens.groupby("labels")[["losses"]]
|
23 |
+
.agg(["count", "mean", "median", "sum"])
|
24 |
+
.droplevel(level=0, axis=1)
|
25 |
+
.sort_values(by="mean", ascending=False)
|
26 |
+
.reset_index()
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class LossesPage(Page):
|
31 |
+
name = "Loss by Token/Label"
|
32 |
+
icon = "sort-alpha-down"
|
33 |
+
|
34 |
+
def render(self, context: Context):
|
35 |
+
st.title(self.name)
|
36 |
+
with st.expander("💡", expanded=True):
|
37 |
+
st.write("Show count, mean and median loss per token and label.")
|
38 |
+
st.write(
|
39 |
+
"Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues."
|
40 |
+
)
|
41 |
+
|
42 |
+
col1, _, col2 = st.columns([8, 1, 6])
|
43 |
+
|
44 |
+
with col1:
|
45 |
+
st.subheader("💬 Loss by Token")
|
46 |
+
|
47 |
+
st.session_state["_merge_tokens"] = st.checkbox(
|
48 |
+
"Merge tokens", value=True, key="merge_tokens"
|
49 |
+
)
|
50 |
+
loss_by_token = (
|
51 |
+
get_loss_by_token(context.df_tokens_merged)
|
52 |
+
if st.session_state["merge_tokens"]
|
53 |
+
else get_loss_by_token(context.df_tokens_cleaned)
|
54 |
+
)
|
55 |
+
aggrid_interactive_table(loss_by_token.round(3))
|
56 |
+
# st.subheader("🏷️ Loss by Label")
|
57 |
+
# loss_by_label = get_loss_by_label(df_tokens_cleaned)
|
58 |
+
# st.dataframe(loss_by_label)
|
59 |
+
|
60 |
+
st.write(
|
61 |
+
"_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._"
|
62 |
+
)
|
63 |
+
|
64 |
+
with col2:
|
65 |
+
st.subheader("🏷️ Loss by Label")
|
66 |
+
loss_by_label = get_loss_by_label(context.df_tokens_cleaned)
|
67 |
+
AgGrid(loss_by_label.round(3), height=200)
|
src/subpages/lossy_samples.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Show every example sorted by loss (descending) for close inspection."""
|
2 |
+
import pandas as pd
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
from src.subpages.page import Context, Page
|
6 |
+
from src.utils import (
|
7 |
+
colorize_classes,
|
8 |
+
get_bg_color,
|
9 |
+
get_fg_color,
|
10 |
+
htmlify_labeled_example,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class LossySamplesPage(Page):
|
15 |
+
name = "Samples by Loss"
|
16 |
+
icon = "sort-numeric-down-alt"
|
17 |
+
|
18 |
+
def _get_widget_defaults(self):
|
19 |
+
return {
|
20 |
+
"skip_correct": True,
|
21 |
+
"samples_by_loss_show_df": True,
|
22 |
+
}
|
23 |
+
|
24 |
+
def render(self, context: Context):
|
25 |
+
st.title(self.name)
|
26 |
+
with st.expander("💡", expanded=True):
|
27 |
+
st.write("Show every example sorted by loss (descending) for close inspection.")
|
28 |
+
st.write(
|
29 |
+
"The **dataframe** is mostly self-explanatory. The cells are color-coded by label, a lighter color signifies a continuation label. Cells in the loss row are filled red from left to right relative to the top loss."
|
30 |
+
)
|
31 |
+
st.write(
|
32 |
+
"The **numbers to the left**: Top (black background) are sample number (listed here) and sample index (from the dataset). Below on yellow background is the total loss for the given sample."
|
33 |
+
)
|
34 |
+
st.write(
|
35 |
+
"The **annotated sample**: Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class."
|
36 |
+
)
|
37 |
+
|
38 |
+
st.subheader("💥 Samples ⬇loss")
|
39 |
+
skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct")
|
40 |
+
show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df")
|
41 |
+
|
42 |
+
st.write(
|
43 |
+
"""<style>
|
44 |
+
thead {
|
45 |
+
display: none;
|
46 |
+
}
|
47 |
+
td {
|
48 |
+
white-space: nowrap;
|
49 |
+
padding: 0 5px !important;
|
50 |
+
}
|
51 |
+
</style>""",
|
52 |
+
unsafe_allow_html=True,
|
53 |
+
)
|
54 |
+
|
55 |
+
top_indices = (
|
56 |
+
context.df.sort_values(by="total_loss", ascending=False)
|
57 |
+
.query("total_loss > 0.5")
|
58 |
+
.index
|
59 |
+
)
|
60 |
+
|
61 |
+
cnt = 0
|
62 |
+
for idx in top_indices:
|
63 |
+
sample = context.df_tokens_merged.loc[idx]
|
64 |
+
|
65 |
+
if isinstance(sample, pd.Series):
|
66 |
+
continue
|
67 |
+
|
68 |
+
if skip_correct and sum(sample.labels != sample.preds) == 0:
|
69 |
+
continue
|
70 |
+
|
71 |
+
if show_df:
|
72 |
+
|
73 |
+
def colorize_col(col):
|
74 |
+
if col.name == "labels" or col.name == "preds":
|
75 |
+
bgs = []
|
76 |
+
fgs = []
|
77 |
+
ops = []
|
78 |
+
for v in col.values:
|
79 |
+
bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
|
80 |
+
fgs.append(get_fg_color(bgs[-1]))
|
81 |
+
ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5")
|
82 |
+
return [
|
83 |
+
f"background-color: {bg}; color: {fg}; opacity: {op};"
|
84 |
+
for bg, fg, op in zip(bgs, fgs, ops)
|
85 |
+
]
|
86 |
+
return [""] * len(col)
|
87 |
+
|
88 |
+
df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3)
|
89 |
+
losses_slice = pd.IndexSlice["losses", :]
|
90 |
+
# x = df.T.astype(str)
|
91 |
+
# st.dataframe(x)
|
92 |
+
# st.dataframe(x.loc[losses_slice])
|
93 |
+
styler = (
|
94 |
+
df.T.style.apply(colorize_col, axis=1)
|
95 |
+
.bar(subset=losses_slice, axis=1)
|
96 |
+
.format(precision=3)
|
97 |
+
)
|
98 |
+
# styler.data = styler.data.astype(str)
|
99 |
+
st.write(styler.to_html(), unsafe_allow_html=True)
|
100 |
+
st.write("")
|
101 |
+
# st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses')) # type: ignore
|
102 |
+
# st.write(
|
103 |
+
# colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str))
|
104 |
+
# )
|
105 |
+
|
106 |
+
col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32])
|
107 |
+
|
108 |
+
cnt += 1
|
109 |
+
counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>"
|
110 |
+
loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
|
111 |
+
col1.write(f"{counter}{loss}", unsafe_allow_html=True)
|
112 |
+
col1.write("")
|
113 |
+
|
114 |
+
col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
|
115 |
+
# st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)
|
src/subpages/metrics.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
|
3 |
+
"""
|
4 |
+
import re
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import plotly.express as px
|
10 |
+
import streamlit as st
|
11 |
+
from seqeval.metrics import classification_report
|
12 |
+
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
|
13 |
+
|
14 |
+
from src.subpages.page import Context, Page
|
15 |
+
|
16 |
+
|
17 |
+
def _get_evaluation(df):
|
18 |
+
y_true = df.apply(lambda row: [lbl for lbl in row.labels if lbl != "IGN"], axis=1)
|
19 |
+
y_pred = df.apply(
|
20 |
+
lambda row: [pred for (pred, lbl) in zip(row.preds, row.labels) if lbl != "IGN"],
|
21 |
+
axis=1,
|
22 |
+
)
|
23 |
+
report: str = classification_report(y_true, y_pred, scheme="IOB2", digits=3) # type: ignore
|
24 |
+
return report.replace(
|
25 |
+
"precision recall f1-score support",
|
26 |
+
"=" * 12 + " precision recall f1-score support",
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True):
|
31 |
+
cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels)
|
32 |
+
if zero_diagonal:
|
33 |
+
np.fill_diagonal(cm, 0)
|
34 |
+
|
35 |
+
# st.write(plt.rcParams["font.size"])
|
36 |
+
# plt.rcParams.update({'font.size': 10.0})
|
37 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
38 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
|
39 |
+
fmt = "d" if normalize is None else ".3f"
|
40 |
+
disp.plot(
|
41 |
+
cmap="Blues",
|
42 |
+
include_values=True,
|
43 |
+
xticks_rotation="vertical",
|
44 |
+
values_format=fmt,
|
45 |
+
ax=ax,
|
46 |
+
colorbar=False,
|
47 |
+
)
|
48 |
+
return fig
|
49 |
+
|
50 |
+
|
51 |
+
class MetricsPage(Page):
|
52 |
+
name = "Metrics"
|
53 |
+
icon = "graph-up-arrow"
|
54 |
+
|
55 |
+
def _get_widget_defaults(self):
|
56 |
+
return {
|
57 |
+
"normalize": True,
|
58 |
+
"zero_diagonal": False,
|
59 |
+
}
|
60 |
+
|
61 |
+
def render(self, context: Context):
|
62 |
+
st.title(self.name)
|
63 |
+
with st.expander("💡", expanded=True):
|
64 |
+
st.write(
|
65 |
+
"The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)."
|
66 |
+
)
|
67 |
+
st.write(
|
68 |
+
"With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone."
|
69 |
+
)
|
70 |
+
|
71 |
+
eval_results = _get_evaluation(context.df)
|
72 |
+
if len(eval_results.splitlines()) < 8:
|
73 |
+
col1, _, col2 = st.columns([8, 1, 1])
|
74 |
+
else:
|
75 |
+
col1 = col2 = st
|
76 |
+
|
77 |
+
col1.subheader("🎯 Evaluation Results")
|
78 |
+
col1.code(eval_results)
|
79 |
+
|
80 |
+
results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]]
|
81 |
+
data = [(r[0], int(r[-1]), float(r[-2])) for r in results]
|
82 |
+
df = pd.DataFrame(data, columns="class support f1".split())
|
83 |
+
fig = px.scatter(
|
84 |
+
df,
|
85 |
+
x="support",
|
86 |
+
y="f1",
|
87 |
+
range_y=(0, 1.05),
|
88 |
+
color="class",
|
89 |
+
)
|
90 |
+
# fig.update_layout(title_text="asdf", title_yanchor="bottom")
|
91 |
+
col1.plotly_chart(fig)
|
92 |
+
|
93 |
+
col2.subheader("🔠 Confusion Matrix")
|
94 |
+
normalize = None if not col2.checkbox("Normalize", key="normalize") else "true"
|
95 |
+
zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal")
|
96 |
+
col2.pyplot(
|
97 |
+
plot_confusion_matrix(
|
98 |
+
y_true=context.df_tokens_cleaned["labels"],
|
99 |
+
y_preds=context.df_tokens_cleaned["preds"],
|
100 |
+
labels=context.labels,
|
101 |
+
normalize=normalize,
|
102 |
+
zero_diagonal=zero_diagonal,
|
103 |
+
),
|
104 |
+
)
|
src/subpages/misclassified.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This page contains all misclassified examples and allows filtering by specific error types."""
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import streamlit as st
|
6 |
+
from sklearn.metrics import confusion_matrix
|
7 |
+
|
8 |
+
from src.subpages.page import Context, Page
|
9 |
+
from src.utils import htmlify_labeled_example
|
10 |
+
|
11 |
+
|
12 |
+
class MisclassifiedPage(Page):
|
13 |
+
name = "Misclassified"
|
14 |
+
icon = "x-octagon"
|
15 |
+
|
16 |
+
def render(self, context: Context):
|
17 |
+
st.title(self.name)
|
18 |
+
with st.expander("💡", expanded=True):
|
19 |
+
st.write(
|
20 |
+
"This page contains all misclassified examples and allows filtering by specific error types."
|
21 |
+
)
|
22 |
+
|
23 |
+
misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
|
24 |
+
misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
|
25 |
+
cm = confusion_matrix(
|
26 |
+
misclassified_samples.labels,
|
27 |
+
misclassified_samples.preds,
|
28 |
+
labels=context.labels,
|
29 |
+
)
|
30 |
+
|
31 |
+
# st.pyplot(
|
32 |
+
# plot_confusion_matrix(
|
33 |
+
# y_preds=misclassified_samples["preds"],
|
34 |
+
# y_true=misclassified_samples["labels"],
|
35 |
+
# labels=labels,
|
36 |
+
# normalize=None,
|
37 |
+
# zero_diagonal=True,
|
38 |
+
# ),
|
39 |
+
# )
|
40 |
+
df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
|
41 |
+
import numpy as np
|
42 |
+
|
43 |
+
np.fill_diagonal(df.values, "")
|
44 |
+
st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
|
45 |
+
# import matplotlib.pyplot as plt
|
46 |
+
# st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
|
47 |
+
# selection = aggrid_interactive_table(df)
|
48 |
+
|
49 |
+
# st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
|
50 |
+
|
51 |
+
confusions = defaultdict(int)
|
52 |
+
for i, row in enumerate(cm):
|
53 |
+
for j, _ in enumerate(row):
|
54 |
+
if i == j or cm[i][j] == 0:
|
55 |
+
continue
|
56 |
+
confusions[(context.labels[i], context.labels[j])] += cm[i][j]
|
57 |
+
|
58 |
+
def format_func(item):
|
59 |
+
return (
|
60 |
+
f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
|
61 |
+
)
|
62 |
+
|
63 |
+
conf = st.radio(
|
64 |
+
"Filter by Class Confusion",
|
65 |
+
options=list(zip(confusions.keys(), confusions.values())),
|
66 |
+
format_func=format_func,
|
67 |
+
)
|
68 |
+
|
69 |
+
# st.write(
|
70 |
+
# f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
|
71 |
+
# )
|
72 |
+
|
73 |
+
filtered_indices = misclassified_samples.query(
|
74 |
+
f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
|
75 |
+
).index
|
76 |
+
for i, idx in enumerate(filtered_indices):
|
77 |
+
sample = context.df_tokens_merged.loc[idx]
|
78 |
+
st.write(
|
79 |
+
htmlify_labeled_example(sample),
|
80 |
+
unsafe_allow_html=True,
|
81 |
+
)
|
82 |
+
st.write("---")
|
src/subpages/page.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from datasets import Dataset # type: ignore
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from transformers import AutoModelForSequenceClassification # type: ignore
|
8 |
+
from transformers import AutoTokenizer # type: ignore
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class Context:
|
13 |
+
"""This object facilitates passing around the application's state between different pages."""
|
14 |
+
|
15 |
+
model: AutoModelForSequenceClassification
|
16 |
+
tokenizer: AutoTokenizer
|
17 |
+
sentence_encoder: SentenceTransformer
|
18 |
+
tags: Any
|
19 |
+
df: pd.DataFrame
|
20 |
+
df_tokens: pd.DataFrame
|
21 |
+
df_tokens_cleaned: pd.DataFrame
|
22 |
+
df_tokens_merged: pd.DataFrame
|
23 |
+
split_sample_size: int
|
24 |
+
ds_name: str
|
25 |
+
ds_config_name: str
|
26 |
+
ds_split_name: str
|
27 |
+
split: Dataset
|
28 |
+
labels: list[str]
|
29 |
+
|
30 |
+
|
31 |
+
class Page:
|
32 |
+
"""This class encapsulates the logic for a single page of the application."""
|
33 |
+
|
34 |
+
name: str
|
35 |
+
"""The page's name that will be used in the sidebar menu."""
|
36 |
+
|
37 |
+
icon: str
|
38 |
+
"""The page's icon that will be used in the sidebar menu."""
|
39 |
+
|
40 |
+
def _get_widget_defaults(self):
|
41 |
+
"""This function holds the default settings for all widgets contained on this page.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.
|
45 |
+
"""
|
46 |
+
return {}
|
47 |
+
|
48 |
+
def render(self, context):
|
49 |
+
"""This function renders the page."""
|
50 |
+
...
|
src/subpages/probing.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors.
|
3 |
+
"""
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
from src.subpages.page import Context, Page
|
7 |
+
from src.utils import device, tag_text
|
8 |
+
|
9 |
+
_DEFAULT_SENTENCES = """
|
10 |
+
Damit hatte er auf ihr letztes , völlig schiefgelaufenes Geschäftsessen angespielt .
|
11 |
+
Damit einher geht übrigens auch , dass Marcella , Collocinis Tochter , keine wie auch immer geartete strafrechtliche Verfolgung zu befürchten hat .
|
12 |
+
Nach dem Bell ’ schen Theorem , einer Physik jenseits der Quanten , ist die Welt , die wir für real halten , nicht objektivierbar .
|
13 |
+
Dazu muss man wiederum wissen , dass die Aussagekraft von Tests , neben der Sensitivität und Spezifität , ganz entscheidend von der Vortestwahrscheinlichkeit abhängt .
|
14 |
+
Haben Sie sich schon eingelebt ? « erkundigte er sich .
|
15 |
+
Das Auto ein Totalschaden , mein Beifahrer ein weinender Jammerlappen .
|
16 |
+
Seltsam , wunderte sie sich , dass das Stück nach mehr als eineinhalb Jahrhunderten noch so gut in Schuss ist .
|
17 |
+
Oder auf den Strich gehen , Strümpfe stricken , Geld hamstern .
|
18 |
+
Und Allah ist Allumfassend Allwissend .
|
19 |
+
Und Pedro Moacir redete weiter : » Verzicht , Pater Antonio , Verzicht , zu großer Schmerz über Verzicht , Sehnsucht , die sich nicht erfüllt , die sich nicht erfüllen kann , das sind Qualen , die ein Verstummen nach sich ziehen können , oder Härte .
|
20 |
+
Mama-San ging mittlerweile fast ausnahmslos nur mit Wei an ihrer Seite aus dem Haus , kaum je mit einem der Mädchen und niemals allein.
|
21 |
+
""".strip()
|
22 |
+
_DEFAULT_SENTENCES = """
|
23 |
+
Elon Musk’s Berghain humiliation — I know the feeling
|
24 |
+
Musk was also seen at a local spot called Sisyphos celebrating entrepreneur Adeo Ressi's birthday, according to The Times.
|
25 |
+
""".strip()
|
26 |
+
|
27 |
+
|
28 |
+
class ProbingPage(Page):
|
29 |
+
name = "Probing"
|
30 |
+
icon = "fonts"
|
31 |
+
|
32 |
+
def _get_widget_defaults(self):
|
33 |
+
return {"probing_textarea": _DEFAULT_SENTENCES}
|
34 |
+
|
35 |
+
def render(self, context: Context):
|
36 |
+
st.title("🔠 Interactive Probing")
|
37 |
+
|
38 |
+
with st.expander("💡", expanded=True):
|
39 |
+
st.write(
|
40 |
+
"A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors."
|
41 |
+
)
|
42 |
+
|
43 |
+
sentences = st.text_area("Sentences", height=200, key="probing_textarea")
|
44 |
+
if not sentences.strip():
|
45 |
+
return
|
46 |
+
sentences = [sentence.strip() for sentence in sentences.splitlines()]
|
47 |
+
|
48 |
+
for sent in sentences:
|
49 |
+
sent = sent.replace(",", "").replace(" ", " ")
|
50 |
+
with st.expander(sent):
|
51 |
+
tagged = tag_text(sent, context.tokenizer, context.model, device)
|
52 |
+
tagged = tagged.astype(str)
|
53 |
+
tagged["probs"] = tagged["probs"].apply(lambda x: x[:-2])
|
54 |
+
tagged["check"] = tagged["probs"].apply(
|
55 |
+
lambda x: "✅ ✅" if int(x) < 100 else "✅" if int(x) < 1000 else ""
|
56 |
+
)
|
57 |
+
st.dataframe(tagged.drop("hidden_states", axis=1).T)
|
src/subpages/random_samples.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Show random samples. Simple method, but it often turns up interesting things."""
|
2 |
+
import pandas as pd
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
from src.subpages.page import Context, Page
|
6 |
+
from src.utils import htmlify_labeled_example
|
7 |
+
|
8 |
+
|
9 |
+
class RandomSamplesPage(Page):
|
10 |
+
name = "Random Samples"
|
11 |
+
icon = "shuffle"
|
12 |
+
|
13 |
+
def _get_widget_defaults(self):
|
14 |
+
return {
|
15 |
+
"random_sample_size_min": 128,
|
16 |
+
}
|
17 |
+
|
18 |
+
def render(self, context: Context):
|
19 |
+
st.title("🎲 Random Samples")
|
20 |
+
with st.expander("💡", expanded=True):
|
21 |
+
st.write(
|
22 |
+
"Show random samples. Simple method, but it often turns up interesting things."
|
23 |
+
)
|
24 |
+
|
25 |
+
random_sample_size = st.number_input(
|
26 |
+
"Random sample size:",
|
27 |
+
value=min(st.session_state.random_sample_size_min, context.split_sample_size),
|
28 |
+
step=16,
|
29 |
+
key="random_sample_size",
|
30 |
+
)
|
31 |
+
|
32 |
+
if st.button("🎲 Resample"):
|
33 |
+
st.experimental_rerun()
|
34 |
+
|
35 |
+
random_indices = context.df.sample(int(random_sample_size)).index
|
36 |
+
samples = context.df_tokens_merged.loc[random_indices]
|
37 |
+
|
38 |
+
for i, idx in enumerate(random_indices):
|
39 |
+
sample = samples.loc[idx]
|
40 |
+
|
41 |
+
if isinstance(sample, pd.Series):
|
42 |
+
continue
|
43 |
+
|
44 |
+
col1, _, col2 = st.columns([0.08, 0.025, 0.8])
|
45 |
+
|
46 |
+
counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: wh^; padding: 0 5px'>[{i+1} | {idx}]</span>"
|
47 |
+
loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
|
48 |
+
col1.write(f"{counter}{loss}", unsafe_allow_html=True)
|
49 |
+
col1.write("")
|
50 |
+
col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
|
src/subpages/raw_data.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""See the data as seen by your model."""
|
2 |
+
import pandas as pd
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
from src.subpages.page import Context, Page
|
6 |
+
from src.utils import aggrid_interactive_table
|
7 |
+
|
8 |
+
|
9 |
+
@st.cache
|
10 |
+
def convert_df(df):
|
11 |
+
return df.to_csv().encode("utf-8")
|
12 |
+
|
13 |
+
|
14 |
+
class RawDataPage(Page):
|
15 |
+
name = "Raw data"
|
16 |
+
icon = "qr-code"
|
17 |
+
|
18 |
+
def render(self, context: Context):
|
19 |
+
st.title(self.name)
|
20 |
+
with st.expander("💡", expanded=True):
|
21 |
+
st.write("See the data as seen by your model.")
|
22 |
+
|
23 |
+
st.subheader("Dataset")
|
24 |
+
st.code(
|
25 |
+
f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
|
26 |
+
)
|
27 |
+
|
28 |
+
st.write("**Data after processing and inference**")
|
29 |
+
|
30 |
+
processed_df = (
|
31 |
+
context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
|
32 |
+
)
|
33 |
+
cols = (
|
34 |
+
"ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
|
35 |
+
)
|
36 |
+
if "token_type_ids" not in processed_df.columns:
|
37 |
+
cols.remove("token_type_ids")
|
38 |
+
processed_df = processed_df[cols]
|
39 |
+
aggrid_interactive_table(processed_df)
|
40 |
+
processed_df_csv = convert_df(processed_df)
|
41 |
+
st.download_button(
|
42 |
+
"Download csv",
|
43 |
+
processed_df_csv,
|
44 |
+
"processed_data.csv",
|
45 |
+
"text/csv",
|
46 |
+
)
|
47 |
+
|
48 |
+
st.write("**Raw data (exploded by tokens)**")
|
49 |
+
raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore
|
50 |
+
aggrid_interactive_table(raw_data_df)
|
51 |
+
raw_data_df_csv = convert_df(raw_data_df)
|
52 |
+
st.download_button(
|
53 |
+
"Download csv",
|
54 |
+
raw_data_df_csv,
|
55 |
+
"raw_data.csv",
|
56 |
+
"text/csv",
|
57 |
+
)
|
src/utils.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import matplotlib as matplotlib
|
4 |
+
import matplotlib.cm as cm
|
5 |
+
import pandas as pd
|
6 |
+
import streamlit as st
|
7 |
+
import tokenizers
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
|
11 |
+
|
12 |
+
PROJ = Path(__file__).parent
|
13 |
+
|
14 |
+
tokenizer_hash_funcs = {
|
15 |
+
tokenizers.Tokenizer: lambda _: None,
|
16 |
+
tokenizers.AddedToken: lambda _: None,
|
17 |
+
}
|
18 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu")
|
19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
|
21 |
+
classmap = {
|
22 |
+
"O": "O",
|
23 |
+
"PER": "🙎",
|
24 |
+
"person": "🙎",
|
25 |
+
"LOC": "🌎",
|
26 |
+
"location": "🌎",
|
27 |
+
"ORG": "🏤",
|
28 |
+
"corporation": "🏤",
|
29 |
+
"product": "📱",
|
30 |
+
"creative": "🎷",
|
31 |
+
"MISC": "🎷",
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
def aggrid_interactive_table(df: pd.DataFrame) -> dict:
|
36 |
+
"""Creates an st-aggrid interactive table based on a dataframe.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
df (pd.DataFrame]): Source dataframe
|
40 |
+
Returns:
|
41 |
+
dict: The selected row
|
42 |
+
"""
|
43 |
+
options = GridOptionsBuilder.from_dataframe(
|
44 |
+
df, enableRowGroup=True, enableValue=True, enablePivot=True
|
45 |
+
)
|
46 |
+
|
47 |
+
options.configure_side_bar()
|
48 |
+
# options.configure_default_column(cellRenderer=JsCode('''function(params) {return '<a href="#samples-loss">'+params.value+'</a>'}'''))
|
49 |
+
|
50 |
+
options.configure_selection("single")
|
51 |
+
selection = AgGrid(
|
52 |
+
df,
|
53 |
+
enable_enterprise_modules=True,
|
54 |
+
gridOptions=options.build(),
|
55 |
+
theme="light",
|
56 |
+
update_mode=GridUpdateMode.NO_UPDATE,
|
57 |
+
allow_unsafe_jscode=True,
|
58 |
+
)
|
59 |
+
|
60 |
+
return selection
|
61 |
+
|
62 |
+
|
63 |
+
def explode_df(df: pd.DataFrame) -> pd.DataFrame:
|
64 |
+
"""Takes a dataframe and explodes all the fields."""
|
65 |
+
|
66 |
+
df_tokens = df.apply(pd.Series.explode)
|
67 |
+
if "losses" in df.columns:
|
68 |
+
df_tokens["losses"] = df_tokens["losses"].astype(float)
|
69 |
+
return df_tokens # type: ignore
|
70 |
+
|
71 |
+
|
72 |
+
def align_sample(row: pd.Series):
|
73 |
+
"""Uses word_ids to align all lists in a sample."""
|
74 |
+
|
75 |
+
columns = row.axes[0].to_list()
|
76 |
+
indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
|
77 |
+
|
78 |
+
out = {}
|
79 |
+
|
80 |
+
tokens = []
|
81 |
+
for i, tok in enumerate(row.tokens):
|
82 |
+
if row.word_ids[i] == -1:
|
83 |
+
continue
|
84 |
+
|
85 |
+
if row.word_ids[i] != row.word_ids[i - 1]:
|
86 |
+
tokens.append(tok.lstrip("▁").lstrip("##").rstrip("@@"))
|
87 |
+
else:
|
88 |
+
tokens[-1] += tok.lstrip("▁").lstrip("##").rstrip("@@")
|
89 |
+
out["tokens"] = tokens
|
90 |
+
|
91 |
+
if "preds" in columns:
|
92 |
+
out["preds"] = [row.preds[i] for i in indices]
|
93 |
+
|
94 |
+
if "labels" in columns:
|
95 |
+
out["labels"] = [row.labels[i] for i in indices]
|
96 |
+
|
97 |
+
if "losses" in columns:
|
98 |
+
out["losses"] = [row.losses[i] for i in indices]
|
99 |
+
|
100 |
+
if "probs" in columns:
|
101 |
+
out["probs"] = [row.probs[i] for i in indices]
|
102 |
+
|
103 |
+
if "hidden_states" in columns:
|
104 |
+
out["hidden_states"] = [row.hidden_states[i] for i in indices]
|
105 |
+
|
106 |
+
if "ids" in columns:
|
107 |
+
out["ids"] = row.ids
|
108 |
+
|
109 |
+
assert len(tokens) == len(out["preds"]), (tokens, row.tokens)
|
110 |
+
|
111 |
+
return out
|
112 |
+
|
113 |
+
|
114 |
+
@st.cache(
|
115 |
+
allow_output_mutation=True,
|
116 |
+
hash_funcs=tokenizer_hash_funcs,
|
117 |
+
)
|
118 |
+
def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
|
119 |
+
"""Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
text (str): The text to be processed
|
123 |
+
tokenizer: Tokenizer to use
|
124 |
+
model (_type_): Model to use
|
125 |
+
device (torch.device): The device we want pytorch to use for its calcultaions.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
pd.DataFrame: A data frame holding the tagged text.
|
129 |
+
"""
|
130 |
+
|
131 |
+
tokens = tokenizer(text).tokens()
|
132 |
+
tokenized = tokenizer(text, return_tensors="pt")
|
133 |
+
word_ids = [w if w is not None else -1 for w in tokenized.word_ids()]
|
134 |
+
input_ids = tokenized.input_ids.to(device)
|
135 |
+
outputs = model(input_ids, output_hidden_states=True)
|
136 |
+
preds = torch.argmax(outputs.logits, dim=2)
|
137 |
+
preds = [model.config.id2label[p] for p in preds[0].cpu().numpy()]
|
138 |
+
hidden_states = outputs.hidden_states[-1][0].detach().cpu().numpy()
|
139 |
+
# hidden_states = np.mean([hidden_states, outputs.hidden_states[0][0].detach().cpu().numpy()], axis=0)
|
140 |
+
|
141 |
+
probs = 1 // (
|
142 |
+
torch.min(F.softmax(outputs.logits, dim=-1), dim=-1).values[0].detach().cpu().numpy()
|
143 |
+
)
|
144 |
+
|
145 |
+
df = pd.DataFrame(
|
146 |
+
[[tokens, word_ids, preds, probs, hidden_states]],
|
147 |
+
columns="tokens word_ids preds probs hidden_states".split(),
|
148 |
+
)
|
149 |
+
merged_df = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
|
150 |
+
return explode_df(merged_df).reset_index().drop(columns=["index"])
|
151 |
+
|
152 |
+
|
153 |
+
def get_bg_color(label: str):
|
154 |
+
"""Retrieves a label's color from the session state."""
|
155 |
+
return st.session_state[f"color_{label}"]
|
156 |
+
|
157 |
+
|
158 |
+
def get_fg_color(bg_color_hex: str) -> str:
|
159 |
+
"""Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast.
|
160 |
+
|
161 |
+
Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/
|
162 |
+
|
163 |
+
Args:
|
164 |
+
bg_color_hex (str): The background color given as a HEX stirng.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
str: Either "black" or "white".
|
168 |
+
"""
|
169 |
+
r = int(bg_color_hex[1:3], 16)
|
170 |
+
g = int(bg_color_hex[3:5], 16)
|
171 |
+
b = int(bg_color_hex[5:7], 16)
|
172 |
+
yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
|
173 |
+
return "black" if (yiq >= 128) else "white"
|
174 |
+
|
175 |
+
|
176 |
+
def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
|
177 |
+
"""Colorizes the errors in the dataframe."""
|
178 |
+
|
179 |
+
def colorize_row(row):
|
180 |
+
return [
|
181 |
+
"background-color: "
|
182 |
+
+ ("white" if (row["labels"] == "IGN" or (row["preds"] == row["labels"])) else "pink")
|
183 |
+
+ ";"
|
184 |
+
] * len(row)
|
185 |
+
|
186 |
+
def colorize_col(col):
|
187 |
+
if col.name == "labels" or col.name == "preds":
|
188 |
+
bgs = []
|
189 |
+
fgs = []
|
190 |
+
for v in col.values:
|
191 |
+
bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
|
192 |
+
fgs.append(get_fg_color(bgs[-1]))
|
193 |
+
return [f"background-color: {bg}; color: {fg};" for bg, fg in zip(bgs, fgs)]
|
194 |
+
return [""] * len(col)
|
195 |
+
|
196 |
+
df = df.reset_index().drop(columns=["index"]).T
|
197 |
+
return df # .style.apply(colorize_col, axis=0)
|
198 |
+
|
199 |
+
|
200 |
+
def htmlify_labeled_example(example: pd.DataFrame) -> str:
|
201 |
+
"""Builds an HTML (string) representation of a single example.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
example (pd.DataFrame): The example to process.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
str: An HTML string representation of a single example.
|
208 |
+
"""
|
209 |
+
html = []
|
210 |
+
|
211 |
+
for _, row in example.iterrows():
|
212 |
+
pred = row.preds.split("-")[1] if "-" in row.preds else "O"
|
213 |
+
label = row.labels
|
214 |
+
label_class = row.labels.split("-")[1] if "-" in row.labels else "O"
|
215 |
+
|
216 |
+
color = get_bg_color(row.preds.split("-")[1]) if "-" in row.preds else "#000000"
|
217 |
+
true_color = get_bg_color(row.labels.split("-")[1]) if "-" in row.labels else "#000000"
|
218 |
+
|
219 |
+
font_color = get_fg_color(color) if color else "white"
|
220 |
+
true_font_color = get_fg_color(true_color) if true_color else "white"
|
221 |
+
|
222 |
+
is_correct = row.preds == row.labels
|
223 |
+
loss_html = (
|
224 |
+
""
|
225 |
+
if float(row.losses) < 0.01
|
226 |
+
else f"<span style='background-color: yellow; color: font_color; padding: 0 5px;'>{row.losses:.3f}</span>"
|
227 |
+
)
|
228 |
+
loss_html = ""
|
229 |
+
|
230 |
+
if row.labels == row.preds == "O":
|
231 |
+
html.append(f"<span>{row.tokens}</span>")
|
232 |
+
elif row.labels == "IGN":
|
233 |
+
assert False
|
234 |
+
else:
|
235 |
+
opacity = "1" if not is_correct else "0.5"
|
236 |
+
correct = (
|
237 |
+
""
|
238 |
+
if is_correct
|
239 |
+
else f"<span title='{label}' style='background-color: {true_color}; opacity: 1; color: {true_font_color}; padding: 0 5px; border: 1px solid black; min-width: 30px'>{classmap[label_class]}</span>"
|
240 |
+
)
|
241 |
+
pred_icon = classmap[pred] if pred != "O" and row.preds[:2] != "I-" else ""
|
242 |
+
html.append(
|
243 |
+
f"<span style='border: 1px solid black; color: {color}; padding: 0 5px;' title={row.preds}>{pred_icon + ' '}{row.tokens}</span>{correct}{loss_html}"
|
244 |
+
)
|
245 |
+
|
246 |
+
return " ".join(html)
|
247 |
+
|
248 |
+
|
249 |
+
def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
|
250 |
+
"""Turns a value into a color using a color map."""
|
251 |
+
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
|
252 |
+
cmap = cm.get_cmap(cmap_name) # PiYG
|
253 |
+
rgba = cmap(norm(abs(value)))
|
254 |
+
color = matplotlib.colors.rgb2hex(rgba[:3])
|
255 |
+
return color
|