ceyda aseifert commited on
Commit
2d4811a
·
0 Parent(s):

Duplicate from aseifert/ExplaiNER

Browse files

Co-authored-by: Alexander Seifert <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.gitignore.io/api/python,osx,linux
3
+ # Edit at https://www.gitignore.io/?templates=python,osx,linux
4
+
5
+ ### Linux ###
6
+ *~
7
+
8
+ # temporary files which can be created if a process still has a handle open of a deleted file
9
+ .fuse_hidden*
10
+
11
+ # KDE directory preferences
12
+ .directory
13
+
14
+ # Linux trash folder which might appear on any partition or disk
15
+ .Trash-*
16
+
17
+ # .nfs files are created when an open file is removed but is still being accessed
18
+ .nfs*
19
+
20
+ ### OSX ###
21
+ # General
22
+ .DS_Store
23
+ .AppleDouble
24
+ .LSOverride
25
+
26
+ # Icon must end with two \r
27
+ Icon
28
+
29
+ # Thumbnails
30
+ ._*
31
+
32
+ # Files that might appear in the root of a volume
33
+ .DocumentRevisions-V100
34
+ .fseventsd
35
+ .Spotlight-V100
36
+ .TemporaryItems
37
+ .Trashes
38
+ .VolumeIcon.icns
39
+ .com.apple.timemachine.donotpresent
40
+
41
+ # Directories potentially created on remote AFP share
42
+ .AppleDB
43
+ .AppleDesktop
44
+ Network Trash Folder
45
+ Temporary Items
46
+ .apdisk
47
+
48
+ ### Python ###
49
+ # Byte-compiled / optimized / DLL files
50
+ __pycache__/
51
+ *.py[cod]
52
+ *$py.class
53
+
54
+ # C extensions
55
+ *.so
56
+
57
+ # Distribution / packaging
58
+ .Python
59
+ build/
60
+ develop-eggs/
61
+ dist/
62
+ downloads/
63
+ eggs/
64
+ .eggs/
65
+ lib/
66
+ lib64/
67
+ parts/
68
+ sdist/
69
+ var/
70
+ wheels/
71
+ pip-wheel-metadata/
72
+ share/python-wheels/
73
+ *.egg-info/
74
+ .installed.cfg
75
+ *.egg
76
+ MANIFEST
77
+
78
+ # PyInstaller
79
+ # Usually these files are written by a python script from a template
80
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
81
+ *.manifest
82
+ *.spec
83
+
84
+ # Installer logs
85
+ pip-log.txt
86
+ pip-delete-this-directory.txt
87
+
88
+ # Unit test / coverage reports
89
+ htmlcov/
90
+ .tox/
91
+ .nox/
92
+ .coverage
93
+ .coverage.*
94
+ .cache
95
+ nosetests.xml
96
+ coverage.xml
97
+ *.cover
98
+ .hypothesis/
99
+ .pytest_cache/
100
+
101
+ # Translations
102
+ *.mo
103
+ *.pot
104
+
105
+ # Scrapy stuff:
106
+ .scrapy
107
+
108
+ # Sphinx documentation
109
+ docs/_build/
110
+
111
+ # PyBuilder
112
+ target/
113
+
114
+ # pyenv
115
+ .python-version
116
+
117
+ # pipenv
118
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
119
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
120
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
121
+ # install all needed dependencies.
122
+ #Pipfile.lock
123
+
124
+ # celery beat schedule file
125
+ celerybeat-schedule
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Spyder project settings
131
+ .spyderproject
132
+ .spyproject
133
+
134
+ # Rope project settings
135
+ .ropeproject
136
+
137
+ # Mr Developer
138
+ .mr.developer.cfg
139
+ .project
140
+ .pydevproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # End of https://www.gitignore.io/api/python,osx,linux
154
+
155
+ .idea/
156
+ .ipynb_checkpoints/
157
+ node_modules/
158
+ data/books/
159
+ docx/cache/
160
+ docx/data/
161
+ save_dir/
162
+ cache_dir/
163
+ outputs/
164
+ models/
165
+ runs/
166
+ .vscode/
167
+ doc/
168
+ html/*.html
169
+ vis2.zip
Makefile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ doc:
2
+ pdoc --docformat google src -o doc
3
+
4
+ vis2: doc
5
+ pandoc html/index.md -s -o html/index.html
6
+ rm -rf src/__pycache__ && rm -rf src/subpages/__pycache__
7
+ zip -r vis2.zip doc html src Makefile presentation.pdf requirements.txt
8
+
9
+ run:
10
+ python -m streamlit run src/app.py
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ExplaiNER
3
+ emoji: 🏷️
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ python_version: 3.9
7
+ sdk: streamlit
8
+ sdk_version: 1.10.0
9
+ app_file: src/app.py
10
+ pinned: true
11
+ duplicated_from: aseifert/ExplaiNER
12
+ ---
13
+
14
+ # 🏷️ ExplaiNER: Error Analysis for NER models & datasets
15
+
16
+ Error Analysis is an important but often overlooked part of the data science project lifecycle, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an extensive toolkit to probe any NER model/dataset combination, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further improvements.
17
+
18
+ ## Sections
19
+
20
+
21
+ ### Activations
22
+
23
+ A group of neurons tends to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model.
24
+
25
+
26
+ ### Embeddings
27
+
28
+ For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
29
+
30
+
31
+ ### Probing
32
+
33
+ A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors.
34
+
35
+
36
+ ### Metrics
37
+
38
+ The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
39
+
40
+
41
+ ### Misclassified
42
+
43
+ This page contains all misclassified examples and allows filtering by specific error types.
44
+
45
+
46
+ ### Loss by Token/Label
47
+
48
+ Show count, mean and median loss per token and label.
49
+
50
+
51
+ ### Samples by Loss
52
+
53
+ Show every example sorted by loss (descending) for close inspection.
54
+
55
+
56
+ ### Random Samples
57
+
58
+ Show random samples. Simple method, but it often turns up interesting things.
59
+
60
+
61
+ ### Find Duplicates
62
+
63
+ Find potential duplicates in the data using cosine similarity.
64
+
65
+
66
+ ### Inspect
67
+
68
+ Inspect your whole dataset, either unfiltered or by id.
69
+
70
+
71
+ ### Raw data
72
+
73
+ See the data as seen by your model.
74
+
75
+
76
+ ### Debug
77
+
78
+ Debug info.
html/index.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "🏷️ ExplaiNER"
3
+ subtitle: "Error Analysis for NER models & datasets"
4
+ ---
5
+
6
+ <div style="text-align: center">
7
+ <img src="screenshot.jpg" alt="drawing" width="480px"/>
8
+ </div>
9
+
10
+ _Error Analysis is an important but often overlooked part of the data science project lifecycle, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an extensive toolkit to probe any NER model/dataset combination, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further improvements._
11
+
12
+ [Documentation](../doc/index.html) | [Slides](../presentation.pdf) | [Github](https://github.com/aseifert/ExplaiNER)
13
+
14
+
15
+ ## Getting started
16
+
17
+ ```bash
18
+ # Install requirements
19
+ pip install -r requirements.txt # you'll need Python 3.9+
20
+
21
+ # Run
22
+ make run
23
+ ```
24
+
25
+ ## Description
26
+
27
+ Some interesting **visualization techniques** contained in this project:
28
+
29
+ * customizable visualization of neural network activation, based on the embedding layer and the feed-forward layers of the selected transformer model. ([Alammar 2021](https://aclanthology.org/2021.acl-demo.30/))
30
+ * customizable similarity map of a 2d projection of the model's final layer's hidden states, using various algorithms (a bit like the [Tensorflow Embedding Projector](https://projector.tensorflow.org/))
31
+ * inline HTML representation of samples with token-level prediction + labels (my own; see below under 'Samples by loss' for more info)
32
+
33
+
34
+ **Libraries** important to this project:
35
+
36
+ * `streamlit` for demoing (custom multi-page feature hacked in, also using session state)
37
+ * `plotly` and `matplotlib` for charting
38
+ * `transformers` for providing the models, and `datasets` for, well, the datasets
39
+ * a forked, slightly modified version of [`ecco`](https://github.com/jalammar/ecco) for visualizing the neural net activations
40
+ * `sentence_transformers` for finding potential duplicates
41
+ * `scikit-learn` for TruncatedSVD & PCA, `umap-learn` for UMAP
42
+
43
+
44
+ ## Application Sections
45
+
46
+
47
+ Activations
48
+
49
+ > A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model.
50
+
51
+
52
+ Hidden States
53
+
54
+ > For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
55
+ >
56
+ > Using these projections you can visually identify data points that end up in the wrong neighborhood, indicating prediction/labeling errors.
57
+
58
+
59
+ Probing
60
+
61
+ > A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors.
62
+
63
+
64
+ Metrics
65
+
66
+ > The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
67
+ >
68
+ > With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone.
69
+
70
+
71
+ Misclassified
72
+
73
+ > This page contains all misclassified examples and allows filtering by specific error types. Helps you get an understanding of the types of errors your model makes.
74
+
75
+
76
+ Loss by Token/Label
77
+
78
+ > Show count, mean and median loss per token and label.
79
+ >
80
+ > Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues.
81
+
82
+
83
+ Samples by Loss
84
+
85
+ > Show every example sorted by loss (descending) for close inspection.
86
+ >
87
+ > Apart from a (token-based) dataframe view, there's also an HTML representation of the samples, which is very information-dense but really helpful, once you got used to reading it:
88
+ >
89
+ > Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class.
90
+ >
91
+ > For short texts, the dataframe view can be sufficient, but for longer texts the HTML view tends to be more useful.
92
+
93
+
94
+ Random Samples
95
+
96
+ > Show random samples. Simple method, but it often turns up interesting things.
97
+
98
+
99
+ Find Duplicates
100
+
101
+ > Find potential duplicates in the data using cosine similarity.
102
+
103
+
104
+ Inspect
105
+
106
+ > Inspect your whole dataset, either unfiltered or by id.
107
+
108
+
109
+ Raw data
110
+
111
+ > See the data as seen by your model.
112
+
113
+
114
+ Debug
115
+
116
+ > Debug info.
html/screenshot.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://download.pytorch.org/whl/cpu/torch-1.11.0%2Bcpu-cp39-cp39-linux_x86_64.whl
2
+ streamlit
3
+ pandas
4
+ scikit-learn
5
+ plotly
6
+ sentence-transformers
7
+ transformers
8
+ tokenizers
9
+ datasets
10
+ numpy
11
+ matplotlib
12
+ seqeval
13
+ streamlit-aggrid
14
+ streamlit_option_menu
15
+ pdoc
16
+ git+https://github.com/aseifert/ecco.git@streamlit
src/__init__.py ADDED
File without changes
src/app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The App module is the main entry point for the application.
2
+
3
+ Run `streamlit run app.py` to start the app.
4
+ """
5
+
6
+ import pandas as pd
7
+ import streamlit as st
8
+ from streamlit_option_menu import option_menu
9
+
10
+ from src.load import load_context
11
+ from src.subpages import (
12
+ DebugPage,
13
+ FindDuplicatesPage,
14
+ HomePage,
15
+ LossesPage,
16
+ LossySamplesPage,
17
+ MetricsPage,
18
+ MisclassifiedPage,
19
+ Page,
20
+ ProbingPage,
21
+ RandomSamplesPage,
22
+ RawDataPage,
23
+ )
24
+ from src.subpages.attention import AttentionPage
25
+ from src.subpages.hidden_states import HiddenStatesPage
26
+ from src.subpages.inspect import InspectPage
27
+ from src.utils import classmap
28
+
29
+ sts = st.sidebar
30
+ st.set_page_config(
31
+ layout="wide",
32
+ page_title="Error Analysis",
33
+ page_icon="🏷️",
34
+ )
35
+
36
+
37
+ def _show_menu(pages: list[Page]) -> int:
38
+ with st.sidebar:
39
+ page_names = [p.name for p in pages]
40
+ page_icons = [p.icon for p in pages]
41
+ selected_menu_item = st.session_state.active_page = option_menu(
42
+ menu_title="ExplaiNER",
43
+ options=page_names,
44
+ icons=page_icons,
45
+ menu_icon="layout-wtf",
46
+ default_index=0,
47
+ )
48
+ return page_names.index(selected_menu_item)
49
+ assert False
50
+
51
+
52
+ def _initialize_session_state(pages: list[Page]):
53
+ if "active_page" not in st.session_state:
54
+ for page in pages:
55
+ st.session_state.update(**page._get_widget_defaults())
56
+ st.session_state.update(st.session_state)
57
+
58
+
59
+ def _write_color_legend(context):
60
+ def style(x):
61
+ return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]
62
+
63
+ labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
64
+ colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]
65
+
66
+ color_legend_df = pd.DataFrame(
67
+ [classmap[l] for l in labels], columns=["label"], index=labels
68
+ ).T
69
+ st.sidebar.write(
70
+ color_legend_df.T.style.apply(style, axis=0).set_properties(
71
+ **{"color": "white", "text-align": "center"}
72
+ )
73
+ )
74
+
75
+
76
+ def main():
77
+ """The main entry point for the application."""
78
+ pages: list[Page] = [
79
+ HomePage(),
80
+ AttentionPage(),
81
+ HiddenStatesPage(),
82
+ ProbingPage(),
83
+ MetricsPage(),
84
+ LossySamplesPage(),
85
+ LossesPage(),
86
+ MisclassifiedPage(),
87
+ RandomSamplesPage(),
88
+ FindDuplicatesPage(),
89
+ InspectPage(),
90
+ RawDataPage(),
91
+ DebugPage(),
92
+ ]
93
+
94
+ _initialize_session_state(pages)
95
+
96
+ selected_page_idx = _show_menu(pages)
97
+ selected_page = pages[selected_page_idx]
98
+
99
+ if isinstance(selected_page, HomePage):
100
+ selected_page.render()
101
+ return
102
+
103
+ if "model_name" not in st.session_state:
104
+ # this can happen if someone loads another page directly (without going through home)
105
+ st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'")
106
+ return
107
+
108
+ context = load_context(**st.session_state)
109
+ _write_color_legend(context)
110
+ selected_page.render(context)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
src/data.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import pandas as pd
4
+ import streamlit as st
5
+ import torch
6
+ from datasets import Dataset, DatasetDict, load_dataset # type: ignore
7
+ from torch.nn.functional import cross_entropy
8
+ from transformers import DataCollatorForTokenClassification # type: ignore
9
+
10
+ from src.utils import device, tokenizer_hash_funcs
11
+
12
+
13
+ @st.cache(allow_output_mutation=True)
14
+ def get_data(
15
+ ds_name: str, config_name: str, split_name: str, split_sample_size: int, randomize_sample: bool
16
+ ) -> Dataset:
17
+ """Loads a Dataset from the HuggingFace hub (if not already loaded).
18
+
19
+ Uses `datasets.load_dataset` to load the dataset (see its documentation for additional details).
20
+
21
+ Args:
22
+ ds_name (str): Path or name of the dataset.
23
+ config_name (str): Name of the dataset configuration.
24
+ split_name (str): Which split of the data to load.
25
+ split_sample_size (int): The number of examples to load from the split.
26
+
27
+ Returns:
28
+ Dataset: A Dataset object.
29
+ """
30
+ ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(
31
+ seed=0 if randomize_sample else None
32
+ ) # type: ignore
33
+ split = ds[split_name].select(range(split_sample_size))
34
+ return split
35
+
36
+
37
+ @st.cache(
38
+ allow_output_mutation=True,
39
+ hash_funcs=tokenizer_hash_funcs,
40
+ )
41
+ def get_collator(tokenizer) -> DataCollatorForTokenClassification:
42
+ """Returns a DataCollator that will dynamically pad the inputs received, as well as the labels.
43
+
44
+ Args:
45
+ tokenizer ([PreTrainedTokenizer] or [PreTrainedTokenizerFast]): The tokenizer used for encoding the data.
46
+
47
+ Returns:
48
+ DataCollatorForTokenClassification: The DataCollatorForTokenClassification object.
49
+ """
50
+ return DataCollatorForTokenClassification(tokenizer)
51
+
52
+
53
+ def create_word_ids_from_input_ids(tokenizer, input_ids: list[int]) -> list[int]:
54
+ """Takes a list of input_ids and return corresponding word_ids
55
+
56
+ Args:
57
+ tokenizer: The tokenizer that was used to obtain the input ids.
58
+ input_ids (list[int]): List of token ids.
59
+
60
+ Returns:
61
+ list[int]: Word ids corresponding to the input ids.
62
+ """
63
+ word_ids = []
64
+ wid = -1
65
+ tokens = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
66
+
67
+ for i, tok in enumerate(tokens):
68
+ if tok in tokenizer.all_special_tokens:
69
+ word_ids.append(-1)
70
+ continue
71
+
72
+ if not tokens[i - 1].endswith("@@") and tokens[i - 1] != "<unk>":
73
+ wid += 1
74
+
75
+ word_ids.append(wid)
76
+
77
+ assert len(word_ids) == len(input_ids)
78
+ return word_ids
79
+
80
+
81
+ def tokenize(batch, tokenizer) -> dict:
82
+ """Tokenizes a batch of examples.
83
+
84
+ Args:
85
+ batch: The examples to tokenize
86
+ tokenizer: The tokenizer to use
87
+
88
+ Returns:
89
+ dict: The tokenized batch
90
+ """
91
+ tokenized_inputs = tokenizer(batch["tokens"], truncation=True, is_split_into_words=True)
92
+ labels = []
93
+ wids = []
94
+
95
+ for idx, label in enumerate(batch["ner_tags"]):
96
+ try:
97
+ word_ids = tokenized_inputs.word_ids(batch_index=idx)
98
+ except ValueError:
99
+ word_ids = create_word_ids_from_input_ids(
100
+ tokenizer, tokenized_inputs["input_ids"][idx]
101
+ )
102
+ previous_word_idx = None
103
+ label_ids = []
104
+ for word_idx in word_ids:
105
+ if word_idx == -1 or word_idx is None or word_idx == previous_word_idx:
106
+ label_ids.append(-100)
107
+ else:
108
+ label_ids.append(label[word_idx])
109
+ previous_word_idx = word_idx
110
+ wids.append(word_ids)
111
+ labels.append(label_ids)
112
+ tokenized_inputs["word_ids"] = wids
113
+ tokenized_inputs["labels"] = labels
114
+ return tokenized_inputs
115
+
116
+
117
+ def stringify_ner_tags(batch: dict, tags) -> dict:
118
+ """Stringifies a dataset batch's NER tags."""
119
+ return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}
120
+
121
+
122
+ def encode_dataset(split: Dataset, tokenizer):
123
+ """Encodes a dataset split.
124
+
125
+ Args:
126
+ split (Dataset): A Dataset object.
127
+ tokenizer: A PreTrainedTokenizer object.
128
+
129
+ Returns:
130
+ Dataset: A Dataset object with the encoded inputs.
131
+ """
132
+
133
+ tags = split.features["ner_tags"].feature
134
+ split = split.map(partial(stringify_ner_tags, tags=tags), batched=True)
135
+ remove_columns = split.column_names
136
+ ids = split["id"]
137
+ split = split.map(
138
+ partial(tokenize, tokenizer=tokenizer),
139
+ batched=True,
140
+ remove_columns=remove_columns,
141
+ )
142
+ word_ids = [[id if id is not None else -1 for id in wids] for wids in split["word_ids"]]
143
+ return split.remove_columns(["word_ids"]), word_ids, ids
144
+
145
+
146
+ def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
147
+ """Runs the forward pass for a batch of examples.
148
+
149
+ Args:
150
+ batch: The batch to process
151
+ model: The model to process the batch with
152
+ collator: A data collator
153
+ num_classes (int): Number of classes
154
+
155
+ Returns:
156
+ dict: a dictionary containing `losses`, `preds` and `hidden_states`
157
+ """
158
+
159
+ # Convert dict of lists to list of dicts suitable for data collator
160
+ features = [dict(zip(batch, t)) for t in zip(*batch.values())]
161
+
162
+ # Pad inputs and labels and put all tensors on device
163
+ batch = collator(features)
164
+ input_ids = batch["input_ids"].to(device)
165
+ attention_mask = batch["attention_mask"].to(device)
166
+ labels = batch["labels"].to(device)
167
+
168
+ with torch.no_grad():
169
+ # Pass data through model
170
+ output = model(input_ids, attention_mask, output_hidden_states=True)
171
+ # logit.size: [batch_size, sequence_length, classes]
172
+
173
+ # Predict class with largest logit value on classes axis
174
+ preds = torch.argmax(output.logits, axis=-1).cpu().numpy() # type: ignore
175
+
176
+ # Calculate loss per token after flattening batch dimension with view
177
+ loss = cross_entropy(
178
+ output.logits.view(-1, num_classes), labels.view(-1), reduction="none"
179
+ )
180
+
181
+ # Unflatten batch dimension and convert to numpy array
182
+ loss = loss.view(len(input_ids), -1).cpu().numpy()
183
+ hidden_states = output.hidden_states[-1].cpu().numpy()
184
+
185
+ # logits = output.logits.view(len(input_ids), -1).cpu().numpy()
186
+
187
+ return {"losses": loss, "preds": preds, "hidden_states": hidden_states}
188
+
189
+
190
+ def predict(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
191
+ """Generates predictions for a given dataset split and returns the results as a dataframe.
192
+
193
+ Args:
194
+ split_encoded (Dataset): The dataset to process
195
+ model: The model to process the dataset with
196
+ tokenizer: The tokenizer to process the dataset with
197
+ collator: The data collator to use
198
+ tags: The tags used in the dataset
199
+
200
+ Returns:
201
+ pd.DataFrame: A dataframe containing token-level predictions.
202
+ """
203
+
204
+ split_encoded = split_encoded.map(
205
+ partial(
206
+ forward_pass_with_label,
207
+ model=model,
208
+ collator=collator,
209
+ num_classes=tags.num_classes,
210
+ ),
211
+ batched=True,
212
+ batch_size=8,
213
+ )
214
+ df: pd.DataFrame = split_encoded.to_pandas() # type: ignore
215
+
216
+ df["tokens"] = df["input_ids"].apply(
217
+ lambda x: tokenizer.convert_ids_to_tokens(x) # type: ignore
218
+ )
219
+ df["labels"] = df["labels"].apply(
220
+ lambda x: ["IGN" if i == -100 else tags.int2str(int(i)) for i in x]
221
+ )
222
+ df["preds"] = df["preds"].apply(lambda x: [model.config.id2label[i] for i in x])
223
+ df["preds"] = df.apply(lambda x: x["preds"][: len(x["input_ids"])], axis=1)
224
+ df["losses"] = df.apply(lambda x: x["losses"][: len(x["input_ids"])], axis=1)
225
+ df["hidden_states"] = df.apply(lambda x: x["hidden_states"][: len(x["input_ids"])], axis=1)
226
+ df["total_loss"] = df["losses"].apply(sum)
227
+
228
+ return df
src/load.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import pandas as pd
4
+ import streamlit as st
5
+ from datasets import Dataset # type: ignore
6
+
7
+ from src.data import encode_dataset, get_collator, get_data, predict
8
+ from src.model import get_encoder, get_model, get_tokenizer
9
+ from src.subpages import Context
10
+ from src.utils import align_sample, device, explode_df
11
+
12
+ _TOKENIZER_NAME = (
13
+ "xlm-roberta-base",
14
+ "gagan3012/bert-tiny-finetuned-ner",
15
+ "distilbert-base-german-cased",
16
+ )[0]
17
+
18
+
19
+ def _load_models_and_tokenizer(
20
+ encoder_model_name: str,
21
+ model_name: str,
22
+ tokenizer_name: Optional[str],
23
+ device: str = "cpu",
24
+ ):
25
+ sentence_encoder = get_encoder(encoder_model_name, device=device)
26
+ tokenizer = get_tokenizer(tokenizer_name if tokenizer_name else model_name)
27
+ labels = "O B-COMMA".split() if "comma" in model_name else None
28
+ model = get_model(model_name, labels=labels)
29
+ return sentence_encoder, model, tokenizer
30
+
31
+
32
+ @st.cache(allow_output_mutation=True)
33
+ def load_context(
34
+ encoder_model_name: str,
35
+ model_name: str,
36
+ ds_name: str,
37
+ ds_config_name: str,
38
+ ds_split_name: str,
39
+ split_sample_size: int,
40
+ randomize_sample: bool,
41
+ **kw_args,
42
+ ) -> Context:
43
+ """Utility method loading (almost) everything we need for the application.
44
+ This exists just because we want to cache the results of this function.
45
+
46
+ Args:
47
+ encoder_model_name (str): Name of the sentence encoder to load.
48
+ model_name (str): Name of the NER model to load.
49
+ ds_name (str): Dataset name or path.
50
+ ds_config_name (str): Dataset config name.
51
+ ds_split_name (str): Dataset split name.
52
+ split_sample_size (int): Number of examples to load from the split.
53
+
54
+ Returns:
55
+ Context: An object containing everything we need for the application.
56
+ """
57
+
58
+ sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
59
+ encoder_model_name=encoder_model_name,
60
+ model_name=model_name,
61
+ tokenizer_name=_TOKENIZER_NAME if "comma" in model_name else None,
62
+ device=str(device),
63
+ )
64
+ collator = get_collator(tokenizer)
65
+
66
+ # load data related stuff
67
+ split: Dataset = get_data(
68
+ ds_name, ds_config_name, ds_split_name, split_sample_size, randomize_sample
69
+ )
70
+ tags = split.features["ner_tags"].feature
71
+ split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
72
+
73
+ # transform into dataframe
74
+ df = predict(split_encoded, model, tokenizer, collator, tags)
75
+ df["word_ids"] = word_ids
76
+ df["ids"] = ids
77
+
78
+ # explode, clean, merge
79
+ df_tokens = explode_df(df)
80
+ df_tokens_cleaned = df_tokens.query("labels != 'IGN'")
81
+ df_merged = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
82
+ df_tokens_merged = explode_df(df_merged)
83
+
84
+ return Context(
85
+ **{
86
+ "model": model,
87
+ "tokenizer": tokenizer,
88
+ "sentence_encoder": sentence_encoder,
89
+ "df": df,
90
+ "df_tokens": df_tokens,
91
+ "df_tokens_cleaned": df_tokens_cleaned,
92
+ "df_tokens_merged": df_tokens_merged,
93
+ "tags": tags,
94
+ "labels": tags.names,
95
+ "split_sample_size": split_sample_size,
96
+ "ds_name": ds_name,
97
+ "ds_config_name": ds_config_name,
98
+ "ds_split_name": ds_split_name,
99
+ "split": split,
100
+ }
101
+ )
src/model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import AutoModelForTokenClassification # type: ignore
4
+ from transformers import AutoTokenizer # type: ignore
5
+
6
+
7
+ @st.experimental_singleton()
8
+ def get_model(model_name: str, labels=None):
9
+ if labels is None:
10
+ return AutoModelForTokenClassification.from_pretrained(
11
+ model_name,
12
+ output_attentions=True,
13
+ ) # type: ignore
14
+ else:
15
+ id2label = {idx: tag for idx, tag in enumerate(labels)}
16
+ label2id = {tag: idx for idx, tag in enumerate(labels)}
17
+ return AutoModelForTokenClassification.from_pretrained(
18
+ model_name,
19
+ output_attentions=True,
20
+ num_labels=len(labels),
21
+ id2label=id2label,
22
+ label2id=label2id,
23
+ ) # type: ignore
24
+
25
+
26
+ @st.experimental_singleton()
27
+ def get_encoder(model_name: str, device: str = "cpu"):
28
+ return SentenceTransformer(model_name, device=device)
29
+
30
+
31
+ @st.experimental_singleton()
32
+ def get_tokenizer(tokenizer_name: str):
33
+ return AutoTokenizer.from_pretrained(tokenizer_name)
src/subpages/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.subpages.attention import AttentionPage
2
+ from src.subpages.debug import DebugPage
3
+ from src.subpages.find_duplicates import FindDuplicatesPage
4
+ from src.subpages.hidden_states import HiddenStatesPage
5
+ from src.subpages.home import HomePage
6
+ from src.subpages.inspect import InspectPage
7
+ from src.subpages.losses import LossesPage
8
+ from src.subpages.lossy_samples import LossySamplesPage
9
+ from src.subpages.metrics import MetricsPage
10
+ from src.subpages.misclassified import MisclassifiedPage
11
+ from src.subpages.page import Context, Page
12
+ from src.subpages.probing import ProbingPage
13
+ from src.subpages.random_samples import RandomSamplesPage
14
+ from src.subpages.raw_data import RawDataPage
src/subpages/attention.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model.
3
+ """
4
+ import ecco
5
+ import streamlit as st
6
+ from streamlit.components.v1 import html
7
+
8
+ from src.subpages.page import Context, Page # type: ignore
9
+
10
+ _SETUP_HTML = """
11
+ <script src="https://requirejs.org/docs/release/2.3.6/minified/require.js"></script>
12
+ <script>
13
+ var ecco_url = 'https://storage.googleapis.com/ml-intro/ecco/'
14
+ //var ecco_url = 'http://localhost:8000/'
15
+
16
+ if (window.ecco === undefined) window.ecco = {}
17
+
18
+ // Setup the paths of the script we'll be using
19
+ requirejs.config({
20
+ urlArgs: "bust=" + (new Date()).getTime(),
21
+ nodeRequire: require,
22
+ paths: {
23
+ d3: "https://d3js.org/d3.v6.min", // This is only for use in setup.html and basic.html
24
+ "d3-array": "https://d3js.org/d3-array.v2.min",
25
+ jquery: "https://code.jquery.com/jquery-3.5.1.min",
26
+ ecco: ecco_url + 'js/0.0.6/ecco-bundle.min',
27
+ xregexp: 'https://cdnjs.cloudflare.com/ajax/libs/xregexp/3.2.0/xregexp-all.min'
28
+ }
29
+ });
30
+
31
+ // Add the css file
32
+ //requirejs(['d3'],
33
+ // function (d3) {
34
+ // d3.select('#css').attr('href', ecco_url + 'html/styles.css')
35
+ // })
36
+
37
+ console.log('Ecco initialize!!')
38
+
39
+ // returns a 'basic' object. basic.init() selects the html div we'll be
40
+ // rendering the html into, adds styles.css to the document.
41
+ define('basic', ['d3'],
42
+ function (d3) {
43
+ return {
44
+ init: function (viz_id = null) {
45
+ if (viz_id == null) {
46
+ viz_id = "viz_" + Math.round(Math.random() * 10000000)
47
+ }
48
+ // Select the div rendered below, change its id
49
+ const div = d3.select('#basic').attr('id', viz_id),
50
+ div_parent = d3.select('#' + viz_id).node().parentNode
51
+
52
+ // Link to CSS file
53
+ d3.select(div_parent).insert('link')
54
+ .attr('rel', 'stylesheet')
55
+ .attr('type', 'text/css')
56
+ .attr('href', ecco_url + 'html/0.0.2/styles.css')
57
+
58
+ return viz_id
59
+ }
60
+ }
61
+ }, function (err) {
62
+ console.log(err);
63
+ }
64
+ )
65
+ </script>
66
+
67
+ <head>
68
+ <link id='css' rel="stylesheet" type="text/css">
69
+ </head>
70
+ <div id="basic"></div>
71
+ """
72
+
73
+
74
+ @st.cache(allow_output_mutation=True)
75
+ def _load_ecco_model():
76
+ model_config = {
77
+ "embedding": "embeddings.word_embeddings",
78
+ "type": "mlm",
79
+ "activations": [r"ffn\.lin1"],
80
+ "token_prefix": "",
81
+ "partial_token_prefix": "##",
82
+ }
83
+ return ecco.from_pretrained(
84
+ "elastic/distilbert-base-uncased-finetuned-conll03-english",
85
+ model_config=model_config,
86
+ activations=True,
87
+ )
88
+
89
+
90
+ class AttentionPage(Page):
91
+ name = "Activations"
92
+ icon = "activity"
93
+
94
+ def _get_widget_defaults(self):
95
+ return {
96
+ "act_n_components": 8,
97
+ "act_default_text": """Now I ask you: what can be expected of man since he is a being endowed with strange qualities? Shower upon him every earthly blessing, drown him in a sea of happiness, so that nothing but bubbles of bliss can be seen on the surface; give him economic prosperity, such that he should have nothing else to do but sleep, eat cakes and busy himself with the continuation of his species, and even then out of sheer ingratitude, sheer spite, man would play you some nasty trick. He would even risk his cakes and would deliberately desire the most fatal rubbish, the most uneconomical absurdity, simply to introduce into all this positive good sense his fatal fantastic element. It is just his fantastic dreams, his vulgar folly that he will desire to retain, simply in order to prove to himself--as though that were so necessary-- that men still are men and not the keys of a piano, which the laws of nature threaten to control so completely that soon one will be able to desire nothing but by the calendar. And that is not all: even if man really were nothing but a piano-key, even if this were proved to him by natural science and mathematics, even then he would not become reasonable, but would purposely do something perverse out of simple ingratitude, simply to gain his point. And if he does not find means he will contrive destruction and chaos, will contrive sufferings of all sorts, only to gain his point! He will launch a curse upon the world, and as only man can curse (it is his privilege, the primary distinction between him and other animals), may be by his curse alone he will attain his object--that is, convince himself that he is a man and not a piano-key!""",
98
+ "act_from_layer": 0,
99
+ "act_to_layer": 5,
100
+ }
101
+
102
+ def render(self, context: Context):
103
+ st.title(self.name)
104
+
105
+ with st.expander("ℹ️", expanded=True):
106
+ st.write(
107
+ "A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model."
108
+ )
109
+
110
+ lm = _load_ecco_model()
111
+
112
+ col1, _, col2 = st.columns([1.5, 0.5, 4])
113
+ with col1:
114
+ st.subheader("Settings")
115
+ n_components = st.slider(
116
+ "#components",
117
+ key="act_n_components",
118
+ min_value=2,
119
+ max_value=10,
120
+ step=1,
121
+ )
122
+ from_layer = st.slider(
123
+ "from layer",
124
+ key="act_from_layer",
125
+ value=0,
126
+ min_value=0,
127
+ max_value=len(lm.model.transformer.layer) - 1,
128
+ step=1,
129
+ )
130
+ to_layer = (
131
+ st.slider(
132
+ "to layer",
133
+ key="act_to_layer",
134
+ value=0,
135
+ min_value=0,
136
+ max_value=len(lm.model.transformer.layer) - 1,
137
+ step=1,
138
+ )
139
+ + 1
140
+ )
141
+
142
+ if to_layer <= from_layer:
143
+ st.error("to_layer must be >= from_layer")
144
+ st.stop()
145
+
146
+ with col2:
147
+ st.subheader("–")
148
+ text = st.text_area("Text", key="act_default_text", height=240)
149
+
150
+ inputs = lm.tokenizer([text], return_tensors="pt")
151
+ output = lm(inputs)
152
+ nmf = output.run_nmf(n_components=n_components, from_layer=from_layer, to_layer=to_layer)
153
+ data = nmf.explore(returnData=True)
154
+ _JS_TEMPLATE = f"""<script>requirejs(['basic', 'ecco'], function(basic, ecco){{
155
+ const viz_id = basic.init()
156
+ ecco.interactiveTokensAndFactorSparklines(viz_id, {data}, {{ 'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}} }} }})
157
+ }}, function (err) {{
158
+ console.log(err);
159
+ }})</script>"""
160
+ html(_SETUP_HTML + _JS_TEMPLATE, height=800, scrolling=True)
src/subpages/debug.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pip._internal.operations import freeze
3
+
4
+ from src.subpages.page import Context, Page
5
+
6
+
7
+ class DebugPage(Page):
8
+ name = "Debug"
9
+ icon = "bug"
10
+
11
+ def render(self, context: Context):
12
+ st.title(self.name)
13
+ # with st.expander("💡", expanded=True):
14
+ # st.write("Some debug info.")
15
+
16
+ st.subheader("Installed Packages")
17
+ # get output of pip freeze from system
18
+ with st.expander("pip freeze"):
19
+ st.code("\n".join(freeze.freeze()))
20
+
21
+ st.subheader("Streamlit Session State")
22
+ st.json(st.session_state)
23
+ st.subheader("Tokenizer")
24
+ st.code(context.tokenizer)
25
+ st.subheader("Model")
26
+ st.code(context.model.config)
27
+ st.code(context.model)
src/subpages/emoji-en-US.json ADDED
The diff for this file is too large to render. See raw diff
 
src/subpages/faiss.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from datasets import Dataset
3
+
4
+ from src.subpages.page import Context, Page # type: ignore
5
+ from src.utils import device, explode_df, htmlify_labeled_example, tag_text
6
+
7
+
8
+ class FaissPage(Page):
9
+ name = "Bla"
10
+ icon = "x-octagon"
11
+
12
+ def render(self, context: Context):
13
+ dd = Dataset.from_pandas(context.df_tokens_merged, preserve_index=False) # type: ignore
14
+
15
+ dd.add_faiss_index(column="hidden_states", index_name="token_index")
16
+ token_id, text = (
17
+ 6,
18
+ "Die Wissenschaft ist eine wichtige Grundlage für die Entwicklung von neuen Technologien.",
19
+ )
20
+ token_id, text = (
21
+ 15,
22
+ "Außer der unbewussten Beeinflussung eines Resultats gibt es auch noch andere Motive die das reine strahlende Licht der Wissenschaft etwas zu trüben vermögen.",
23
+ )
24
+ token_id, text = (
25
+ 3,
26
+ "Mit mehr Instrumenten einer besseren präziseren Datenbasis ist auch ein viel besseres smarteres Risikomanagement möglich.",
27
+ )
28
+ token_id, text = (
29
+ 7,
30
+ "Es gilt die akademische Viertelstunde das heißt Beginn ist fünfzehn Minuten später.",
31
+ )
32
+ token_id, text = (
33
+ 7,
34
+ "Damit einher geht übrigens auch dass Marcella Collocinis Tochter keine wie auch immer geartete strafrechtliche Verfolgung zu befürchten hat.",
35
+ )
36
+ token_id, text = (
37
+ 16,
38
+ "After Steve Jobs met with Bill Gates of Microsoft back in 1993, they went to Cupertino and made the deal.",
39
+ )
40
+
41
+ tagged = tag_text(text, context.tokenizer, context.model, device)
42
+ hidden_states = tagged["hidden_states"]
43
+ # tagged.drop("hidden_states", inplace=True, axis=1)
44
+ # hidden_states_vec = svd.transform([hidden_states[token_id]])[0].astype(np.float32)
45
+ hidden_states_vec = hidden_states[token_id]
46
+ tagged = tagged.astype(str)
47
+ tagged["probs"] = tagged["probs"].apply(lambda x: x[:-2])
48
+ tagged["check"] = tagged["probs"].apply(
49
+ lambda x: "✅ ✅" if int(x) < 100 else "✅" if int(x) < 1000 else ""
50
+ )
51
+ st.dataframe(tagged.drop("hidden_states", axis=1).T)
52
+ results = dd.get_nearest_examples("token_index", hidden_states_vec, k=10)
53
+ for i, (dist, idx, token) in enumerate(
54
+ zip(results.scores, results.examples["ids"], results.examples["tokens"])
55
+ ):
56
+ st.code(f"{dist:.3f} {token}")
57
+ sample = context.df_tokens_merged.query(f"ids == '{idx}'")
58
+ st.write(f"[{i};{idx}] " + htmlify_labeled_example(sample), unsafe_allow_html=True)
src/subpages/find_duplicates.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Find potential duplicates in the data using cosine similarity."""
2
+ import streamlit as st
3
+ from sentence_transformers.util import cos_sim
4
+
5
+ from src.subpages.page import Context, Page
6
+
7
+
8
+ @st.cache()
9
+ def get_sims(texts: list[str], sentence_encoder):
10
+ embeddings = sentence_encoder.encode(texts, batch_size=8, convert_to_numpy=True)
11
+ return cos_sim(embeddings, embeddings)
12
+
13
+
14
+ class FindDuplicatesPage(Page):
15
+ name = "Find Duplicates"
16
+ icon = "fingerprint"
17
+
18
+ def _get_widget_defaults(self):
19
+ return {
20
+ "cutoff": 0.95,
21
+ }
22
+
23
+ def render(self, context: Context):
24
+ st.title("Find Duplicates")
25
+ with st.expander("💡", expanded=True):
26
+ st.write("Find potential duplicates in the data using cosine similarity.")
27
+
28
+ cutoff = st.slider("Similarity threshold", min_value=0.0, max_value=1.0, key="cutoff")
29
+ # split.add_faiss_index(column="embeddings", index_name="sent_index")
30
+ # st.write("Index is ready")
31
+ # sentence_encoder.encode(["hello world"], batch_size=8)
32
+ # st.write(split["tokens"][0])
33
+ texts = [" ".join(ts) for ts in context.split["tokens"]]
34
+ sims = get_sims(texts, context.sentence_encoder)
35
+
36
+ candidates = []
37
+ for i in range(len(sims)):
38
+ for j in range(i + 1, len(sims)):
39
+ if sims[i][j] >= cutoff:
40
+ candidates.append((sims[i][j], i, j))
41
+ candidates.sort(reverse=False)
42
+
43
+ for (sim, i, j) in candidates[:100]:
44
+ st.markdown(f"**Possible duplicate ({i}, {j}, sim: {sim:.3f}):**")
45
+ st.markdown("* " + " ".join(context.split["tokens"][i]))
46
+ st.markdown("* " + " ".join(context.split["tokens"][j]))
47
+
48
+ # st.write("queries")
49
+ # results = split.get_nearest_examples("sent_index", np.array(split["embeddings"][0], dtype=np.float32), k=2)
50
+ # results = split.get_nearest_examples_batch("sent_index", queries, k=2)
51
+ # st.write(results.total_examples[0]["id"][1])
52
+ # st.write(results.total_examples[0])
src/subpages/hidden_states.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
3
+ """
4
+ import numpy as np
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ import streamlit as st
8
+
9
+ from src.subpages.page import Context, Page
10
+
11
+
12
+ @st.cache
13
+ def reduce_dim_svd(X, n_iter: int, random_state=42):
14
+ """Dimensionality reduction using truncated SVD (aka LSA).
15
+
16
+ This transformer performs linear dimensionality reduction by means of truncated singular value decomposition (SVD). Contrary to PCA, this estimator does not center the data before computing the singular value decomposition. This means it can work with sparse matrices efficiently.
17
+
18
+ Args:
19
+ X: Training data
20
+ n_iter (int): Desired dimensionality of output data. Must be strictly less than the number of features.
21
+ random_state (int, optional): Used during randomized svd. Pass an int for reproducible results across multiple function calls. Defaults to 42.
22
+
23
+ Returns:
24
+ ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
25
+ """
26
+ from sklearn.decomposition import TruncatedSVD
27
+
28
+ svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
29
+ return svd.fit_transform(X)
30
+
31
+
32
+ @st.cache
33
+ def reduce_dim_pca(X, random_state=42):
34
+ """Principal component analysis (PCA).
35
+
36
+ Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. The input data is centered but not scaled for each feature before applying the SVD.
37
+
38
+ Args:
39
+ X: Training data
40
+ random_state (int, optional): Used when the 'arpack' or 'randomized' solvers are used. Pass an int for reproducible results across multiple function calls.
41
+
42
+ Returns:
43
+ ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
44
+ """
45
+ from sklearn.decomposition import PCA
46
+
47
+ return PCA(n_components=2, random_state=random_state).fit_transform(X)
48
+
49
+
50
+ @st.cache
51
+ def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
52
+ """Uniform Manifold Approximation and Projection
53
+
54
+ Finds a low dimensional embedding of the data that approximates an underlying manifold.
55
+
56
+ Args:
57
+ X: Training data
58
+ n_neighbors (int, optional): The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. Defaults to 5.
59
+ min_dist (float, optional): The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. The value should be set relative to the `spread` value, which determines the scale at which embedded points will be spread out. Defaults to 0.1.
60
+ metric (str, optional): The metric to use to compute distances in high dimensional space (see UMAP docs for options). Defaults to "euclidean".
61
+
62
+ Returns:
63
+ ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
64
+ """
65
+ from umap import UMAP
66
+
67
+ return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
68
+
69
+
70
+ class HiddenStatesPage(Page):
71
+ name = "Hidden States"
72
+ icon = "grid-3x3"
73
+
74
+ def _get_widget_defaults(self):
75
+ return {
76
+ "n_tokens": 1_000,
77
+ "svd_n_iter": 5,
78
+ "svd_random_state": 42,
79
+ "umap_n_neighbors": 15,
80
+ "umap_metric": "euclidean",
81
+ "umap_min_dist": 0.1,
82
+ }
83
+
84
+ def render(self, context: Context):
85
+ st.title("Embeddings")
86
+
87
+ with st.expander("💡", expanded=True):
88
+ st.write(
89
+ "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements signified by a small black border."
90
+ )
91
+
92
+ col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
93
+ df = context.df_tokens_merged.copy()
94
+ dim_algo = "SVD"
95
+ n_tokens = 100
96
+
97
+ with col1:
98
+ st.subheader("Settings")
99
+ n_tokens = st.slider(
100
+ "#tokens",
101
+ key="n_tokens",
102
+ min_value=100,
103
+ max_value=len(df["tokens"].unique()),
104
+ step=100,
105
+ )
106
+
107
+ dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
108
+ if dim_algo == "SVD":
109
+ svd_n_iter = st.slider(
110
+ "#iterations",
111
+ key="svd_n_iter",
112
+ min_value=1,
113
+ max_value=10,
114
+ step=1,
115
+ )
116
+ elif dim_algo == "UMAP":
117
+ umap_n_neighbors = st.slider(
118
+ "#neighbors",
119
+ key="umap_n_neighbors",
120
+ min_value=2,
121
+ max_value=100,
122
+ step=1,
123
+ )
124
+ umap_min_dist = st.number_input(
125
+ "Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
126
+ )
127
+ umap_metric = st.selectbox(
128
+ "Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
129
+ )
130
+ else:
131
+ pass
132
+
133
+ with col2:
134
+ sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
135
+
136
+ X = np.array(df["hidden_states"].tolist())
137
+ transformed_hidden_states = None
138
+ if dim_algo == "SVD":
139
+ transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore
140
+ elif dim_algo == "PCA":
141
+ transformed_hidden_states = reduce_dim_pca(X)
142
+ elif dim_algo == "UMAP":
143
+ transformed_hidden_states = reduce_dim_umap(
144
+ X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
145
+ )
146
+
147
+ assert isinstance(transformed_hidden_states, np.ndarray)
148
+ df["x"] = transformed_hidden_states[:, 0]
149
+ df["y"] = transformed_hidden_states[:, 1]
150
+ df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
151
+ df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
152
+ df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
153
+ df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
154
+ df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
155
+ df["disagreements"] = df["labels"] != df["preds"]
156
+
157
+ subset = df[:n_tokens]
158
+ disagreements_trace = go.Scatter(
159
+ x=subset[subset["disagreements"]]["x"],
160
+ y=subset[subset["disagreements"]]["y"],
161
+ mode="markers",
162
+ marker=dict(
163
+ size=6,
164
+ color="rgba(0,0,0,0)",
165
+ line=dict(width=1),
166
+ ),
167
+ hoverinfo="skip",
168
+ )
169
+
170
+ st.subheader("Projection Results")
171
+
172
+ fig = px.scatter(
173
+ subset,
174
+ x="x",
175
+ y="y",
176
+ color="labels",
177
+ hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"],
178
+ hover_name="tokens",
179
+ title="Colored by label",
180
+ )
181
+ fig.add_trace(disagreements_trace)
182
+ st.plotly_chart(fig)
183
+
184
+ fig = px.scatter(
185
+ subset,
186
+ x="x",
187
+ y="y",
188
+ color="preds",
189
+ hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"],
190
+ hover_name="tokens",
191
+ title="Colored by prediction",
192
+ )
193
+ fig.add_trace(disagreements_trace)
194
+ st.plotly_chart(fig)
src/subpages/home.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from typing import Optional
4
+
5
+ import streamlit as st
6
+
7
+ from src.data import get_data
8
+ from src.subpages.page import Context, Page
9
+ from src.utils import PROJ, classmap, color_map_color
10
+
11
+ _SENTENCE_ENCODER_MODEL = (
12
+ "sentence-transformers/all-MiniLM-L6-v2",
13
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
14
+ )[0]
15
+ _MODEL_NAME = (
16
+ "elastic/distilbert-base-uncased-finetuned-conll03-english",
17
+ "gagan3012/bert-tiny-finetuned-ner",
18
+ "socialmediaie/bertweet-base_wnut17_ner",
19
+ "sberbank-ai/bert-base-NER-reptile-5-datasets",
20
+ "aseifert/comma-xlm-roberta-base",
21
+ "dslim/bert-base-NER",
22
+ "aseifert/distilbert-base-german-cased-comma-derstandard",
23
+ )[0]
24
+ _DATASET_NAME = (
25
+ "conll2003",
26
+ "wnut_17",
27
+ "aseifert/comma",
28
+ )[0]
29
+ _CONFIG_NAME = (
30
+ "conll2003",
31
+ "wnut_17",
32
+ "seifertverlag",
33
+ )[0]
34
+
35
+
36
+ class HomePage(Page):
37
+ name = "Home / Setup"
38
+ icon = "house"
39
+
40
+ def _get_widget_defaults(self):
41
+ return {
42
+ "encoder_model_name": _SENTENCE_ENCODER_MODEL,
43
+ "model_name": _MODEL_NAME,
44
+ "ds_name": _DATASET_NAME,
45
+ "ds_split_name": "validation",
46
+ "ds_config_name": _CONFIG_NAME,
47
+ "split_sample_size": 512,
48
+ "randomize_sample": True,
49
+ }
50
+
51
+ def render(self, context: Optional[Context] = None):
52
+ st.title("ExplaiNER")
53
+
54
+ with st.expander("💡", expanded=True):
55
+ st.write(
56
+ "**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further **improving both model AND dataset**."
57
+ )
58
+ st.write(
59
+ "**Note:** This Space requires a fair amount of computation, so please be patient with the loading animations. 🙏 I am caching as much as possible, so after the first wait most things should be precomputed."
60
+ )
61
+ st.write(
62
+ "_Caveat: Even though everything is customizable here, I haven't tested this app much with different models/datasets._"
63
+ )
64
+
65
+ col1, _, col2a, col2b = st.columns([0.8, 0.05, 0.15, 0.15])
66
+
67
+ with col1:
68
+ random_form_key = f"settings-{random.randint(0, 100000)}"
69
+ # FIXME: for some reason I'm getting the following error if I don't randomize the key:
70
+ """
71
+ 2022-05-05 20:37:16.507 Traceback (most recent call last):
72
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/scriptrunner/script_runner.py", line 443, in _run_script
73
+ exec(code, module.__dict__)
74
+ File "/Users/zoro/code/error-analysis/main.py", line 162, in <module>
75
+ main()
76
+ File "/Users/zoro/code/error-analysis/main.py", line 102, in main
77
+ show_setup()
78
+ File "/Users/zoro/code/error-analysis/section/setup.py", line 68, in show_setup
79
+ st.form_submit_button("Load Model & Data")
80
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 240, in form_submit_button
81
+ return self._form_submit_button(
82
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 260, in _form_submit_button
83
+ return self.dg._button(
84
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/button.py", line 304, in _button
85
+ check_session_state_rules(default_value=None, key=key, writes_allowed=False)
86
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/utils.py", line 74, in check_session_state_rules
87
+ raise StreamlitAPIException(
88
+ streamlit.errors.StreamlitAPIException: Values for st.button, st.download_button, st.file_uploader, and st.form cannot be set using st.session_state.
89
+ """
90
+ with st.form(key=random_form_key):
91
+ st.subheader("Model & Data Selection")
92
+ st.text_input(
93
+ label="NER Model:",
94
+ key="model_name",
95
+ help="Path or name of the model to use",
96
+ )
97
+ st.text_input(
98
+ label="Encoder Model:",
99
+ key="encoder_model_name",
100
+ help="Path or name of the encoder to use for duplicate detection",
101
+ )
102
+ ds_name = st.text_input(
103
+ label="Dataset:",
104
+ key="ds_name",
105
+ help="Path or name of the dataset to use",
106
+ )
107
+ ds_config_name = st.text_input(
108
+ label="Config (optional):",
109
+ key="ds_config_name",
110
+ )
111
+ ds_split_name = st.selectbox(
112
+ label="Split:",
113
+ options=["train", "validation", "test"],
114
+ key="ds_split_name",
115
+ )
116
+ split_sample_size = st.number_input(
117
+ "Sample size:",
118
+ step=16,
119
+ key="split_sample_size",
120
+ help="Sample size for the split, speeds up processing inside streamlit",
121
+ )
122
+ randomize_sample = st.checkbox(
123
+ "Randomize sample",
124
+ key="randomize_sample",
125
+ help="Whether to randomize the sample",
126
+ )
127
+ # breakpoint()
128
+ # st.form_submit_button("Submit")
129
+ st.form_submit_button("Load Model & Data")
130
+
131
+ split = get_data(
132
+ ds_name, ds_config_name, ds_split_name, split_sample_size, randomize_sample # type: ignore
133
+ )
134
+ labels = list(
135
+ set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"])
136
+ )
137
+
138
+ with col2a:
139
+ st.subheader("Classes")
140
+ st.write("**Color**")
141
+ colors = {label: color_map_color(i / len(labels)) for i, label in enumerate(labels)}
142
+ for label in labels:
143
+ if f"color_{label}" not in st.session_state:
144
+ st.session_state[f"color_{label}"] = colors[label]
145
+ st.color_picker(label, key=f"color_{label}")
146
+ with col2b:
147
+ st.subheader("—")
148
+ st.write("**Icon**")
149
+ emojis = list(json.load(open(PROJ / "subpages/emoji-en-US.json")).keys())
150
+ for label in labels:
151
+ if f"icon_{label}" not in st.session_state:
152
+ st.session_state[f"icon_{label}"] = classmap[label]
153
+ st.selectbox(label, key=f"icon_{label}", options=emojis)
154
+ classmap[label] = st.session_state[f"icon_{label}"]
155
+
156
+ # if st.button("Reset to defaults"):
157
+ # st.session_state.update(**get_home_page_defaults())
158
+ # # time.sleep 2 secs
159
+ # import time
160
+ # time.sleep(1)
161
+
162
+ # # st.legacy_caching.clear_cache()
163
+ # st.experimental_rerun()
src/subpages/inspect.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inspect your whole dataset, either unfiltered or by id."""
2
+ import streamlit as st
3
+
4
+ from src.subpages.page import Context, Page
5
+ from src.utils import aggrid_interactive_table, colorize_classes
6
+
7
+
8
+ class InspectPage(Page):
9
+ name = "Inspect"
10
+ icon = "search"
11
+
12
+ def render(self, context: Context):
13
+ st.title(self.name)
14
+ with st.expander("💡", expanded=True):
15
+ st.write("Inspect your whole dataset, either unfiltered or by id.")
16
+
17
+ df = context.df_tokens
18
+ cols = (
19
+ "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
20
+ )
21
+ if "token_type_ids" not in df.columns:
22
+ cols.remove("token_type_ids")
23
+ df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols]
24
+
25
+ if st.checkbox("Filter by id", value=True):
26
+ ids = list(sorted(map(int, df.ids.unique())))
27
+ next_id = st.session_state.get("next_id", 0)
28
+
29
+ example_id = st.selectbox("Select an example", ids, index=next_id)
30
+ df = df[df.ids == str(example_id)][1:-1]
31
+ # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses")) # type: ignore
32
+ st.dataframe(colorize_classes(df.round(3).astype(str)))
33
+
34
+ # if st.button("➡️ Next example"):
35
+ # st.session_state.next_id = (ids.index(example_id) + 1) % len(ids)
36
+ # st.experimental_rerun()
37
+ # if st.button("⬅️ Previous example"):
38
+ # st.session_state.next_id = (ids.index(example_id) - 1) % len(ids)
39
+ # st.experimental_rerun()
40
+ else:
41
+ aggrid_interactive_table(df.round(3))
src/subpages/losses.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Show count, mean and median loss per token and label."""
2
+ import streamlit as st
3
+
4
+ from src.subpages.page import Context, Page
5
+ from src.utils import AgGrid, aggrid_interactive_table
6
+
7
+
8
+ @st.cache
9
+ def get_loss_by_token(df_tokens):
10
+ return (
11
+ df_tokens.groupby("tokens")[["losses"]]
12
+ .agg(["count", "mean", "median", "sum"])
13
+ .droplevel(level=0, axis=1) # Get rid of multi-level columns
14
+ .sort_values(by="sum", ascending=False)
15
+ .reset_index()
16
+ )
17
+
18
+
19
+ @st.cache
20
+ def get_loss_by_label(df_tokens):
21
+ return (
22
+ df_tokens.groupby("labels")[["losses"]]
23
+ .agg(["count", "mean", "median", "sum"])
24
+ .droplevel(level=0, axis=1)
25
+ .sort_values(by="mean", ascending=False)
26
+ .reset_index()
27
+ )
28
+
29
+
30
+ class LossesPage(Page):
31
+ name = "Loss by Token/Label"
32
+ icon = "sort-alpha-down"
33
+
34
+ def render(self, context: Context):
35
+ st.title(self.name)
36
+ with st.expander("💡", expanded=True):
37
+ st.write("Show count, mean and median loss per token and label.")
38
+ st.write(
39
+ "Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues."
40
+ )
41
+
42
+ col1, _, col2 = st.columns([8, 1, 6])
43
+
44
+ with col1:
45
+ st.subheader("💬 Loss by Token")
46
+
47
+ st.session_state["_merge_tokens"] = st.checkbox(
48
+ "Merge tokens", value=True, key="merge_tokens"
49
+ )
50
+ loss_by_token = (
51
+ get_loss_by_token(context.df_tokens_merged)
52
+ if st.session_state["merge_tokens"]
53
+ else get_loss_by_token(context.df_tokens_cleaned)
54
+ )
55
+ aggrid_interactive_table(loss_by_token.round(3))
56
+ # st.subheader("🏷️ Loss by Label")
57
+ # loss_by_label = get_loss_by_label(df_tokens_cleaned)
58
+ # st.dataframe(loss_by_label)
59
+
60
+ st.write(
61
+ "_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._"
62
+ )
63
+
64
+ with col2:
65
+ st.subheader("🏷️ Loss by Label")
66
+ loss_by_label = get_loss_by_label(context.df_tokens_cleaned)
67
+ AgGrid(loss_by_label.round(3), height=200)
src/subpages/lossy_samples.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Show every example sorted by loss (descending) for close inspection."""
2
+ import pandas as pd
3
+ import streamlit as st
4
+
5
+ from src.subpages.page import Context, Page
6
+ from src.utils import (
7
+ colorize_classes,
8
+ get_bg_color,
9
+ get_fg_color,
10
+ htmlify_labeled_example,
11
+ )
12
+
13
+
14
+ class LossySamplesPage(Page):
15
+ name = "Samples by Loss"
16
+ icon = "sort-numeric-down-alt"
17
+
18
+ def _get_widget_defaults(self):
19
+ return {
20
+ "skip_correct": True,
21
+ "samples_by_loss_show_df": True,
22
+ }
23
+
24
+ def render(self, context: Context):
25
+ st.title(self.name)
26
+ with st.expander("💡", expanded=True):
27
+ st.write("Show every example sorted by loss (descending) for close inspection.")
28
+ st.write(
29
+ "The **dataframe** is mostly self-explanatory. The cells are color-coded by label, a lighter color signifies a continuation label. Cells in the loss row are filled red from left to right relative to the top loss."
30
+ )
31
+ st.write(
32
+ "The **numbers to the left**: Top (black background) are sample number (listed here) and sample index (from the dataset). Below on yellow background is the total loss for the given sample."
33
+ )
34
+ st.write(
35
+ "The **annotated sample**: Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class."
36
+ )
37
+
38
+ st.subheader("💥 Samples ⬇loss")
39
+ skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct")
40
+ show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df")
41
+
42
+ st.write(
43
+ """<style>
44
+ thead {
45
+ display: none;
46
+ }
47
+ td {
48
+ white-space: nowrap;
49
+ padding: 0 5px !important;
50
+ }
51
+ </style>""",
52
+ unsafe_allow_html=True,
53
+ )
54
+
55
+ top_indices = (
56
+ context.df.sort_values(by="total_loss", ascending=False)
57
+ .query("total_loss > 0.5")
58
+ .index
59
+ )
60
+
61
+ cnt = 0
62
+ for idx in top_indices:
63
+ sample = context.df_tokens_merged.loc[idx]
64
+
65
+ if isinstance(sample, pd.Series):
66
+ continue
67
+
68
+ if skip_correct and sum(sample.labels != sample.preds) == 0:
69
+ continue
70
+
71
+ if show_df:
72
+
73
+ def colorize_col(col):
74
+ if col.name == "labels" or col.name == "preds":
75
+ bgs = []
76
+ fgs = []
77
+ ops = []
78
+ for v in col.values:
79
+ bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
80
+ fgs.append(get_fg_color(bgs[-1]))
81
+ ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5")
82
+ return [
83
+ f"background-color: {bg}; color: {fg}; opacity: {op};"
84
+ for bg, fg, op in zip(bgs, fgs, ops)
85
+ ]
86
+ return [""] * len(col)
87
+
88
+ df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3)
89
+ losses_slice = pd.IndexSlice["losses", :]
90
+ # x = df.T.astype(str)
91
+ # st.dataframe(x)
92
+ # st.dataframe(x.loc[losses_slice])
93
+ styler = (
94
+ df.T.style.apply(colorize_col, axis=1)
95
+ .bar(subset=losses_slice, axis=1)
96
+ .format(precision=3)
97
+ )
98
+ # styler.data = styler.data.astype(str)
99
+ st.write(styler.to_html(), unsafe_allow_html=True)
100
+ st.write("")
101
+ # st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses')) # type: ignore
102
+ # st.write(
103
+ # colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str))
104
+ # )
105
+
106
+ col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32])
107
+
108
+ cnt += 1
109
+ counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>"
110
+ loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
111
+ col1.write(f"{counter}{loss}", unsafe_allow_html=True)
112
+ col1.write("")
113
+
114
+ col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
115
+ # st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)
src/subpages/metrics.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
3
+ """
4
+ import re
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import plotly.express as px
10
+ import streamlit as st
11
+ from seqeval.metrics import classification_report
12
+ from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
13
+
14
+ from src.subpages.page import Context, Page
15
+
16
+
17
+ def _get_evaluation(df):
18
+ y_true = df.apply(lambda row: [lbl for lbl in row.labels if lbl != "IGN"], axis=1)
19
+ y_pred = df.apply(
20
+ lambda row: [pred for (pred, lbl) in zip(row.preds, row.labels) if lbl != "IGN"],
21
+ axis=1,
22
+ )
23
+ report: str = classification_report(y_true, y_pred, scheme="IOB2", digits=3) # type: ignore
24
+ return report.replace(
25
+ "precision recall f1-score support",
26
+ "=" * 12 + " precision recall f1-score support",
27
+ )
28
+
29
+
30
+ def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True):
31
+ cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels)
32
+ if zero_diagonal:
33
+ np.fill_diagonal(cm, 0)
34
+
35
+ # st.write(plt.rcParams["font.size"])
36
+ # plt.rcParams.update({'font.size': 10.0})
37
+ fig, ax = plt.subplots(figsize=(10, 10))
38
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
39
+ fmt = "d" if normalize is None else ".3f"
40
+ disp.plot(
41
+ cmap="Blues",
42
+ include_values=True,
43
+ xticks_rotation="vertical",
44
+ values_format=fmt,
45
+ ax=ax,
46
+ colorbar=False,
47
+ )
48
+ return fig
49
+
50
+
51
+ class MetricsPage(Page):
52
+ name = "Metrics"
53
+ icon = "graph-up-arrow"
54
+
55
+ def _get_widget_defaults(self):
56
+ return {
57
+ "normalize": True,
58
+ "zero_diagonal": False,
59
+ }
60
+
61
+ def render(self, context: Context):
62
+ st.title(self.name)
63
+ with st.expander("💡", expanded=True):
64
+ st.write(
65
+ "The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)."
66
+ )
67
+ st.write(
68
+ "With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone."
69
+ )
70
+
71
+ eval_results = _get_evaluation(context.df)
72
+ if len(eval_results.splitlines()) < 8:
73
+ col1, _, col2 = st.columns([8, 1, 1])
74
+ else:
75
+ col1 = col2 = st
76
+
77
+ col1.subheader("🎯 Evaluation Results")
78
+ col1.code(eval_results)
79
+
80
+ results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]]
81
+ data = [(r[0], int(r[-1]), float(r[-2])) for r in results]
82
+ df = pd.DataFrame(data, columns="class support f1".split())
83
+ fig = px.scatter(
84
+ df,
85
+ x="support",
86
+ y="f1",
87
+ range_y=(0, 1.05),
88
+ color="class",
89
+ )
90
+ # fig.update_layout(title_text="asdf", title_yanchor="bottom")
91
+ col1.plotly_chart(fig)
92
+
93
+ col2.subheader("🔠 Confusion Matrix")
94
+ normalize = None if not col2.checkbox("Normalize", key="normalize") else "true"
95
+ zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal")
96
+ col2.pyplot(
97
+ plot_confusion_matrix(
98
+ y_true=context.df_tokens_cleaned["labels"],
99
+ y_preds=context.df_tokens_cleaned["preds"],
100
+ labels=context.labels,
101
+ normalize=normalize,
102
+ zero_diagonal=zero_diagonal,
103
+ ),
104
+ )
src/subpages/misclassified.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This page contains all misclassified examples and allows filtering by specific error types."""
2
+ from collections import defaultdict
3
+
4
+ import pandas as pd
5
+ import streamlit as st
6
+ from sklearn.metrics import confusion_matrix
7
+
8
+ from src.subpages.page import Context, Page
9
+ from src.utils import htmlify_labeled_example
10
+
11
+
12
+ class MisclassifiedPage(Page):
13
+ name = "Misclassified"
14
+ icon = "x-octagon"
15
+
16
+ def render(self, context: Context):
17
+ st.title(self.name)
18
+ with st.expander("💡", expanded=True):
19
+ st.write(
20
+ "This page contains all misclassified examples and allows filtering by specific error types."
21
+ )
22
+
23
+ misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
24
+ misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
25
+ cm = confusion_matrix(
26
+ misclassified_samples.labels,
27
+ misclassified_samples.preds,
28
+ labels=context.labels,
29
+ )
30
+
31
+ # st.pyplot(
32
+ # plot_confusion_matrix(
33
+ # y_preds=misclassified_samples["preds"],
34
+ # y_true=misclassified_samples["labels"],
35
+ # labels=labels,
36
+ # normalize=None,
37
+ # zero_diagonal=True,
38
+ # ),
39
+ # )
40
+ df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
41
+ import numpy as np
42
+
43
+ np.fill_diagonal(df.values, "")
44
+ st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
45
+ # import matplotlib.pyplot as plt
46
+ # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
47
+ # selection = aggrid_interactive_table(df)
48
+
49
+ # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
50
+
51
+ confusions = defaultdict(int)
52
+ for i, row in enumerate(cm):
53
+ for j, _ in enumerate(row):
54
+ if i == j or cm[i][j] == 0:
55
+ continue
56
+ confusions[(context.labels[i], context.labels[j])] += cm[i][j]
57
+
58
+ def format_func(item):
59
+ return (
60
+ f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
61
+ )
62
+
63
+ conf = st.radio(
64
+ "Filter by Class Confusion",
65
+ options=list(zip(confusions.keys(), confusions.values())),
66
+ format_func=format_func,
67
+ )
68
+
69
+ # st.write(
70
+ # f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
71
+ # )
72
+
73
+ filtered_indices = misclassified_samples.query(
74
+ f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
75
+ ).index
76
+ for i, idx in enumerate(filtered_indices):
77
+ sample = context.df_tokens_merged.loc[idx]
78
+ st.write(
79
+ htmlify_labeled_example(sample),
80
+ unsafe_allow_html=True,
81
+ )
82
+ st.write("---")
src/subpages/page.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ import pandas as pd
5
+ from datasets import Dataset # type: ignore
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoModelForSequenceClassification # type: ignore
8
+ from transformers import AutoTokenizer # type: ignore
9
+
10
+
11
+ @dataclass
12
+ class Context:
13
+ """This object facilitates passing around the application's state between different pages."""
14
+
15
+ model: AutoModelForSequenceClassification
16
+ tokenizer: AutoTokenizer
17
+ sentence_encoder: SentenceTransformer
18
+ tags: Any
19
+ df: pd.DataFrame
20
+ df_tokens: pd.DataFrame
21
+ df_tokens_cleaned: pd.DataFrame
22
+ df_tokens_merged: pd.DataFrame
23
+ split_sample_size: int
24
+ ds_name: str
25
+ ds_config_name: str
26
+ ds_split_name: str
27
+ split: Dataset
28
+ labels: list[str]
29
+
30
+
31
+ class Page:
32
+ """This class encapsulates the logic for a single page of the application."""
33
+
34
+ name: str
35
+ """The page's name that will be used in the sidebar menu."""
36
+
37
+ icon: str
38
+ """The page's icon that will be used in the sidebar menu."""
39
+
40
+ def _get_widget_defaults(self):
41
+ """This function holds the default settings for all widgets contained on this page.
42
+
43
+ Returns:
44
+ dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.
45
+ """
46
+ return {}
47
+
48
+ def render(self, context):
49
+ """This function renders the page."""
50
+ ...
src/subpages/probing.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors.
3
+ """
4
+ import streamlit as st
5
+
6
+ from src.subpages.page import Context, Page
7
+ from src.utils import device, tag_text
8
+
9
+ _DEFAULT_SENTENCES = """
10
+ Damit hatte er auf ihr letztes , völlig schiefgelaufenes Geschäftsessen angespielt .
11
+ Damit einher geht übrigens auch , dass Marcella , Collocinis Tochter , keine wie auch immer geartete strafrechtliche Verfolgung zu befürchten hat .
12
+ Nach dem Bell ’ schen Theorem , einer Physik jenseits der Quanten , ist die Welt , die wir für real halten , nicht objektivierbar .
13
+ Dazu muss man wiederum wissen , dass die Aussagekraft von Tests , neben der Sensitivität und Spezifität , ganz entscheidend von der Vortestwahrscheinlichkeit abhängt .
14
+ Haben Sie sich schon eingelebt ? « erkundigte er sich .
15
+ Das Auto ein Totalschaden , mein Beifahrer ein weinender Jammerlappen .
16
+ Seltsam , wunderte sie sich , dass das Stück nach mehr als eineinhalb Jahrhunderten noch so gut in Schuss ist .
17
+ Oder auf den Strich gehen , Strümpfe stricken , Geld hamstern .
18
+ Und Allah ist Allumfassend Allwissend .
19
+ Und Pedro Moacir redete weiter : » Verzicht , Pater Antonio , Verzicht , zu großer Schmerz über Verzicht , Sehnsucht , die sich nicht erfüllt , die sich nicht erfüllen kann , das sind Qualen , die ein Verstummen nach sich ziehen können , oder Härte .
20
+ Mama-San ging mittlerweile fast ausnahmslos nur mit Wei an ihrer Seite aus dem Haus , kaum je mit einem der Mädchen und niemals allein.
21
+ """.strip()
22
+ _DEFAULT_SENTENCES = """
23
+ Elon Musk’s Berghain humiliation — I know the feeling
24
+ Musk was also seen at a local spot called Sisyphos celebrating entrepreneur Adeo Ressi's birthday, according to The Times.
25
+ """.strip()
26
+
27
+
28
+ class ProbingPage(Page):
29
+ name = "Probing"
30
+ icon = "fonts"
31
+
32
+ def _get_widget_defaults(self):
33
+ return {"probing_textarea": _DEFAULT_SENTENCES}
34
+
35
+ def render(self, context: Context):
36
+ st.title("🔠 Interactive Probing")
37
+
38
+ with st.expander("💡", expanded=True):
39
+ st.write(
40
+ "A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors."
41
+ )
42
+
43
+ sentences = st.text_area("Sentences", height=200, key="probing_textarea")
44
+ if not sentences.strip():
45
+ return
46
+ sentences = [sentence.strip() for sentence in sentences.splitlines()]
47
+
48
+ for sent in sentences:
49
+ sent = sent.replace(",", "").replace(" ", " ")
50
+ with st.expander(sent):
51
+ tagged = tag_text(sent, context.tokenizer, context.model, device)
52
+ tagged = tagged.astype(str)
53
+ tagged["probs"] = tagged["probs"].apply(lambda x: x[:-2])
54
+ tagged["check"] = tagged["probs"].apply(
55
+ lambda x: "✅ ✅" if int(x) < 100 else "✅" if int(x) < 1000 else ""
56
+ )
57
+ st.dataframe(tagged.drop("hidden_states", axis=1).T)
src/subpages/random_samples.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Show random samples. Simple method, but it often turns up interesting things."""
2
+ import pandas as pd
3
+ import streamlit as st
4
+
5
+ from src.subpages.page import Context, Page
6
+ from src.utils import htmlify_labeled_example
7
+
8
+
9
+ class RandomSamplesPage(Page):
10
+ name = "Random Samples"
11
+ icon = "shuffle"
12
+
13
+ def _get_widget_defaults(self):
14
+ return {
15
+ "random_sample_size_min": 128,
16
+ }
17
+
18
+ def render(self, context: Context):
19
+ st.title("🎲 Random Samples")
20
+ with st.expander("💡", expanded=True):
21
+ st.write(
22
+ "Show random samples. Simple method, but it often turns up interesting things."
23
+ )
24
+
25
+ random_sample_size = st.number_input(
26
+ "Random sample size:",
27
+ value=min(st.session_state.random_sample_size_min, context.split_sample_size),
28
+ step=16,
29
+ key="random_sample_size",
30
+ )
31
+
32
+ if st.button("🎲 Resample"):
33
+ st.experimental_rerun()
34
+
35
+ random_indices = context.df.sample(int(random_sample_size)).index
36
+ samples = context.df_tokens_merged.loc[random_indices]
37
+
38
+ for i, idx in enumerate(random_indices):
39
+ sample = samples.loc[idx]
40
+
41
+ if isinstance(sample, pd.Series):
42
+ continue
43
+
44
+ col1, _, col2 = st.columns([0.08, 0.025, 0.8])
45
+
46
+ counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: wh^; padding: 0 5px'>[{i+1} | {idx}]</span>"
47
+ loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
48
+ col1.write(f"{counter}{loss}", unsafe_allow_html=True)
49
+ col1.write("")
50
+ col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
src/subpages/raw_data.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """See the data as seen by your model."""
2
+ import pandas as pd
3
+ import streamlit as st
4
+
5
+ from src.subpages.page import Context, Page
6
+ from src.utils import aggrid_interactive_table
7
+
8
+
9
+ @st.cache
10
+ def convert_df(df):
11
+ return df.to_csv().encode("utf-8")
12
+
13
+
14
+ class RawDataPage(Page):
15
+ name = "Raw data"
16
+ icon = "qr-code"
17
+
18
+ def render(self, context: Context):
19
+ st.title(self.name)
20
+ with st.expander("💡", expanded=True):
21
+ st.write("See the data as seen by your model.")
22
+
23
+ st.subheader("Dataset")
24
+ st.code(
25
+ f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
26
+ )
27
+
28
+ st.write("**Data after processing and inference**")
29
+
30
+ processed_df = (
31
+ context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
32
+ )
33
+ cols = (
34
+ "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
35
+ )
36
+ if "token_type_ids" not in processed_df.columns:
37
+ cols.remove("token_type_ids")
38
+ processed_df = processed_df[cols]
39
+ aggrid_interactive_table(processed_df)
40
+ processed_df_csv = convert_df(processed_df)
41
+ st.download_button(
42
+ "Download csv",
43
+ processed_df_csv,
44
+ "processed_data.csv",
45
+ "text/csv",
46
+ )
47
+
48
+ st.write("**Raw data (exploded by tokens)**")
49
+ raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore
50
+ aggrid_interactive_table(raw_data_df)
51
+ raw_data_df_csv = convert_df(raw_data_df)
52
+ st.download_button(
53
+ "Download csv",
54
+ raw_data_df_csv,
55
+ "raw_data.csv",
56
+ "text/csv",
57
+ )
src/utils.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import matplotlib as matplotlib
4
+ import matplotlib.cm as cm
5
+ import pandas as pd
6
+ import streamlit as st
7
+ import tokenizers
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
11
+
12
+ PROJ = Path(__file__).parent
13
+
14
+ tokenizer_hash_funcs = {
15
+ tokenizers.Tokenizer: lambda _: None,
16
+ tokenizers.AddedToken: lambda _: None,
17
+ }
18
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu")
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ classmap = {
22
+ "O": "O",
23
+ "PER": "🙎",
24
+ "person": "🙎",
25
+ "LOC": "🌎",
26
+ "location": "🌎",
27
+ "ORG": "🏤",
28
+ "corporation": "🏤",
29
+ "product": "📱",
30
+ "creative": "🎷",
31
+ "MISC": "🎷",
32
+ }
33
+
34
+
35
+ def aggrid_interactive_table(df: pd.DataFrame) -> dict:
36
+ """Creates an st-aggrid interactive table based on a dataframe.
37
+
38
+ Args:
39
+ df (pd.DataFrame]): Source dataframe
40
+ Returns:
41
+ dict: The selected row
42
+ """
43
+ options = GridOptionsBuilder.from_dataframe(
44
+ df, enableRowGroup=True, enableValue=True, enablePivot=True
45
+ )
46
+
47
+ options.configure_side_bar()
48
+ # options.configure_default_column(cellRenderer=JsCode('''function(params) {return '<a href="#samples-loss">'+params.value+'</a>'}'''))
49
+
50
+ options.configure_selection("single")
51
+ selection = AgGrid(
52
+ df,
53
+ enable_enterprise_modules=True,
54
+ gridOptions=options.build(),
55
+ theme="light",
56
+ update_mode=GridUpdateMode.NO_UPDATE,
57
+ allow_unsafe_jscode=True,
58
+ )
59
+
60
+ return selection
61
+
62
+
63
+ def explode_df(df: pd.DataFrame) -> pd.DataFrame:
64
+ """Takes a dataframe and explodes all the fields."""
65
+
66
+ df_tokens = df.apply(pd.Series.explode)
67
+ if "losses" in df.columns:
68
+ df_tokens["losses"] = df_tokens["losses"].astype(float)
69
+ return df_tokens # type: ignore
70
+
71
+
72
+ def align_sample(row: pd.Series):
73
+ """Uses word_ids to align all lists in a sample."""
74
+
75
+ columns = row.axes[0].to_list()
76
+ indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
77
+
78
+ out = {}
79
+
80
+ tokens = []
81
+ for i, tok in enumerate(row.tokens):
82
+ if row.word_ids[i] == -1:
83
+ continue
84
+
85
+ if row.word_ids[i] != row.word_ids[i - 1]:
86
+ tokens.append(tok.lstrip("▁").lstrip("##").rstrip("@@"))
87
+ else:
88
+ tokens[-1] += tok.lstrip("▁").lstrip("##").rstrip("@@")
89
+ out["tokens"] = tokens
90
+
91
+ if "preds" in columns:
92
+ out["preds"] = [row.preds[i] for i in indices]
93
+
94
+ if "labels" in columns:
95
+ out["labels"] = [row.labels[i] for i in indices]
96
+
97
+ if "losses" in columns:
98
+ out["losses"] = [row.losses[i] for i in indices]
99
+
100
+ if "probs" in columns:
101
+ out["probs"] = [row.probs[i] for i in indices]
102
+
103
+ if "hidden_states" in columns:
104
+ out["hidden_states"] = [row.hidden_states[i] for i in indices]
105
+
106
+ if "ids" in columns:
107
+ out["ids"] = row.ids
108
+
109
+ assert len(tokens) == len(out["preds"]), (tokens, row.tokens)
110
+
111
+ return out
112
+
113
+
114
+ @st.cache(
115
+ allow_output_mutation=True,
116
+ hash_funcs=tokenizer_hash_funcs,
117
+ )
118
+ def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
119
+ """Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities.
120
+
121
+ Args:
122
+ text (str): The text to be processed
123
+ tokenizer: Tokenizer to use
124
+ model (_type_): Model to use
125
+ device (torch.device): The device we want pytorch to use for its calcultaions.
126
+
127
+ Returns:
128
+ pd.DataFrame: A data frame holding the tagged text.
129
+ """
130
+
131
+ tokens = tokenizer(text).tokens()
132
+ tokenized = tokenizer(text, return_tensors="pt")
133
+ word_ids = [w if w is not None else -1 for w in tokenized.word_ids()]
134
+ input_ids = tokenized.input_ids.to(device)
135
+ outputs = model(input_ids, output_hidden_states=True)
136
+ preds = torch.argmax(outputs.logits, dim=2)
137
+ preds = [model.config.id2label[p] for p in preds[0].cpu().numpy()]
138
+ hidden_states = outputs.hidden_states[-1][0].detach().cpu().numpy()
139
+ # hidden_states = np.mean([hidden_states, outputs.hidden_states[0][0].detach().cpu().numpy()], axis=0)
140
+
141
+ probs = 1 // (
142
+ torch.min(F.softmax(outputs.logits, dim=-1), dim=-1).values[0].detach().cpu().numpy()
143
+ )
144
+
145
+ df = pd.DataFrame(
146
+ [[tokens, word_ids, preds, probs, hidden_states]],
147
+ columns="tokens word_ids preds probs hidden_states".split(),
148
+ )
149
+ merged_df = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
150
+ return explode_df(merged_df).reset_index().drop(columns=["index"])
151
+
152
+
153
+ def get_bg_color(label: str):
154
+ """Retrieves a label's color from the session state."""
155
+ return st.session_state[f"color_{label}"]
156
+
157
+
158
+ def get_fg_color(bg_color_hex: str) -> str:
159
+ """Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast.
160
+
161
+ Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/
162
+
163
+ Args:
164
+ bg_color_hex (str): The background color given as a HEX stirng.
165
+
166
+ Returns:
167
+ str: Either "black" or "white".
168
+ """
169
+ r = int(bg_color_hex[1:3], 16)
170
+ g = int(bg_color_hex[3:5], 16)
171
+ b = int(bg_color_hex[5:7], 16)
172
+ yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
173
+ return "black" if (yiq >= 128) else "white"
174
+
175
+
176
+ def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
177
+ """Colorizes the errors in the dataframe."""
178
+
179
+ def colorize_row(row):
180
+ return [
181
+ "background-color: "
182
+ + ("white" if (row["labels"] == "IGN" or (row["preds"] == row["labels"])) else "pink")
183
+ + ";"
184
+ ] * len(row)
185
+
186
+ def colorize_col(col):
187
+ if col.name == "labels" or col.name == "preds":
188
+ bgs = []
189
+ fgs = []
190
+ for v in col.values:
191
+ bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
192
+ fgs.append(get_fg_color(bgs[-1]))
193
+ return [f"background-color: {bg}; color: {fg};" for bg, fg in zip(bgs, fgs)]
194
+ return [""] * len(col)
195
+
196
+ df = df.reset_index().drop(columns=["index"]).T
197
+ return df # .style.apply(colorize_col, axis=0)
198
+
199
+
200
+ def htmlify_labeled_example(example: pd.DataFrame) -> str:
201
+ """Builds an HTML (string) representation of a single example.
202
+
203
+ Args:
204
+ example (pd.DataFrame): The example to process.
205
+
206
+ Returns:
207
+ str: An HTML string representation of a single example.
208
+ """
209
+ html = []
210
+
211
+ for _, row in example.iterrows():
212
+ pred = row.preds.split("-")[1] if "-" in row.preds else "O"
213
+ label = row.labels
214
+ label_class = row.labels.split("-")[1] if "-" in row.labels else "O"
215
+
216
+ color = get_bg_color(row.preds.split("-")[1]) if "-" in row.preds else "#000000"
217
+ true_color = get_bg_color(row.labels.split("-")[1]) if "-" in row.labels else "#000000"
218
+
219
+ font_color = get_fg_color(color) if color else "white"
220
+ true_font_color = get_fg_color(true_color) if true_color else "white"
221
+
222
+ is_correct = row.preds == row.labels
223
+ loss_html = (
224
+ ""
225
+ if float(row.losses) < 0.01
226
+ else f"<span style='background-color: yellow; color: font_color; padding: 0 5px;'>{row.losses:.3f}</span>"
227
+ )
228
+ loss_html = ""
229
+
230
+ if row.labels == row.preds == "O":
231
+ html.append(f"<span>{row.tokens}</span>")
232
+ elif row.labels == "IGN":
233
+ assert False
234
+ else:
235
+ opacity = "1" if not is_correct else "0.5"
236
+ correct = (
237
+ ""
238
+ if is_correct
239
+ else f"<span title='{label}' style='background-color: {true_color}; opacity: 1; color: {true_font_color}; padding: 0 5px; border: 1px solid black; min-width: 30px'>{classmap[label_class]}</span>"
240
+ )
241
+ pred_icon = classmap[pred] if pred != "O" and row.preds[:2] != "I-" else ""
242
+ html.append(
243
+ f"<span style='border: 1px solid black; color: {color}; padding: 0 5px;' title={row.preds}>{pred_icon + ' '}{row.tokens}</span>{correct}{loss_html}"
244
+ )
245
+
246
+ return " ".join(html)
247
+
248
+
249
+ def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
250
+ """Turns a value into a color using a color map."""
251
+ norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
252
+ cmap = cm.get_cmap(cmap_name) # PiYG
253
+ rgba = cmap(norm(abs(value)))
254
+ color = matplotlib.colors.rgb2hex(rgba[:3])
255
+ return color