Spaces:
Runtime error
Runtime error
Alexander Seifert
commited on
Commit
Β·
597bf7d
1
Parent(s):
9b03751
initial commit
Browse files- .gitignore +165 -0
- README.md +46 -12
- data.py +147 -0
- load.py +84 -0
- main.py +121 -0
- model.py +33 -0
- requirements.txt +14 -0
- subpages/__init__.py +14 -0
- subpages/attention.py +172 -0
- subpages/debug.py +27 -0
- subpages/embeddings.py +155 -0
- subpages/emoji-en-US.json +0 -0
- subpages/faiss.py +58 -0
- subpages/find_duplicates.py +51 -0
- subpages/home.py +149 -0
- subpages/inspect.py +38 -0
- subpages/losses.py +63 -0
- subpages/lossy_samples.py +99 -0
- subpages/metrics.py +98 -0
- subpages/misclassified.py +81 -0
- subpages/page.py +37 -0
- subpages/probing.py +54 -0
- subpages/random_samples.py +51 -0
- subpages/raw_data.py +56 -0
- utils.py +229 -0
.gitignore
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Created by https://www.gitignore.io/api/python,osx,linux
|
3 |
+
# Edit at https://www.gitignore.io/?templates=python,osx,linux
|
4 |
+
|
5 |
+
### Linux ###
|
6 |
+
*~
|
7 |
+
|
8 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
9 |
+
.fuse_hidden*
|
10 |
+
|
11 |
+
# KDE directory preferences
|
12 |
+
.directory
|
13 |
+
|
14 |
+
# Linux trash folder which might appear on any partition or disk
|
15 |
+
.Trash-*
|
16 |
+
|
17 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
18 |
+
.nfs*
|
19 |
+
|
20 |
+
### OSX ###
|
21 |
+
# General
|
22 |
+
.DS_Store
|
23 |
+
.AppleDouble
|
24 |
+
.LSOverride
|
25 |
+
|
26 |
+
# Icon must end with two \r
|
27 |
+
Icon
|
28 |
+
|
29 |
+
# Thumbnails
|
30 |
+
._*
|
31 |
+
|
32 |
+
# Files that might appear in the root of a volume
|
33 |
+
.DocumentRevisions-V100
|
34 |
+
.fseventsd
|
35 |
+
.Spotlight-V100
|
36 |
+
.TemporaryItems
|
37 |
+
.Trashes
|
38 |
+
.VolumeIcon.icns
|
39 |
+
.com.apple.timemachine.donotpresent
|
40 |
+
|
41 |
+
# Directories potentially created on remote AFP share
|
42 |
+
.AppleDB
|
43 |
+
.AppleDesktop
|
44 |
+
Network Trash Folder
|
45 |
+
Temporary Items
|
46 |
+
.apdisk
|
47 |
+
|
48 |
+
### Python ###
|
49 |
+
# Byte-compiled / optimized / DLL files
|
50 |
+
__pycache__/
|
51 |
+
*.py[cod]
|
52 |
+
*$py.class
|
53 |
+
|
54 |
+
# C extensions
|
55 |
+
*.so
|
56 |
+
|
57 |
+
# Distribution / packaging
|
58 |
+
.Python
|
59 |
+
build/
|
60 |
+
develop-eggs/
|
61 |
+
dist/
|
62 |
+
downloads/
|
63 |
+
eggs/
|
64 |
+
.eggs/
|
65 |
+
lib/
|
66 |
+
lib64/
|
67 |
+
parts/
|
68 |
+
sdist/
|
69 |
+
var/
|
70 |
+
wheels/
|
71 |
+
pip-wheel-metadata/
|
72 |
+
share/python-wheels/
|
73 |
+
*.egg-info/
|
74 |
+
.installed.cfg
|
75 |
+
*.egg
|
76 |
+
MANIFEST
|
77 |
+
|
78 |
+
# PyInstaller
|
79 |
+
# Usually these files are written by a python script from a template
|
80 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
81 |
+
*.manifest
|
82 |
+
*.spec
|
83 |
+
|
84 |
+
# Installer logs
|
85 |
+
pip-log.txt
|
86 |
+
pip-delete-this-directory.txt
|
87 |
+
|
88 |
+
# Unit test / coverage reports
|
89 |
+
htmlcov/
|
90 |
+
.tox/
|
91 |
+
.nox/
|
92 |
+
.coverage
|
93 |
+
.coverage.*
|
94 |
+
.cache
|
95 |
+
nosetests.xml
|
96 |
+
coverage.xml
|
97 |
+
*.cover
|
98 |
+
.hypothesis/
|
99 |
+
.pytest_cache/
|
100 |
+
|
101 |
+
# Translations
|
102 |
+
*.mo
|
103 |
+
*.pot
|
104 |
+
|
105 |
+
# Scrapy stuff:
|
106 |
+
.scrapy
|
107 |
+
|
108 |
+
# Sphinx documentation
|
109 |
+
docs/_build/
|
110 |
+
|
111 |
+
# PyBuilder
|
112 |
+
target/
|
113 |
+
|
114 |
+
# pyenv
|
115 |
+
.python-version
|
116 |
+
|
117 |
+
# pipenv
|
118 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
119 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
120 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
121 |
+
# install all needed dependencies.
|
122 |
+
#Pipfile.lock
|
123 |
+
|
124 |
+
# celery beat schedule file
|
125 |
+
celerybeat-schedule
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Spyder project settings
|
131 |
+
.spyderproject
|
132 |
+
.spyproject
|
133 |
+
|
134 |
+
# Rope project settings
|
135 |
+
.ropeproject
|
136 |
+
|
137 |
+
# Mr Developer
|
138 |
+
.mr.developer.cfg
|
139 |
+
.project
|
140 |
+
.pydevproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# End of https://www.gitignore.io/api/python,osx,linux
|
154 |
+
|
155 |
+
.idea/
|
156 |
+
.ipynb_checkpoints/
|
157 |
+
node_modules/
|
158 |
+
data/books/
|
159 |
+
docx/cache/
|
160 |
+
docx/data/
|
161 |
+
save_dir/
|
162 |
+
cache_dir/
|
163 |
+
outputs/
|
164 |
+
models/
|
165 |
+
runs/
|
README.md
CHANGED
@@ -1,12 +1,46 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ExplaiNER
|
2 |
+
|
3 |
+
Error Analysis is an important but often overlooked part of the data science project lifecycle, for which there is still very little tooling available. Due to the lack of tooling, practitioners often write throwaway code or, worse, skip understanding their models' errors altogether. This project tries to provide an extensive toolkit to probe any NER model/dataset combination, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further improvements.
|
4 |
+
|
5 |
+
|
6 |
+
## Sections
|
7 |
+
|
8 |
+
### Probing
|
9 |
+
|
10 |
+
A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors.
|
11 |
+
|
12 |
+
### Embeddings
|
13 |
+
|
14 |
+
For every token in the dataset, we take its hidden state and using TruncatedSVD we project it onto a two-dimensional plane. Data points are colored by label, with mislabeled examples signified by a small black border.
|
15 |
+
|
16 |
+
### Metrics
|
17 |
+
|
18 |
+
The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
|
19 |
+
|
20 |
+
### Misclassified
|
21 |
+
|
22 |
+
asdf
|
23 |
+
|
24 |
+
### Loss by Token/Label
|
25 |
+
|
26 |
+
Shows count, mean and median loss per token and label.
|
27 |
+
|
28 |
+
### Samples by Loss
|
29 |
+
|
30 |
+
Shows every example sorted by loss (descending) for close inspection.
|
31 |
+
|
32 |
+
### Random Samples
|
33 |
+
|
34 |
+
Shows random samples. Simple idea, but often it turns up some interesting things.
|
35 |
+
|
36 |
+
### Inspect
|
37 |
+
|
38 |
+
Inspect your whole dataset, either unfiltered or by id.
|
39 |
+
|
40 |
+
### Raw data
|
41 |
+
|
42 |
+
See the data as seen by your model.
|
43 |
+
|
44 |
+
### Debug
|
45 |
+
|
46 |
+
Some debug info.
|
data.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
import torch
|
6 |
+
from datasets import Dataset, DatasetDict, load_dataset # type: ignore
|
7 |
+
from torch.nn.functional import cross_entropy
|
8 |
+
from transformers import DataCollatorForTokenClassification # type: ignore
|
9 |
+
|
10 |
+
from utils import device, tokenizer_hash_funcs
|
11 |
+
|
12 |
+
|
13 |
+
@st.cache(allow_output_mutation=True)
|
14 |
+
def get_data(ds_name, config_name, split_name, split_sample_size) -> Dataset:
|
15 |
+
ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(seed=0) # type: ignore
|
16 |
+
split = ds[split_name].select(range(split_sample_size))
|
17 |
+
return split
|
18 |
+
|
19 |
+
|
20 |
+
@st.cache(
|
21 |
+
allow_output_mutation=True,
|
22 |
+
hash_funcs=tokenizer_hash_funcs,
|
23 |
+
)
|
24 |
+
def get_collator(tokenizer) -> DataCollatorForTokenClassification:
|
25 |
+
return DataCollatorForTokenClassification(tokenizer)
|
26 |
+
|
27 |
+
|
28 |
+
def create_word_ids_from_tokens(tokenizer, input_ids: list[int]):
|
29 |
+
word_ids = []
|
30 |
+
wid = -1
|
31 |
+
tokens = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
|
32 |
+
|
33 |
+
for i, tok in enumerate(tokens):
|
34 |
+
if tok in tokenizer.all_special_tokens:
|
35 |
+
word_ids.append(-1)
|
36 |
+
continue
|
37 |
+
|
38 |
+
if not tokens[i - 1].endswith("@@") and tokens[i - 1] != "<unk>":
|
39 |
+
wid += 1
|
40 |
+
|
41 |
+
word_ids.append(wid)
|
42 |
+
|
43 |
+
assert len(word_ids) == len(input_ids)
|
44 |
+
return word_ids
|
45 |
+
|
46 |
+
|
47 |
+
def tokenize_and_align_labels(examples, tokenizer):
|
48 |
+
tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
|
49 |
+
labels = []
|
50 |
+
wids = []
|
51 |
+
|
52 |
+
for idx, label in enumerate(examples["ner_tags"]):
|
53 |
+
try:
|
54 |
+
word_ids = tokenized_inputs.word_ids(batch_index=idx)
|
55 |
+
except ValueError:
|
56 |
+
word_ids = create_word_ids_from_tokens(tokenizer, tokenized_inputs["input_ids"][idx])
|
57 |
+
previous_word_idx = None
|
58 |
+
label_ids = []
|
59 |
+
for word_idx in word_ids:
|
60 |
+
if word_idx is -1 or word_idx is None or word_idx == previous_word_idx:
|
61 |
+
label_ids.append(-100)
|
62 |
+
else:
|
63 |
+
label_ids.append(label[word_idx])
|
64 |
+
previous_word_idx = word_idx
|
65 |
+
wids.append(word_ids)
|
66 |
+
labels.append(label_ids)
|
67 |
+
tokenized_inputs["word_ids"] = wids
|
68 |
+
tokenized_inputs["labels"] = labels
|
69 |
+
return tokenized_inputs
|
70 |
+
|
71 |
+
|
72 |
+
def stringify_ner_tags(batch, tags):
|
73 |
+
return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}
|
74 |
+
|
75 |
+
|
76 |
+
def encode_dataset(split, tokenizer):
|
77 |
+
tags = split.features["ner_tags"].feature
|
78 |
+
split = split.map(partial(stringify_ner_tags, tags=tags), batched=True)
|
79 |
+
remove_columns = split.column_names
|
80 |
+
ids = split["id"]
|
81 |
+
split = split.map(
|
82 |
+
partial(tokenize_and_align_labels, tokenizer=tokenizer),
|
83 |
+
batched=True,
|
84 |
+
remove_columns=remove_columns,
|
85 |
+
)
|
86 |
+
word_ids = [[id if id is not None else -1 for id in wids] for wids in split["word_ids"]]
|
87 |
+
return split.remove_columns(["word_ids"]), word_ids, ids
|
88 |
+
|
89 |
+
|
90 |
+
def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
|
91 |
+
# Convert dict of lists to list of dicts suitable for data collator
|
92 |
+
features = [dict(zip(batch, t)) for t in zip(*batch.values())]
|
93 |
+
|
94 |
+
# Pad inputs and labels and put all tensors on device
|
95 |
+
batch = collator(features)
|
96 |
+
input_ids = batch["input_ids"].to(device)
|
97 |
+
attention_mask = batch["attention_mask"].to(device)
|
98 |
+
labels = batch["labels"].to(device)
|
99 |
+
|
100 |
+
with torch.no_grad():
|
101 |
+
# Pass data through model
|
102 |
+
output = model(input_ids, attention_mask, output_hidden_states=True)
|
103 |
+
# logit.size: [batch_size, sequence_length, classes]
|
104 |
+
|
105 |
+
# Predict class with largest logit value on classes axis
|
106 |
+
preds = torch.argmax(output.logits, axis=-1).cpu().numpy() # type: ignore
|
107 |
+
|
108 |
+
# Calculate loss per token after flattening batch dimension with view
|
109 |
+
loss = cross_entropy(
|
110 |
+
output.logits.view(-1, num_classes), labels.view(-1), reduction="none"
|
111 |
+
)
|
112 |
+
|
113 |
+
# Unflatten batch dimension and convert to numpy array
|
114 |
+
loss = loss.view(len(input_ids), -1).cpu().numpy()
|
115 |
+
hidden_states = output.hidden_states[-1].cpu().numpy()
|
116 |
+
|
117 |
+
# logits = output.logits.view(len(input_ids), -1).cpu().numpy()
|
118 |
+
|
119 |
+
return {"losses": loss, "preds": preds, "hidden_states": hidden_states}
|
120 |
+
|
121 |
+
|
122 |
+
def get_split_df(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
|
123 |
+
split_encoded = split_encoded.map(
|
124 |
+
partial(
|
125 |
+
forward_pass_with_label,
|
126 |
+
model=model,
|
127 |
+
collator=collator,
|
128 |
+
num_classes=tags.num_classes,
|
129 |
+
),
|
130 |
+
batched=True,
|
131 |
+
batch_size=8,
|
132 |
+
)
|
133 |
+
df: pd.DataFrame = split_encoded.to_pandas() # type: ignore
|
134 |
+
|
135 |
+
df["tokens"] = df["input_ids"].apply(
|
136 |
+
lambda x: tokenizer.convert_ids_to_tokens(x) # type: ignore
|
137 |
+
)
|
138 |
+
df["labels"] = df["labels"].apply(
|
139 |
+
lambda x: ["IGN" if i == -100 else tags.int2str(int(i)) for i in x]
|
140 |
+
)
|
141 |
+
df["preds"] = df["preds"].apply(lambda x: [model.config.id2label[i] for i in x])
|
142 |
+
df["preds"] = df.apply(lambda x: x["preds"][: len(x["input_ids"])], axis=1)
|
143 |
+
df["losses"] = df.apply(lambda x: x["losses"][: len(x["input_ids"])], axis=1)
|
144 |
+
df["hidden_states"] = df.apply(lambda x: x["hidden_states"][: len(x["input_ids"])], axis=1)
|
145 |
+
df["total_loss"] = df["losses"].apply(sum)
|
146 |
+
|
147 |
+
return df
|
load.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
from datasets import Dataset # type: ignore
|
6 |
+
|
7 |
+
from data import encode_dataset, get_collator, get_data, get_split_df
|
8 |
+
from model import get_encoder, get_model, get_tokenizer
|
9 |
+
from subpages import Context
|
10 |
+
from utils import align_sample, device, explode_df
|
11 |
+
|
12 |
+
_TOKENIZER_NAME = (
|
13 |
+
"xlm-roberta-base",
|
14 |
+
"gagan3012/bert-tiny-finetuned-ner",
|
15 |
+
"distilbert-base-german-cased",
|
16 |
+
)[0]
|
17 |
+
|
18 |
+
|
19 |
+
def _load_models_and_tokenizer(
|
20 |
+
encoder_model_name: str,
|
21 |
+
model_name: str,
|
22 |
+
tokenizer_name: Optional[str],
|
23 |
+
device: str = "cpu",
|
24 |
+
):
|
25 |
+
sentence_encoder = get_encoder(encoder_model_name, device=device)
|
26 |
+
tokenizer = get_tokenizer(tokenizer_name if tokenizer_name else model_name)
|
27 |
+
labels = "O B-COMMA".split() if "comma" in model_name else None
|
28 |
+
model = get_model(model_name, labels=labels)
|
29 |
+
return sentence_encoder, model, tokenizer
|
30 |
+
|
31 |
+
|
32 |
+
@st.cache(allow_output_mutation=True)
|
33 |
+
def load_context(
|
34 |
+
encoder_model_name: str,
|
35 |
+
model_name: str,
|
36 |
+
ds_name: str,
|
37 |
+
ds_config_name: str,
|
38 |
+
ds_split_name: str,
|
39 |
+
split_sample_size: int,
|
40 |
+
**kw_args,
|
41 |
+
) -> Context:
|
42 |
+
|
43 |
+
sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
|
44 |
+
encoder_model_name=encoder_model_name,
|
45 |
+
model_name=model_name,
|
46 |
+
tokenizer_name=_TOKENIZER_NAME if "comma" in model_name else None,
|
47 |
+
device=str(device),
|
48 |
+
)
|
49 |
+
collator = get_collator(tokenizer)
|
50 |
+
|
51 |
+
# load data related stuff
|
52 |
+
split: Dataset = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
|
53 |
+
tags = split.features["ner_tags"].feature
|
54 |
+
split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
|
55 |
+
|
56 |
+
# transform into dataframe
|
57 |
+
df = get_split_df(split_encoded, model, tokenizer, collator, tags)
|
58 |
+
df["word_ids"] = word_ids
|
59 |
+
df["ids"] = ids
|
60 |
+
|
61 |
+
# explode, clean, merge
|
62 |
+
df_tokens = explode_df(df)
|
63 |
+
df_tokens_cleaned = df_tokens.query("labels != 'IGN'")
|
64 |
+
df_merged = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
|
65 |
+
df_tokens_merged = explode_df(df_merged)
|
66 |
+
|
67 |
+
return Context(
|
68 |
+
**{
|
69 |
+
"model": model,
|
70 |
+
"tokenizer": tokenizer,
|
71 |
+
"sentence_encoder": sentence_encoder,
|
72 |
+
"df": df,
|
73 |
+
"df_tokens": df_tokens,
|
74 |
+
"df_tokens_cleaned": df_tokens_cleaned,
|
75 |
+
"df_tokens_merged": df_tokens_merged,
|
76 |
+
"tags": tags,
|
77 |
+
"labels": tags.names,
|
78 |
+
"split_sample_size": split_sample_size,
|
79 |
+
"ds_name": ds_name,
|
80 |
+
"ds_config_name": ds_config_name,
|
81 |
+
"ds_split_name": ds_split_name,
|
82 |
+
"split": split,
|
83 |
+
}
|
84 |
+
)
|
main.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit_option_menu import option_menu
|
4 |
+
|
5 |
+
from load import load_context
|
6 |
+
from subpages import (
|
7 |
+
DebugPage,
|
8 |
+
FindDuplicatesPage,
|
9 |
+
HomePage,
|
10 |
+
LossesPage,
|
11 |
+
LossySamplesPage,
|
12 |
+
MetricsPage,
|
13 |
+
MisclassifiedPage,
|
14 |
+
Page,
|
15 |
+
ProbingPage,
|
16 |
+
RandomSamplesPage,
|
17 |
+
RawDataPage,
|
18 |
+
)
|
19 |
+
from subpages.attention import AttentionPage
|
20 |
+
from subpages.embeddings import EmbeddingsPage
|
21 |
+
from subpages.inspect import InspectPage
|
22 |
+
|
23 |
+
sts = st.sidebar
|
24 |
+
st.set_page_config(
|
25 |
+
layout="wide",
|
26 |
+
page_title="Error Analysis",
|
27 |
+
page_icon="π·οΈ",
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def _show_menu(pages: list[Page]) -> int:
|
32 |
+
with st.sidebar:
|
33 |
+
page_names = [p.name for p in pages]
|
34 |
+
page_icons = [p.icon for p in pages]
|
35 |
+
selected_menu_item = st.session_state.active_page = option_menu(
|
36 |
+
menu_title="ExplaiNER",
|
37 |
+
options=page_names,
|
38 |
+
icons=page_icons,
|
39 |
+
menu_icon="layout-wtf",
|
40 |
+
default_index=0,
|
41 |
+
)
|
42 |
+
return page_names.index(selected_menu_item)
|
43 |
+
assert False
|
44 |
+
|
45 |
+
|
46 |
+
def _initialize_session_state(pages: list[Page]):
|
47 |
+
if "active_page" not in st.session_state:
|
48 |
+
for page in pages:
|
49 |
+
st.session_state.update(**page.get_widget_defaults())
|
50 |
+
st.session_state.update(st.session_state)
|
51 |
+
|
52 |
+
|
53 |
+
def _write_color_legend(context):
|
54 |
+
def style(x):
|
55 |
+
return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]
|
56 |
+
|
57 |
+
labelmap = {
|
58 |
+
"O": "O",
|
59 |
+
"person": "π",
|
60 |
+
"PER": "π",
|
61 |
+
"location": "π",
|
62 |
+
"LOC": "π",
|
63 |
+
"corporation": "π€",
|
64 |
+
"ORG": "π€",
|
65 |
+
"product": "π±",
|
66 |
+
"creative": "π·",
|
67 |
+
"group": "π·",
|
68 |
+
"MISC": "π·",
|
69 |
+
}
|
70 |
+
|
71 |
+
labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
|
72 |
+
colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]
|
73 |
+
|
74 |
+
color_legend_df = pd.DataFrame(
|
75 |
+
[labelmap[l] for l in labels], columns=["label"], index=labels
|
76 |
+
).T
|
77 |
+
st.sidebar.write(
|
78 |
+
color_legend_df.T.style.apply(style, axis=0).set_properties(
|
79 |
+
**{"color": "white", "text-align": "center"}
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
def main():
|
85 |
+
pages: list[Page] = [
|
86 |
+
HomePage(),
|
87 |
+
AttentionPage(),
|
88 |
+
EmbeddingsPage(),
|
89 |
+
ProbingPage(),
|
90 |
+
MetricsPage(),
|
91 |
+
MisclassifiedPage(),
|
92 |
+
LossesPage(),
|
93 |
+
LossySamplesPage(),
|
94 |
+
RandomSamplesPage(),
|
95 |
+
FindDuplicatesPage(),
|
96 |
+
InspectPage(),
|
97 |
+
RawDataPage(),
|
98 |
+
DebugPage(),
|
99 |
+
]
|
100 |
+
|
101 |
+
_initialize_session_state(pages)
|
102 |
+
|
103 |
+
selected_page_idx = _show_menu(pages)
|
104 |
+
selected_page = pages[selected_page_idx]
|
105 |
+
|
106 |
+
if isinstance(selected_page, HomePage):
|
107 |
+
selected_page.render()
|
108 |
+
return
|
109 |
+
|
110 |
+
if "model_name" not in st.session_state:
|
111 |
+
# this can happen if someone loads another page directly (without going through home)
|
112 |
+
st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'")
|
113 |
+
return
|
114 |
+
|
115 |
+
context = load_context(**st.session_state)
|
116 |
+
_write_color_legend(context)
|
117 |
+
selected_page.render(context)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
main()
|
model.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from transformers import AutoModelForTokenClassification # type: ignore
|
4 |
+
from transformers import AutoTokenizer # type: ignore
|
5 |
+
|
6 |
+
|
7 |
+
@st.experimental_singleton()
|
8 |
+
def get_model(model_name: str, labels=None):
|
9 |
+
if labels is None:
|
10 |
+
return AutoModelForTokenClassification.from_pretrained(
|
11 |
+
model_name,
|
12 |
+
output_attentions=True,
|
13 |
+
) # type: ignore
|
14 |
+
else:
|
15 |
+
id2label = {idx: tag for idx, tag in enumerate(labels)}
|
16 |
+
label2id = {tag: idx for idx, tag in enumerate(labels)}
|
17 |
+
return AutoModelForTokenClassification.from_pretrained(
|
18 |
+
model_name,
|
19 |
+
output_attentions=True,
|
20 |
+
num_labels=len(labels),
|
21 |
+
id2label=id2label,
|
22 |
+
label2id=label2id,
|
23 |
+
) # type: ignore
|
24 |
+
|
25 |
+
|
26 |
+
@st.experimental_singleton()
|
27 |
+
def get_encoder(model_name: str, device: str = "cpu"):
|
28 |
+
return SentenceTransformer(model_name, device=device)
|
29 |
+
|
30 |
+
|
31 |
+
@st.experimental_singleton()
|
32 |
+
def get_tokenizer(tokenizer_name: str):
|
33 |
+
return AutoTokenizer.from_pretrained(tokenizer_name)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
pandas
|
3 |
+
scikit-learn
|
4 |
+
plotly
|
5 |
+
sentence-transformers
|
6 |
+
transformers
|
7 |
+
tokenizers
|
8 |
+
datasets
|
9 |
+
numpy
|
10 |
+
matplotlib
|
11 |
+
seqeval
|
12 |
+
st_aggrid
|
13 |
+
streamlit_option_menu
|
14 |
+
git+git://github.com/aseifert/ecco@streamlit
|
subpages/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from subpages.attention import AttentionPage
|
2 |
+
from subpages.debug import DebugPage
|
3 |
+
from subpages.embeddings import EmbeddingsPage
|
4 |
+
from subpages.find_duplicates import FindDuplicatesPage
|
5 |
+
from subpages.home import HomePage
|
6 |
+
from subpages.inspect import InspectPage
|
7 |
+
from subpages.losses import LossesPage
|
8 |
+
from subpages.lossy_samples import LossySamplesPage
|
9 |
+
from subpages.metrics import MetricsPage
|
10 |
+
from subpages.misclassified import MisclassifiedPage
|
11 |
+
from subpages.page import Context, Page
|
12 |
+
from subpages.probing import ProbingPage
|
13 |
+
from subpages.random_samples import RandomSamplesPage
|
14 |
+
from subpages.raw_data import RawDataPage
|
subpages/attention.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ecco
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit.components.v1 import html
|
4 |
+
|
5 |
+
from subpages.page import Context, Page # type: ignore
|
6 |
+
|
7 |
+
SETUP_HTML = """
|
8 |
+
<script src="https://requirejs.org/docs/release/2.3.6/minified/require.js"></script>
|
9 |
+
<script>
|
10 |
+
var ecco_url = 'https://storage.googleapis.com/ml-intro/ecco/'
|
11 |
+
//var ecco_url = 'http://localhost:8000/'
|
12 |
+
|
13 |
+
if (window.ecco === undefined) window.ecco = {}
|
14 |
+
|
15 |
+
// Setup the paths of the script we'll be using
|
16 |
+
requirejs.config({
|
17 |
+
urlArgs: "bust=" + (new Date()).getTime(),
|
18 |
+
nodeRequire: require,
|
19 |
+
paths: {
|
20 |
+
d3: "https://d3js.org/d3.v6.min", // This is only for use in setup.html and basic.html
|
21 |
+
"d3-array": "https://d3js.org/d3-array.v2.min",
|
22 |
+
jquery: "https://code.jquery.com/jquery-3.5.1.min",
|
23 |
+
ecco: ecco_url + 'js/0.0.6/ecco-bundle.min',
|
24 |
+
xregexp: 'https://cdnjs.cloudflare.com/ajax/libs/xregexp/3.2.0/xregexp-all.min'
|
25 |
+
}
|
26 |
+
});
|
27 |
+
|
28 |
+
// Add the css file
|
29 |
+
//requirejs(['d3'],
|
30 |
+
// function (d3) {
|
31 |
+
// d3.select('#css').attr('href', ecco_url + 'html/styles.css')
|
32 |
+
// })
|
33 |
+
|
34 |
+
console.log('Ecco initialize!!')
|
35 |
+
|
36 |
+
// returns a 'basic' object. basic.init() selects the html div we'll be
|
37 |
+
// rendering the html into, adds styles.css to the document.
|
38 |
+
define('basic', ['d3'],
|
39 |
+
function (d3) {
|
40 |
+
return {
|
41 |
+
init: function (viz_id = null) {
|
42 |
+
if (viz_id == null) {
|
43 |
+
viz_id = "viz_" + Math.round(Math.random() * 10000000)
|
44 |
+
}
|
45 |
+
// Select the div rendered below, change its id
|
46 |
+
const div = d3.select('#basic').attr('id', viz_id),
|
47 |
+
div_parent = d3.select('#' + viz_id).node().parentNode
|
48 |
+
|
49 |
+
// Link to CSS file
|
50 |
+
d3.select(div_parent).insert('link')
|
51 |
+
.attr('rel', 'stylesheet')
|
52 |
+
.attr('type', 'text/css')
|
53 |
+
.attr('href', ecco_url + 'html/0.0.2/styles.css')
|
54 |
+
|
55 |
+
return viz_id
|
56 |
+
}
|
57 |
+
}
|
58 |
+
}, function (err) {
|
59 |
+
console.log(err);
|
60 |
+
}
|
61 |
+
)
|
62 |
+
</script>
|
63 |
+
|
64 |
+
<head>
|
65 |
+
<link id='css' rel="stylesheet" type="text/css">
|
66 |
+
</head>
|
67 |
+
<div id="basic"></div>
|
68 |
+
"""
|
69 |
+
|
70 |
+
JS_TEMPLATE = """requirejs(['basic', 'ecco'], function(basic, ecco){{
|
71 |
+
const viz_id = basic.init()
|
72 |
+
|
73 |
+
ecco.interactiveTokensAndFactorSparklines(viz_id, {},
|
74 |
+
{{
|
75 |
+
'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}}
|
76 |
+
}}
|
77 |
+
}})
|
78 |
+
}}, function (err) {{
|
79 |
+
console.log(err);
|
80 |
+
}})"""
|
81 |
+
|
82 |
+
|
83 |
+
@st.cache(allow_output_mutation=True)
|
84 |
+
def load_ecco_model():
|
85 |
+
model_config = {
|
86 |
+
"embedding": "embeddings.word_embeddings",
|
87 |
+
"type": "mlm",
|
88 |
+
"activations": [r"ffn\.lin1"],
|
89 |
+
"token_prefix": "",
|
90 |
+
"partial_token_prefix": "##",
|
91 |
+
}
|
92 |
+
return ecco.from_pretrained(
|
93 |
+
"elastic/distilbert-base-uncased-finetuned-conll03-english",
|
94 |
+
model_config=model_config,
|
95 |
+
activations=True,
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
class AttentionPage(Page):
|
100 |
+
name = "Activations"
|
101 |
+
icon = "activity"
|
102 |
+
|
103 |
+
def get_widget_defaults(self):
|
104 |
+
return {
|
105 |
+
"act_n_components": 8,
|
106 |
+
"act_default_text": """Now I ask you: \n what can be expected of man since he is a being endowed with strange qualities? Shower upon him every earthly blessing, drown him in a sea of happiness, so that nothing but bubbles of bliss can be seen on the surface; give him economic prosperity, such that he should have nothing else to do but sleep, eat cakes and busy himself with the continuation of his species, and even then out of sheer ingratitude, sheer spite, man would play you some nasty trick. He would even risk his cakes and would deliberately desire the most fatal rubbish, the most uneconomical absurdity, simply to introduce into all this positive good sense his fatal fantastic element. It is just his fantastic dreams, his vulgar folly that he will desire to retain, simply in order to prove to himself--as though that were so necessary-- that men still are men and not the keys of a piano, which the laws of nature threaten to control so completely that soon one will be able to desire nothing but by the calendar. And that is not all: even if man really were nothing but a piano-key, even if this were proved to him by natural science and mathematics, even then he would not become reasonable, but would purposely do something perverse out of simple ingratitude, simply to gain his point. And if he does not find means he will contrive destruction and chaos, will contrive sufferings of all sorts, only to gain his point! He will launch a curse upon the world, and as only man can curse (it is his privilege, the primary distinction between him and other animals), may be by his curse alone he will attain his object--that is, convince himself that he is a man and not a piano-key!""",
|
107 |
+
"act_from_layer": 0,
|
108 |
+
"act_to_layer": 6,
|
109 |
+
}
|
110 |
+
|
111 |
+
def render(self, context: Context):
|
112 |
+
st.title(self.name)
|
113 |
+
|
114 |
+
with st.expander("βΉοΈ", expanded=True):
|
115 |
+
st.write(
|
116 |
+
"A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model."
|
117 |
+
)
|
118 |
+
|
119 |
+
lm = load_ecco_model()
|
120 |
+
|
121 |
+
col1, _, col2 = st.columns([1.5, 0.5, 4])
|
122 |
+
with col1:
|
123 |
+
st.subheader("Settings")
|
124 |
+
n_components = st.slider(
|
125 |
+
"#components",
|
126 |
+
key="act_n_components",
|
127 |
+
min_value=2,
|
128 |
+
max_value=10,
|
129 |
+
step=1,
|
130 |
+
)
|
131 |
+
from_layer = (
|
132 |
+
st.slider(
|
133 |
+
"from layer",
|
134 |
+
key="act_from_layer",
|
135 |
+
value=0,
|
136 |
+
min_value=0,
|
137 |
+
max_value=len(lm.model.transformer.layer) - 1,
|
138 |
+
step=1,
|
139 |
+
)
|
140 |
+
or None
|
141 |
+
)
|
142 |
+
to_layer = (
|
143 |
+
st.slider(
|
144 |
+
"to layer",
|
145 |
+
key="act_to_layer",
|
146 |
+
value=0,
|
147 |
+
min_value=0,
|
148 |
+
max_value=len(lm.model.transformer.layer),
|
149 |
+
step=1,
|
150 |
+
)
|
151 |
+
or None
|
152 |
+
)
|
153 |
+
with col2:
|
154 |
+
st.subheader("β")
|
155 |
+
text = st.text_area("Text", key="act_default_text")
|
156 |
+
|
157 |
+
inputs = lm.tokenizer([text], return_tensors="pt")
|
158 |
+
output = lm(inputs)
|
159 |
+
nmf = output.run_nmf(n_components=n_components, from_layer=from_layer, to_layer=to_layer)
|
160 |
+
data = nmf.explore(returnData=True)
|
161 |
+
JS_TEMPLATE = f"""<script>requirejs(['basic', 'ecco'], function(basic, ecco){{
|
162 |
+
const viz_id = basic.init()
|
163 |
+
|
164 |
+
ecco.interactiveTokensAndFactorSparklines(viz_id, {data},
|
165 |
+
{{
|
166 |
+
'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}}
|
167 |
+
}}
|
168 |
+
}})
|
169 |
+
}}, function (err) {{
|
170 |
+
console.log(err);
|
171 |
+
}})</script>"""
|
172 |
+
html(SETUP_HTML + JS_TEMPLATE, height=800, scrolling=True)
|
subpages/debug.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pip._internal.operations import freeze
|
3 |
+
|
4 |
+
from subpages.page import Context, Page
|
5 |
+
|
6 |
+
|
7 |
+
class DebugPage(Page):
|
8 |
+
name = "Debug"
|
9 |
+
icon = "bug"
|
10 |
+
|
11 |
+
def render(self, context: Context):
|
12 |
+
st.title(self.name)
|
13 |
+
# with st.expander("π‘", expanded=True):
|
14 |
+
# st.write("Some debug info.")
|
15 |
+
|
16 |
+
st.subheader("Installed Packages")
|
17 |
+
# get output of pip freeze from system
|
18 |
+
with st.expander("pip freeze"):
|
19 |
+
st.code("\n".join(freeze.freeze()))
|
20 |
+
|
21 |
+
st.subheader("Streamlit Session State")
|
22 |
+
st.json(st.session_state)
|
23 |
+
st.subheader("Tokenizer")
|
24 |
+
st.code(context.tokenizer)
|
25 |
+
st.subheader("Model")
|
26 |
+
st.code(context.model.config)
|
27 |
+
st.code(context.model)
|
subpages/embeddings.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import plotly.express as px
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
from subpages.page import Context, Page
|
7 |
+
|
8 |
+
|
9 |
+
@st.cache
|
10 |
+
def reduce_dim_svd(X, n_iter, random_state=42):
|
11 |
+
from sklearn.decomposition import TruncatedSVD
|
12 |
+
|
13 |
+
svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
|
14 |
+
return svd.fit_transform(X)
|
15 |
+
|
16 |
+
|
17 |
+
@st.cache
|
18 |
+
def reduce_dim_pca(X, random_state=42):
|
19 |
+
from sklearn.decomposition import PCA
|
20 |
+
|
21 |
+
return PCA(n_components=2, random_state=random_state).fit_transform(X)
|
22 |
+
|
23 |
+
|
24 |
+
@st.cache
|
25 |
+
def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
|
26 |
+
from umap import UMAP
|
27 |
+
|
28 |
+
return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
|
29 |
+
|
30 |
+
|
31 |
+
class EmbeddingsPage(Page):
|
32 |
+
name = "Embeddings"
|
33 |
+
icon = "grid-3x3"
|
34 |
+
|
35 |
+
def get_widget_defaults(self):
|
36 |
+
return {
|
37 |
+
"n_tokens": 1_000,
|
38 |
+
"svd_n_iter": 5,
|
39 |
+
"svd_random_state": 42,
|
40 |
+
"umap_n_neighbors": 15,
|
41 |
+
"umap_metric": "euclidean",
|
42 |
+
"umap_min_dist": 0.1,
|
43 |
+
}
|
44 |
+
|
45 |
+
def render(self, context: Context):
|
46 |
+
st.title("Embeddings")
|
47 |
+
|
48 |
+
with st.expander("π‘", expanded=True):
|
49 |
+
st.write(
|
50 |
+
"For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples signified by a small black border."
|
51 |
+
)
|
52 |
+
|
53 |
+
col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
|
54 |
+
df = context.df_tokens_merged.copy()
|
55 |
+
dim_algo = "SVD"
|
56 |
+
n_tokens = 100
|
57 |
+
|
58 |
+
with col1:
|
59 |
+
st.subheader("Settings")
|
60 |
+
n_tokens = st.slider(
|
61 |
+
"#tokens",
|
62 |
+
key="n_tokens",
|
63 |
+
min_value=100,
|
64 |
+
max_value=len(df["tokens"].unique()),
|
65 |
+
step=100,
|
66 |
+
)
|
67 |
+
|
68 |
+
dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
|
69 |
+
if dim_algo == "SVD":
|
70 |
+
svd_n_iter = st.slider(
|
71 |
+
"#iterations",
|
72 |
+
key="svd_n_iter",
|
73 |
+
min_value=1,
|
74 |
+
max_value=10,
|
75 |
+
step=1,
|
76 |
+
)
|
77 |
+
elif dim_algo == "UMAP":
|
78 |
+
umap_n_neighbors = st.slider(
|
79 |
+
"#neighbors",
|
80 |
+
key="umap_n_neighbors",
|
81 |
+
min_value=2,
|
82 |
+
max_value=100,
|
83 |
+
step=1,
|
84 |
+
)
|
85 |
+
umap_min_dist = st.number_input(
|
86 |
+
"Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
|
87 |
+
)
|
88 |
+
umap_metric = st.selectbox(
|
89 |
+
"Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
pass
|
93 |
+
|
94 |
+
with col2:
|
95 |
+
sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
|
96 |
+
|
97 |
+
X = np.array(df["hidden_states"].tolist())
|
98 |
+
transformed_hidden_states = None
|
99 |
+
if dim_algo == "SVD":
|
100 |
+
transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore
|
101 |
+
elif dim_algo == "PCA":
|
102 |
+
transformed_hidden_states = reduce_dim_pca(X)
|
103 |
+
elif dim_algo == "UMAP":
|
104 |
+
transformed_hidden_states = reduce_dim_umap(
|
105 |
+
X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
|
106 |
+
)
|
107 |
+
|
108 |
+
assert isinstance(transformed_hidden_states, np.ndarray)
|
109 |
+
df["x"] = transformed_hidden_states[:, 0]
|
110 |
+
df["y"] = transformed_hidden_states[:, 1]
|
111 |
+
df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
|
112 |
+
df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
|
113 |
+
df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
|
114 |
+
df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
|
115 |
+
df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
|
116 |
+
df["mislabeled"] = df["labels"] != df["preds"]
|
117 |
+
|
118 |
+
subset = df[:n_tokens]
|
119 |
+
mislabeled_examples_trace = go.Scatter(
|
120 |
+
x=subset[subset["mislabeled"]]["x"],
|
121 |
+
y=subset[subset["mislabeled"]]["y"],
|
122 |
+
mode="markers",
|
123 |
+
marker=dict(
|
124 |
+
size=6,
|
125 |
+
color="rgba(0,0,0,0)",
|
126 |
+
line=dict(width=1),
|
127 |
+
),
|
128 |
+
hoverinfo="skip",
|
129 |
+
)
|
130 |
+
|
131 |
+
st.subheader("Projection Results")
|
132 |
+
|
133 |
+
fig = px.scatter(
|
134 |
+
subset,
|
135 |
+
x="x",
|
136 |
+
y="y",
|
137 |
+
color="labels",
|
138 |
+
hover_data=["sent0", "sent1", "sent2", "sent3", "sent4"],
|
139 |
+
hover_name="tokens",
|
140 |
+
title="Colored by label",
|
141 |
+
)
|
142 |
+
fig.add_trace(mislabeled_examples_trace)
|
143 |
+
st.plotly_chart(fig)
|
144 |
+
|
145 |
+
fig = px.scatter(
|
146 |
+
subset,
|
147 |
+
x="x",
|
148 |
+
y="y",
|
149 |
+
color="preds",
|
150 |
+
hover_data=["sent0", "sent1", "sent2", "sent3", "sent4"],
|
151 |
+
hover_name="tokens",
|
152 |
+
title="Colored by prediction",
|
153 |
+
)
|
154 |
+
fig.add_trace(mislabeled_examples_trace)
|
155 |
+
st.plotly_chart(fig)
|
subpages/emoji-en-US.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
subpages/faiss.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from datasets import Dataset
|
3 |
+
|
4 |
+
from subpages.page import Context, Page # type: ignore
|
5 |
+
from utils import device, explode_df, htmlify_labeled_example, tag_text
|
6 |
+
|
7 |
+
|
8 |
+
class FaissPage(Page):
|
9 |
+
name = "Bla"
|
10 |
+
icon = "x-octagon"
|
11 |
+
|
12 |
+
def render(self, context: Context):
|
13 |
+
dd = Dataset.from_pandas(context.df_tokens_merged, preserve_index=False) # type: ignore
|
14 |
+
|
15 |
+
dd.add_faiss_index(column="hidden_states", index_name="token_index")
|
16 |
+
token_id, text = (
|
17 |
+
6,
|
18 |
+
"Die Wissenschaft ist eine wichtige Grundlage fΓΌr die Entwicklung von neuen Technologien.",
|
19 |
+
)
|
20 |
+
token_id, text = (
|
21 |
+
15,
|
22 |
+
"AuΓer der unbewussten Beeinflussung eines Resultats gibt es auch noch andere Motive die das reine strahlende Licht der Wissenschaft etwas zu trΓΌben vermΓΆgen.",
|
23 |
+
)
|
24 |
+
token_id, text = (
|
25 |
+
3,
|
26 |
+
"Mit mehr Instrumenten einer besseren prΓ€ziseren Datenbasis ist auch ein viel besseres smarteres Risikomanagement mΓΆglich.",
|
27 |
+
)
|
28 |
+
token_id, text = (
|
29 |
+
7,
|
30 |
+
"Es gilt die akademische Viertelstunde das heiΓt Beginn ist fΓΌnfzehn Minuten spΓ€ter.",
|
31 |
+
)
|
32 |
+
token_id, text = (
|
33 |
+
7,
|
34 |
+
"Damit einher geht ΓΌbrigens auch dass Marcella Collocinis Tochter keine wie auch immer geartete strafrechtliche Verfolgung zu befΓΌrchten hat.",
|
35 |
+
)
|
36 |
+
token_id, text = (
|
37 |
+
16,
|
38 |
+
"After Steve Jobs met with Bill Gates of Microsoft back in 1993, they went to Cupertino and made the deal.",
|
39 |
+
)
|
40 |
+
|
41 |
+
tagged = tag_text(text, context.tokenizer, context.model, device)
|
42 |
+
hidden_states = tagged["hidden_states"]
|
43 |
+
# tagged.drop("hidden_states", inplace=True, axis=1)
|
44 |
+
# hidden_states_vec = svd.transform([hidden_states[token_id]])[0].astype(np.float32)
|
45 |
+
hidden_states_vec = hidden_states[token_id]
|
46 |
+
tagged = tagged.astype(str)
|
47 |
+
tagged["probs"] = tagged["probs"].apply(lambda x: x[:-2])
|
48 |
+
tagged["check"] = tagged["probs"].apply(
|
49 |
+
lambda x: "β
β
" if int(x) < 100 else "β
" if int(x) < 1000 else ""
|
50 |
+
)
|
51 |
+
st.dataframe(tagged.drop("hidden_states", axis=1).T)
|
52 |
+
results = dd.get_nearest_examples("token_index", hidden_states_vec, k=10)
|
53 |
+
for i, (dist, idx, token) in enumerate(
|
54 |
+
zip(results.scores, results.examples["ids"], results.examples["tokens"])
|
55 |
+
):
|
56 |
+
st.code(f"{dist:.3f} {token}")
|
57 |
+
sample = context.df_tokens_merged.query(f"ids == '{idx}'")
|
58 |
+
st.write(f"[{i};{idx}] " + htmlify_labeled_example(sample), unsafe_allow_html=True)
|
subpages/find_duplicates.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from sentence_transformers.util import cos_sim
|
3 |
+
|
4 |
+
from subpages.page import Context, Page
|
5 |
+
|
6 |
+
|
7 |
+
@st.cache()
|
8 |
+
def get_sims(texts: list[str], sentence_encoder):
|
9 |
+
embeddings = sentence_encoder.encode(texts, batch_size=8, convert_to_numpy=True)
|
10 |
+
return cos_sim(embeddings, embeddings)
|
11 |
+
|
12 |
+
|
13 |
+
class FindDuplicatesPage(Page):
|
14 |
+
name = "Find Duplicates"
|
15 |
+
icon = "fingerprint"
|
16 |
+
|
17 |
+
def get_widget_defaults(self):
|
18 |
+
return {
|
19 |
+
"cutoff": 0.95,
|
20 |
+
}
|
21 |
+
|
22 |
+
def render(self, context: Context):
|
23 |
+
st.title("Find Duplicates")
|
24 |
+
with st.expander("π‘", expanded=True):
|
25 |
+
st.write("Find potential duplicates in the data using cosine similarity.")
|
26 |
+
|
27 |
+
cutoff = st.slider("Similarity threshold", min_value=0.0, max_value=1.0, key="cutoff")
|
28 |
+
# split.add_faiss_index(column="embeddings", index_name="sent_index")
|
29 |
+
# st.write("Index is ready")
|
30 |
+
# sentence_encoder.encode(["hello world"], batch_size=8)
|
31 |
+
# st.write(split["tokens"][0])
|
32 |
+
texts = [" ".join(ts) for ts in context.split["tokens"]]
|
33 |
+
sims = get_sims(texts, context.sentence_encoder)
|
34 |
+
|
35 |
+
candidates = []
|
36 |
+
for i in range(len(sims)):
|
37 |
+
for j in range(i + 1, len(sims)):
|
38 |
+
if sims[i][j] >= cutoff:
|
39 |
+
candidates.append((sims[i][j], i, j))
|
40 |
+
candidates.sort(reverse=False)
|
41 |
+
|
42 |
+
for (sim, i, j) in candidates[:100]:
|
43 |
+
st.markdown(f"**Possible duplicate ({i}, {j}, sim: {sim:.3f}):**")
|
44 |
+
st.markdown("* " + " ".join(context.split["tokens"][i]))
|
45 |
+
st.markdown("* " + " ".join(context.split["tokens"][j]))
|
46 |
+
|
47 |
+
# st.write("queries")
|
48 |
+
# results = split.get_nearest_examples("sent_index", np.array(split["embeddings"][0], dtype=np.float32), k=2)
|
49 |
+
# results = split.get_nearest_examples_batch("sent_index", queries, k=2)
|
50 |
+
# st.write(results.total_examples[0]["id"][1])
|
51 |
+
# st.write(results.total_examples[0])
|
subpages/home.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
from pandas import wide_to_long
|
7 |
+
|
8 |
+
from data import get_data
|
9 |
+
from subpages.page import Context, Page
|
10 |
+
from utils import color_map_color
|
11 |
+
|
12 |
+
_SENTENCE_ENCODER_MODEL = (
|
13 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
14 |
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
15 |
+
)[0]
|
16 |
+
_MODEL_NAME = (
|
17 |
+
"elastic/distilbert-base-uncased-finetuned-conll03-english",
|
18 |
+
"gagan3012/bert-tiny-finetuned-ner",
|
19 |
+
"socialmediaie/bertweet-base_wnut17_ner",
|
20 |
+
"sberbank-ai/bert-base-NER-reptile-5-datasets",
|
21 |
+
"aseifert/comma-xlm-roberta-base",
|
22 |
+
"dslim/bert-base-NER",
|
23 |
+
"aseifert/distilbert-base-german-cased-comma-derstandard",
|
24 |
+
)[0]
|
25 |
+
_DATASET_NAME = (
|
26 |
+
"conll2003",
|
27 |
+
"wnut_17",
|
28 |
+
"aseifert/comma",
|
29 |
+
)[0]
|
30 |
+
_CONFIG_NAME = (
|
31 |
+
"conll2003",
|
32 |
+
"wnut_17",
|
33 |
+
"seifertverlag",
|
34 |
+
)[0]
|
35 |
+
|
36 |
+
|
37 |
+
class HomePage(Page):
|
38 |
+
name = "Home / Setup"
|
39 |
+
icon = "house"
|
40 |
+
|
41 |
+
def get_widget_defaults(self):
|
42 |
+
return {
|
43 |
+
"encoder_model_name": _SENTENCE_ENCODER_MODEL,
|
44 |
+
"model_name": _MODEL_NAME,
|
45 |
+
"ds_name": _DATASET_NAME,
|
46 |
+
"ds_split_name": "validation",
|
47 |
+
"ds_config_name": _CONFIG_NAME,
|
48 |
+
"split_sample_size": 512,
|
49 |
+
}
|
50 |
+
|
51 |
+
def render(self, context: Optional[Context] = None):
|
52 |
+
st.title("ExplaiNER")
|
53 |
+
|
54 |
+
with st.expander("π‘", expanded=True):
|
55 |
+
st.write(
|
56 |
+
"**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further improvements."
|
57 |
+
)
|
58 |
+
|
59 |
+
col1, _, col2a, col2b = st.columns([1, 0.05, 0.15, 0.15])
|
60 |
+
|
61 |
+
with col1:
|
62 |
+
random_form_key = f"settings-{random.randint(0, 100000)}"
|
63 |
+
# FIXME: for some reason I'm getting the following error if I don't randomize the key:
|
64 |
+
"""
|
65 |
+
2022-05-05 20:37:16.507 Traceback (most recent call last):
|
66 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/scriptrunner/script_runner.py", line 443, in _run_script
|
67 |
+
exec(code, module.__dict__)
|
68 |
+
File "/Users/zoro/code/error-analysis/main.py", line 162, in <module>
|
69 |
+
main()
|
70 |
+
File "/Users/zoro/code/error-analysis/main.py", line 102, in main
|
71 |
+
show_setup()
|
72 |
+
File "/Users/zoro/code/error-analysis/section/setup.py", line 68, in show_setup
|
73 |
+
st.form_submit_button("Load Model & Data")
|
74 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 240, in form_submit_button
|
75 |
+
return self._form_submit_button(
|
76 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 260, in _form_submit_button
|
77 |
+
return self.dg._button(
|
78 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/button.py", line 304, in _button
|
79 |
+
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
|
80 |
+
File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/utils.py", line 74, in check_session_state_rules
|
81 |
+
raise StreamlitAPIException(
|
82 |
+
streamlit.errors.StreamlitAPIException: Values for st.button, st.download_button, st.file_uploader, and st.form cannot be set using st.session_state.
|
83 |
+
"""
|
84 |
+
with st.form(key=random_form_key):
|
85 |
+
st.subheader("Model & Data Selection")
|
86 |
+
st.text_input(
|
87 |
+
label="NER Model:",
|
88 |
+
key="model_name",
|
89 |
+
help="Path or name of the model to use",
|
90 |
+
)
|
91 |
+
st.text_input(
|
92 |
+
label="Encoder Model:",
|
93 |
+
key="encoder_model_name",
|
94 |
+
help="Path or name of the encoder to use",
|
95 |
+
)
|
96 |
+
ds_name = st.text_input(
|
97 |
+
label="Dataset:",
|
98 |
+
key="ds_name",
|
99 |
+
help="Path or name of the dataset to use",
|
100 |
+
)
|
101 |
+
ds_config_name = st.text_input(
|
102 |
+
label="Config (optional):",
|
103 |
+
key="ds_config_name",
|
104 |
+
)
|
105 |
+
ds_split_name = st.selectbox(
|
106 |
+
label="Split:",
|
107 |
+
options=["train", "validation", "test"],
|
108 |
+
key="ds_split_name",
|
109 |
+
)
|
110 |
+
split_sample_size = st.number_input(
|
111 |
+
"Sample size:",
|
112 |
+
step=16,
|
113 |
+
key="split_sample_size",
|
114 |
+
help="Sample size for the split, speeds up processing inside streamlit",
|
115 |
+
)
|
116 |
+
# breakpoint()
|
117 |
+
# st.form_submit_button("Submit")
|
118 |
+
st.form_submit_button("Load Model & Data")
|
119 |
+
|
120 |
+
split = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
|
121 |
+
labels = list(
|
122 |
+
set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"])
|
123 |
+
)
|
124 |
+
|
125 |
+
with col2a:
|
126 |
+
st.subheader("Classes")
|
127 |
+
st.write("**Color**")
|
128 |
+
colors = {label: color_map_color(i / len(labels)) for i, label in enumerate(labels)}
|
129 |
+
for label in labels:
|
130 |
+
if f"color_{label}" not in st.session_state:
|
131 |
+
st.session_state[f"color_{label}"] = colors[label]
|
132 |
+
st.color_picker(label, key=f"color_{label}")
|
133 |
+
with col2b:
|
134 |
+
st.subheader("β")
|
135 |
+
st.write("**Icon**")
|
136 |
+
emojis = list(json.load(open("subpages/emoji-en-US.json")).keys())
|
137 |
+
for label in labels:
|
138 |
+
if f"icon_{label}" not in st.session_state:
|
139 |
+
st.session_state[f"icon_{label}"] = "π€" # labels[label]
|
140 |
+
st.selectbox(label, key=f"icon_{label}", options=emojis)
|
141 |
+
|
142 |
+
# if st.button("Reset to defaults"):
|
143 |
+
# st.session_state.update(**get_home_page_defaults())
|
144 |
+
# # time.sleep 2 secs
|
145 |
+
# import time
|
146 |
+
# time.sleep(1)
|
147 |
+
|
148 |
+
# # st.legacy_caching.clear_cache()
|
149 |
+
# st.experimental_rerun()
|
subpages/inspect.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from subpages.page import Context, Page
|
4 |
+
from utils import aggrid_interactive_table, colorize_classes
|
5 |
+
|
6 |
+
|
7 |
+
class InspectPage(Page):
|
8 |
+
name = "Inspect"
|
9 |
+
icon = "search"
|
10 |
+
|
11 |
+
def render(self, context: Context):
|
12 |
+
st.title(self.name)
|
13 |
+
with st.expander("π‘", expanded=True):
|
14 |
+
st.write("Inspect your whole dataset, either unfiltered or by id.")
|
15 |
+
|
16 |
+
df = context.df_tokens
|
17 |
+
cols = (
|
18 |
+
"ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
|
19 |
+
)
|
20 |
+
if "token_type_ids" not in df.columns:
|
21 |
+
cols.remove("token_type_ids")
|
22 |
+
df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols]
|
23 |
+
|
24 |
+
if st.checkbox("Filter by id", value=True):
|
25 |
+
ids = list(sorted(map(int, df.ids.unique())))
|
26 |
+
next_id = st.session_state.get("next_id", 0)
|
27 |
+
|
28 |
+
example_id = st.selectbox("Select an example", ids, index=next_id)
|
29 |
+
df = df[df.ids == str(example_id)][1:-1]
|
30 |
+
# st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses")) # type: ignore
|
31 |
+
st.dataframe(colorize_classes(df.round(3).astype(str)))
|
32 |
+
|
33 |
+
if st.button("Next example"):
|
34 |
+
st.session_state.next_id = (ids.index(example_id) + 1) % len(ids)
|
35 |
+
if st.button("Previous example"):
|
36 |
+
st.session_state.next_id = (ids.index(example_id) - 1) % len(ids)
|
37 |
+
else:
|
38 |
+
aggrid_interactive_table(df.round(3))
|
subpages/losses.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from subpages.page import Context, Page
|
4 |
+
from utils import AgGrid, aggrid_interactive_table
|
5 |
+
|
6 |
+
|
7 |
+
@st.cache
|
8 |
+
def get_loss_by_token(df_tokens):
|
9 |
+
return (
|
10 |
+
df_tokens.groupby("tokens")[["losses"]]
|
11 |
+
.agg(["count", "mean", "median", "sum"])
|
12 |
+
.droplevel(level=0, axis=1) # Get rid of multi-level columns
|
13 |
+
.sort_values(by="sum", ascending=False)
|
14 |
+
.reset_index()
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
@st.cache
|
19 |
+
def get_loss_by_label(df_tokens):
|
20 |
+
return (
|
21 |
+
df_tokens.groupby("labels")[["losses"]]
|
22 |
+
.agg(["count", "mean", "median", "sum"])
|
23 |
+
.droplevel(level=0, axis=1)
|
24 |
+
.sort_values(by="mean", ascending=False)
|
25 |
+
.reset_index()
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class LossesPage(Page):
|
30 |
+
name = "Loss by Token/Label"
|
31 |
+
icon = "sort-alpha-down"
|
32 |
+
|
33 |
+
def render(self, context: Context):
|
34 |
+
st.title(self.name)
|
35 |
+
with st.expander("π‘", expanded=True):
|
36 |
+
st.write("Show count, mean and median loss per token and label.")
|
37 |
+
|
38 |
+
col1, _, col2 = st.columns([8, 1, 6])
|
39 |
+
|
40 |
+
with col1:
|
41 |
+
st.subheader("π¬ Loss by Token")
|
42 |
+
|
43 |
+
st.session_state["_merge_tokens"] = st.checkbox(
|
44 |
+
"Merge tokens", value=True, key="merge_tokens"
|
45 |
+
)
|
46 |
+
loss_by_token = (
|
47 |
+
get_loss_by_token(context.df_tokens_merged)
|
48 |
+
if st.session_state["merge_tokens"]
|
49 |
+
else get_loss_by_token(context.df_tokens_cleaned)
|
50 |
+
)
|
51 |
+
aggrid_interactive_table(loss_by_token.round(3))
|
52 |
+
# st.subheader("π·οΈ Loss by Label")
|
53 |
+
# loss_by_label = get_loss_by_label(df_tokens_cleaned)
|
54 |
+
# st.dataframe(loss_by_label)
|
55 |
+
|
56 |
+
st.write(
|
57 |
+
"_Attention: This statistic disregards that tokens have contextual representations._"
|
58 |
+
)
|
59 |
+
|
60 |
+
with col2:
|
61 |
+
st.subheader("π·οΈ Loss by Label")
|
62 |
+
loss_by_label = get_loss_by_label(context.df_tokens_cleaned)
|
63 |
+
AgGrid(loss_by_label.round(3), height=200)
|
subpages/lossy_samples.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from subpages.page import Context, Page
|
5 |
+
from utils import colorize_classes, get_bg_color, get_fg_color, htmlify_labeled_example
|
6 |
+
|
7 |
+
|
8 |
+
class LossySamplesPage(Page):
|
9 |
+
name = "Samples by Loss"
|
10 |
+
icon = "sort-numeric-down-alt"
|
11 |
+
|
12 |
+
def get_widget_defaults(self):
|
13 |
+
return {
|
14 |
+
"skip_correct": True,
|
15 |
+
"samples_by_loss_show_df": True,
|
16 |
+
}
|
17 |
+
|
18 |
+
def render(self, context: Context):
|
19 |
+
st.title(self.name)
|
20 |
+
with st.expander("π‘", expanded=True):
|
21 |
+
st.write("Show every example sorted by loss (descending) for close inspection.")
|
22 |
+
|
23 |
+
st.subheader("π₯ Samples β¬loss")
|
24 |
+
skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct")
|
25 |
+
show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df")
|
26 |
+
|
27 |
+
st.write(
|
28 |
+
"""<style>
|
29 |
+
thead {
|
30 |
+
display: none;
|
31 |
+
}
|
32 |
+
td {
|
33 |
+
white-space: nowrap;
|
34 |
+
padding: 0 5px !important;
|
35 |
+
}
|
36 |
+
</style>""",
|
37 |
+
unsafe_allow_html=True,
|
38 |
+
)
|
39 |
+
|
40 |
+
top_indices = (
|
41 |
+
context.df.sort_values(by="total_loss", ascending=False)
|
42 |
+
.query("total_loss > 0.5")
|
43 |
+
.index
|
44 |
+
)
|
45 |
+
|
46 |
+
cnt = 0
|
47 |
+
for idx in top_indices:
|
48 |
+
sample = context.df_tokens_merged.loc[idx]
|
49 |
+
|
50 |
+
if isinstance(sample, pd.Series):
|
51 |
+
continue
|
52 |
+
|
53 |
+
if skip_correct and sum(sample.labels != sample.preds) == 0:
|
54 |
+
continue
|
55 |
+
|
56 |
+
if show_df:
|
57 |
+
|
58 |
+
def colorize_col(col):
|
59 |
+
if col.name == "labels" or col.name == "preds":
|
60 |
+
bgs = []
|
61 |
+
fgs = []
|
62 |
+
ops = []
|
63 |
+
for v in col.values:
|
64 |
+
bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
|
65 |
+
fgs.append(get_fg_color(bgs[-1]))
|
66 |
+
ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5")
|
67 |
+
return [
|
68 |
+
f"background-color: {bg}; color: {fg}; opacity: {op};"
|
69 |
+
for bg, fg, op in zip(bgs, fgs, ops)
|
70 |
+
]
|
71 |
+
return [""] * len(col)
|
72 |
+
|
73 |
+
df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3)
|
74 |
+
losses_slice = pd.IndexSlice["losses", :]
|
75 |
+
# x = df.T.astype(str)
|
76 |
+
# st.dataframe(x)
|
77 |
+
# st.dataframe(x.loc[losses_slice])
|
78 |
+
styler = (
|
79 |
+
df.T.style.apply(colorize_col, axis=1)
|
80 |
+
.bar(subset=losses_slice, axis=1)
|
81 |
+
.format(precision=3)
|
82 |
+
)
|
83 |
+
# styler.data = styler.data.astype(str)
|
84 |
+
st.write(styler.to_html(), unsafe_allow_html=True)
|
85 |
+
# st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses')) # type: ignore
|
86 |
+
# st.write(
|
87 |
+
# colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str))
|
88 |
+
# )
|
89 |
+
|
90 |
+
col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32])
|
91 |
+
|
92 |
+
cnt += 1
|
93 |
+
counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>"
|
94 |
+
loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>πΏ {sample.losses.sum():.3f}</span>"
|
95 |
+
col1.write(f"{counter}{loss}", unsafe_allow_html=True)
|
96 |
+
col1.write("")
|
97 |
+
|
98 |
+
col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
|
99 |
+
# st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)
|
subpages/metrics.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import plotly.express as px
|
7 |
+
import streamlit as st
|
8 |
+
from seqeval.metrics import classification_report
|
9 |
+
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
|
10 |
+
|
11 |
+
from subpages.page import Context, Page
|
12 |
+
|
13 |
+
|
14 |
+
def _get_evaluation(df):
|
15 |
+
y_true = df.apply(lambda row: [lbl for lbl in row.labels if lbl != "IGN"], axis=1)
|
16 |
+
y_pred = df.apply(
|
17 |
+
lambda row: [pred for (pred, lbl) in zip(row.preds, row.labels) if lbl != "IGN"],
|
18 |
+
axis=1,
|
19 |
+
)
|
20 |
+
report: str = classification_report(y_true, y_pred, scheme="IOB2", digits=3) # type: ignore
|
21 |
+
return report.replace(
|
22 |
+
"precision recall f1-score support",
|
23 |
+
"=" * 12 + " precision recall f1-score support",
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True):
|
28 |
+
cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels)
|
29 |
+
if zero_diagonal:
|
30 |
+
np.fill_diagonal(cm, 0)
|
31 |
+
|
32 |
+
# st.write(plt.rcParams["font.size"])
|
33 |
+
# plt.rcParams.update({'font.size': 10.0})
|
34 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
35 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
|
36 |
+
fmt = "d" if normalize is None else ".3f"
|
37 |
+
disp.plot(
|
38 |
+
cmap="Blues",
|
39 |
+
include_values=True,
|
40 |
+
xticks_rotation="vertical",
|
41 |
+
values_format=fmt,
|
42 |
+
ax=ax,
|
43 |
+
colorbar=False,
|
44 |
+
)
|
45 |
+
return fig
|
46 |
+
|
47 |
+
|
48 |
+
class MetricsPage(Page):
|
49 |
+
name = "Metrics"
|
50 |
+
icon = "graph-up-arrow"
|
51 |
+
|
52 |
+
def get_widget_defaults(self):
|
53 |
+
return {
|
54 |
+
"normalize": True,
|
55 |
+
"zero_diagonal": False,
|
56 |
+
}
|
57 |
+
|
58 |
+
def render(self, context: Context):
|
59 |
+
st.title(self.name)
|
60 |
+
with st.expander("π‘", expanded=True):
|
61 |
+
st.write(
|
62 |
+
"The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)."
|
63 |
+
)
|
64 |
+
|
65 |
+
eval_results = _get_evaluation(context.df)
|
66 |
+
if len(eval_results.splitlines()) < 8:
|
67 |
+
col1, _, col2 = st.columns([8, 1, 1])
|
68 |
+
else:
|
69 |
+
col1 = col2 = st
|
70 |
+
|
71 |
+
col1.subheader("π― Evaluation Results")
|
72 |
+
col1.code(eval_results)
|
73 |
+
|
74 |
+
results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]]
|
75 |
+
data = [(r[0], int(r[-1]), float(r[-2])) for r in results]
|
76 |
+
df = pd.DataFrame(data, columns="class support f1".split())
|
77 |
+
fig = px.scatter(
|
78 |
+
df,
|
79 |
+
x="support",
|
80 |
+
y="f1",
|
81 |
+
range_y=(0, 1.05),
|
82 |
+
color="class",
|
83 |
+
)
|
84 |
+
# fig.update_layout(title_text="asdf", title_yanchor="bottom")
|
85 |
+
col1.plotly_chart(fig)
|
86 |
+
|
87 |
+
col2.subheader("π Confusion Matrix")
|
88 |
+
normalize = None if not col2.checkbox("Normalize", key="normalize") else "true"
|
89 |
+
zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal")
|
90 |
+
col2.pyplot(
|
91 |
+
plot_confusion_matrix(
|
92 |
+
y_true=context.df_tokens_cleaned["labels"],
|
93 |
+
y_preds=context.df_tokens_cleaned["preds"],
|
94 |
+
labels=context.labels,
|
95 |
+
normalize=normalize,
|
96 |
+
zero_diagonal=zero_diagonal,
|
97 |
+
),
|
98 |
+
)
|
subpages/misclassified.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
from sklearn.metrics import confusion_matrix
|
6 |
+
|
7 |
+
from subpages.page import Context, Page
|
8 |
+
from utils import htmlify_labeled_example
|
9 |
+
|
10 |
+
|
11 |
+
class MisclassifiedPage(Page):
|
12 |
+
name = "Misclassified"
|
13 |
+
icon = "x-octagon"
|
14 |
+
|
15 |
+
def render(self, context: Context):
|
16 |
+
st.title(self.name)
|
17 |
+
with st.expander("π‘", expanded=True):
|
18 |
+
st.write(
|
19 |
+
"This page contains all misclassified examples and allows filtering by specific error types."
|
20 |
+
)
|
21 |
+
|
22 |
+
misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
|
23 |
+
misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
|
24 |
+
cm = confusion_matrix(
|
25 |
+
misclassified_samples.labels,
|
26 |
+
misclassified_samples.preds,
|
27 |
+
labels=context.labels,
|
28 |
+
)
|
29 |
+
|
30 |
+
# st.pyplot(
|
31 |
+
# plot_confusion_matrix(
|
32 |
+
# y_preds=misclassified_samples["preds"],
|
33 |
+
# y_true=misclassified_samples["labels"],
|
34 |
+
# labels=labels,
|
35 |
+
# normalize=None,
|
36 |
+
# zero_diagonal=True,
|
37 |
+
# ),
|
38 |
+
# )
|
39 |
+
df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
|
40 |
+
import numpy as np
|
41 |
+
|
42 |
+
np.fill_diagonal(df.values, "")
|
43 |
+
st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
|
44 |
+
# import matplotlib.pyplot as plt
|
45 |
+
# st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
|
46 |
+
# selection = aggrid_interactive_table(df)
|
47 |
+
|
48 |
+
# st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
|
49 |
+
|
50 |
+
confusions = defaultdict(int)
|
51 |
+
for i, row in enumerate(cm):
|
52 |
+
for j, _ in enumerate(row):
|
53 |
+
if i == j or cm[i][j] == 0:
|
54 |
+
continue
|
55 |
+
confusions[(context.labels[i], context.labels[j])] += cm[i][j]
|
56 |
+
|
57 |
+
def format_func(item):
|
58 |
+
return (
|
59 |
+
f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
|
60 |
+
)
|
61 |
+
|
62 |
+
conf = st.radio(
|
63 |
+
"Filter by Class Confusion",
|
64 |
+
options=list(zip(confusions.keys(), confusions.values())),
|
65 |
+
format_func=format_func,
|
66 |
+
)
|
67 |
+
|
68 |
+
# st.write(
|
69 |
+
# f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
|
70 |
+
# )
|
71 |
+
|
72 |
+
filtered_indices = misclassified_samples.query(
|
73 |
+
f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
|
74 |
+
).index
|
75 |
+
for i, idx in enumerate(filtered_indices):
|
76 |
+
sample = context.df_tokens_merged.loc[idx]
|
77 |
+
st.write(
|
78 |
+
htmlify_labeled_example(sample),
|
79 |
+
unsafe_allow_html=True,
|
80 |
+
)
|
81 |
+
st.write("---")
|
subpages/page.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from datasets import Dataset # type: ignore
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from transformers import AutoModelForSequenceClassification # type: ignore
|
8 |
+
from transformers import AutoTokenizer # type: ignore
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class Context:
|
13 |
+
model: AutoModelForSequenceClassification
|
14 |
+
tokenizer: AutoTokenizer
|
15 |
+
sentence_encoder: SentenceTransformer
|
16 |
+
tags: Any
|
17 |
+
df: pd.DataFrame
|
18 |
+
df_tokens: pd.DataFrame
|
19 |
+
df_tokens_cleaned: pd.DataFrame
|
20 |
+
df_tokens_merged: pd.DataFrame
|
21 |
+
split_sample_size: int
|
22 |
+
ds_name: str
|
23 |
+
ds_config_name: str
|
24 |
+
ds_split_name: str
|
25 |
+
split: Dataset
|
26 |
+
labels: list[str]
|
27 |
+
|
28 |
+
|
29 |
+
class Page:
|
30 |
+
name: str
|
31 |
+
icon: str
|
32 |
+
|
33 |
+
def get_widget_defaults(self):
|
34 |
+
return {}
|
35 |
+
|
36 |
+
def render(self, context):
|
37 |
+
...
|
subpages/probing.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from subpages.page import Context, Page
|
4 |
+
from utils import device, tag_text
|
5 |
+
|
6 |
+
_DEFAULT_SENTENCES = """
|
7 |
+
Damit hatte er auf ihr letztes , vΓΆllig schiefgelaufenes GeschΓ€ftsessen angespielt .
|
8 |
+
Damit einher geht ΓΌbrigens auch , dass Marcella , Collocinis Tochter , keine wie auch immer geartete strafrechtliche Verfolgung zu befΓΌrchten hat .
|
9 |
+
Nach dem Bell β schen Theorem , einer Physik jenseits der Quanten , ist die Welt , die wir fΓΌr real halten , nicht objektivierbar .
|
10 |
+
Dazu muss man wiederum wissen , dass die Aussagekraft von Tests , neben der SensitivitΓ€t und SpezifitΓ€t , ganz entscheidend von der Vortestwahrscheinlichkeit abhΓ€ngt .
|
11 |
+
Haben Sie sich schon eingelebt ? Β« erkundigte er sich .
|
12 |
+
Das Auto ein Totalschaden , mein Beifahrer ein weinender Jammerlappen .
|
13 |
+
Seltsam , wunderte sie sich , dass das StΓΌck nach mehr als eineinhalb Jahrhunderten noch so gut in Schuss ist .
|
14 |
+
Oder auf den Strich gehen , StrΓΌmpfe stricken , Geld hamstern .
|
15 |
+
Und Allah ist Allumfassend Allwissend .
|
16 |
+
Und Pedro Moacir redete weiter : Β» Verzicht , Pater Antonio , Verzicht , zu groΓer Schmerz ΓΌber Verzicht , Sehnsucht , die sich nicht erfΓΌllt , die sich nicht erfΓΌllen kann , das sind Qualen , die ein Verstummen nach sich ziehen kΓΆnnen , oder HΓ€rte .
|
17 |
+
Mama-San ging mittlerweile fast ausnahmslos nur mit Wei an ihrer Seite aus dem Haus , kaum je mit einem der MΓ€dchen und niemals allein.
|
18 |
+
""".strip()
|
19 |
+
_DEFAULT_SENTENCES = """
|
20 |
+
Elon Muskβs Berghain humiliation β I know the feeling
|
21 |
+
Musk was also seen at a local spot called Sisyphos celebrating entrepreneur Adeo Ressi's birthday, according to The Times.
|
22 |
+
""".strip()
|
23 |
+
|
24 |
+
|
25 |
+
class ProbingPage(Page):
|
26 |
+
name = "Probing"
|
27 |
+
icon = "fonts"
|
28 |
+
|
29 |
+
def get_widget_defaults(self):
|
30 |
+
return {"probing_textarea": _DEFAULT_SENTENCES}
|
31 |
+
|
32 |
+
def render(self, context: Context):
|
33 |
+
st.title("π Interactive Probing")
|
34 |
+
|
35 |
+
with st.expander("π‘", expanded=True):
|
36 |
+
st.write(
|
37 |
+
"A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors."
|
38 |
+
)
|
39 |
+
|
40 |
+
sentences = st.text_area("Sentences", height=200, key="probing_textarea")
|
41 |
+
if not sentences.strip():
|
42 |
+
return
|
43 |
+
sentences = [sentence.strip() for sentence in sentences.splitlines()]
|
44 |
+
|
45 |
+
for sent in sentences:
|
46 |
+
sent = sent.replace(",", "").replace(" ", " ")
|
47 |
+
with st.expander(sent):
|
48 |
+
tagged = tag_text(sent, context.tokenizer, context.model, device)
|
49 |
+
tagged = tagged.astype(str)
|
50 |
+
tagged["probs"] = tagged["probs"].apply(lambda x: x[:-2])
|
51 |
+
tagged["check"] = tagged["probs"].apply(
|
52 |
+
lambda x: "β
β
" if int(x) < 100 else "β
" if int(x) < 1000 else ""
|
53 |
+
)
|
54 |
+
st.dataframe(tagged.drop("hidden_states", axis=1).T)
|
subpages/random_samples.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from subpages.page import Context, Page
|
5 |
+
from utils import htmlify_labeled_example
|
6 |
+
|
7 |
+
|
8 |
+
class RandomSamplesPage(Page):
|
9 |
+
name = "Random Samples"
|
10 |
+
icon = "shuffle"
|
11 |
+
|
12 |
+
def get_widget_defaults(self):
|
13 |
+
return {
|
14 |
+
"random_sample_size_min": 128,
|
15 |
+
}
|
16 |
+
|
17 |
+
def render(self, context: Context):
|
18 |
+
st.title("π² Random Samples")
|
19 |
+
with st.expander("π‘", expanded=True):
|
20 |
+
st.write(
|
21 |
+
"Show random samples. Simple idea, but often it turns up some interesting things."
|
22 |
+
)
|
23 |
+
|
24 |
+
random_sample_size = st.number_input(
|
25 |
+
"Random sample size:",
|
26 |
+
value=min(st.session_state.random_sample_size_min, context.split_sample_size),
|
27 |
+
step=16,
|
28 |
+
key="random_sample_size",
|
29 |
+
)
|
30 |
+
|
31 |
+
if st.button("π² Resample"):
|
32 |
+
st.experimental_rerun()
|
33 |
+
|
34 |
+
random_indices = context.df.sample(int(random_sample_size)).index
|
35 |
+
samples = context.df_tokens_merged.loc[random_indices]
|
36 |
+
return
|
37 |
+
|
38 |
+
for i, idx in enumerate(random_indices):
|
39 |
+
sample = samples.loc[idx]
|
40 |
+
|
41 |
+
if isinstance(sample, pd.Series):
|
42 |
+
continue
|
43 |
+
|
44 |
+
col1, _, col2 = st.columns([0.08, 0.025, 0.8])
|
45 |
+
|
46 |
+
counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: wh^; padding: 0 5px'>[{i+1} | {idx}]</span>"
|
47 |
+
loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>πΏ {sample.losses.sum():.3f}</span>"
|
48 |
+
col1.write(f"{counter}{loss}", unsafe_allow_html=True)
|
49 |
+
col1.write("")
|
50 |
+
st.write(sample.astype(str))
|
51 |
+
col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
|
subpages/raw_data.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from subpages.page import Context, Page
|
5 |
+
from utils import aggrid_interactive_table
|
6 |
+
|
7 |
+
|
8 |
+
@st.cache
|
9 |
+
def convert_df(df):
|
10 |
+
return df.to_csv().encode("utf-8")
|
11 |
+
|
12 |
+
|
13 |
+
class RawDataPage(Page):
|
14 |
+
name = "Raw data"
|
15 |
+
icon = "qr-code"
|
16 |
+
|
17 |
+
def render(self, context: Context):
|
18 |
+
st.title(self.name)
|
19 |
+
with st.expander("π‘", expanded=True):
|
20 |
+
st.write("See the data as seen by your model.")
|
21 |
+
|
22 |
+
st.subheader("Dataset")
|
23 |
+
st.code(
|
24 |
+
f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
|
25 |
+
)
|
26 |
+
|
27 |
+
st.write("**Data after processing and inference**")
|
28 |
+
|
29 |
+
processed_df = (
|
30 |
+
context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
|
31 |
+
)
|
32 |
+
cols = (
|
33 |
+
"ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
|
34 |
+
)
|
35 |
+
if "token_type_ids" not in processed_df.columns:
|
36 |
+
cols.remove("token_type_ids")
|
37 |
+
processed_df = processed_df[cols]
|
38 |
+
aggrid_interactive_table(processed_df)
|
39 |
+
processed_df_csv = convert_df(processed_df)
|
40 |
+
st.download_button(
|
41 |
+
"Download csv",
|
42 |
+
processed_df_csv,
|
43 |
+
"processed_data.csv",
|
44 |
+
"text/csv",
|
45 |
+
)
|
46 |
+
|
47 |
+
st.write("**Raw data (exploded by tokens)**")
|
48 |
+
raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore
|
49 |
+
aggrid_interactive_table(raw_data_df)
|
50 |
+
raw_data_df_csv = convert_df(raw_data_df)
|
51 |
+
st.download_button(
|
52 |
+
"Download csv",
|
53 |
+
raw_data_df_csv,
|
54 |
+
"raw_data.csv",
|
55 |
+
"text/csv",
|
56 |
+
)
|
utils.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib as matplotlib
|
2 |
+
import matplotlib.cm as cm
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
import tokenizers
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
|
9 |
+
|
10 |
+
tokenizer_hash_funcs = {
|
11 |
+
tokenizers.Tokenizer: lambda _: None,
|
12 |
+
tokenizers.AddedToken: lambda _: None,
|
13 |
+
}
|
14 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu")
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
|
17 |
+
|
18 |
+
def aggrid_interactive_table(df: pd.DataFrame) -> dict:
|
19 |
+
"""Creates an st-aggrid interactive table based on a dataframe.
|
20 |
+
Args:
|
21 |
+
df (pd.DataFrame]): Source dataframe
|
22 |
+
Returns:
|
23 |
+
dict: The selected row
|
24 |
+
"""
|
25 |
+
options = GridOptionsBuilder.from_dataframe(
|
26 |
+
df, enableRowGroup=True, enableValue=True, enablePivot=True
|
27 |
+
)
|
28 |
+
|
29 |
+
options.configure_side_bar()
|
30 |
+
# options.configure_default_column(cellRenderer=JsCode('''function(params) {return '<a href="#samples-loss">'+params.value+'</a>'}'''))
|
31 |
+
|
32 |
+
options.configure_selection("single")
|
33 |
+
selection = AgGrid(
|
34 |
+
df,
|
35 |
+
enable_enterprise_modules=True,
|
36 |
+
gridOptions=options.build(),
|
37 |
+
theme="light",
|
38 |
+
update_mode=GridUpdateMode.NO_UPDATE,
|
39 |
+
allow_unsafe_jscode=True,
|
40 |
+
)
|
41 |
+
|
42 |
+
return selection
|
43 |
+
|
44 |
+
|
45 |
+
def explode_df(df: pd.DataFrame) -> pd.DataFrame:
|
46 |
+
df_tokens = df.apply(pd.Series.explode)
|
47 |
+
if "losses" in df.columns:
|
48 |
+
df_tokens["losses"] = df_tokens["losses"].astype(float)
|
49 |
+
return df_tokens # type: ignore
|
50 |
+
|
51 |
+
|
52 |
+
def align_sample(row: pd.Series):
|
53 |
+
"""Use word_ids to align all lists in a sample."""
|
54 |
+
|
55 |
+
columns = row.axes[0].to_list()
|
56 |
+
indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
|
57 |
+
|
58 |
+
out = {}
|
59 |
+
|
60 |
+
tokens = []
|
61 |
+
for i, tok in enumerate(row.tokens):
|
62 |
+
if row.word_ids[i] == -1:
|
63 |
+
continue
|
64 |
+
|
65 |
+
if row.word_ids[i] != row.word_ids[i - 1]:
|
66 |
+
tokens.append(tok.lstrip("β").lstrip("##").rstrip("@@"))
|
67 |
+
else:
|
68 |
+
tokens[-1] += tok.lstrip("β").lstrip("##").rstrip("@@")
|
69 |
+
out["tokens"] = tokens
|
70 |
+
|
71 |
+
if "labels" in columns:
|
72 |
+
out["labels"] = [row.labels[i] for i in indices]
|
73 |
+
|
74 |
+
if "preds" in columns:
|
75 |
+
out["preds"] = [row.preds[i] for i in indices]
|
76 |
+
|
77 |
+
if "losses" in columns:
|
78 |
+
out["losses"] = [row.losses[i] for i in indices]
|
79 |
+
|
80 |
+
if "probs" in columns:
|
81 |
+
out["probs"] = [row.probs[i] for i in indices]
|
82 |
+
|
83 |
+
if "hidden_states" in columns:
|
84 |
+
out["hidden_states"] = [row.hidden_states[i] for i in indices]
|
85 |
+
|
86 |
+
if "ids" in columns:
|
87 |
+
out["ids"] = row.ids
|
88 |
+
|
89 |
+
assert len(tokens) == len(out["preds"]), (tokens, row.tokens)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
@st.cache(
|
95 |
+
allow_output_mutation=True,
|
96 |
+
hash_funcs=tokenizer_hash_funcs,
|
97 |
+
)
|
98 |
+
def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
|
99 |
+
"""Create an (exploded) DataFrame with the predicted labels and probabilities."""
|
100 |
+
|
101 |
+
tokens = tokenizer(text).tokens()
|
102 |
+
tokenized = tokenizer(text, return_tensors="pt")
|
103 |
+
word_ids = [w if w is not None else -1 for w in tokenized.word_ids()]
|
104 |
+
input_ids = tokenized.input_ids.to(device)
|
105 |
+
outputs = model(input_ids, output_hidden_states=True)
|
106 |
+
preds = torch.argmax(outputs.logits, dim=2)
|
107 |
+
preds = [model.config.id2label[p] for p in preds[0].cpu().numpy()]
|
108 |
+
hidden_states = outputs.hidden_states[-1][0].detach().cpu().numpy()
|
109 |
+
# hidden_states = np.mean([hidden_states, outputs.hidden_states[0][0].detach().cpu().numpy()], axis=0)
|
110 |
+
|
111 |
+
probs = 1 // (
|
112 |
+
torch.min(F.softmax(outputs.logits, dim=-1), dim=-1).values[0].detach().cpu().numpy()
|
113 |
+
)
|
114 |
+
|
115 |
+
df = pd.DataFrame(
|
116 |
+
[[tokens, word_ids, preds, probs, hidden_states]],
|
117 |
+
columns="tokens word_ids preds probs hidden_states".split(),
|
118 |
+
)
|
119 |
+
merged_df = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
|
120 |
+
return explode_df(merged_df).reset_index().drop(columns=["index"])
|
121 |
+
|
122 |
+
|
123 |
+
def get_bg_color(label):
|
124 |
+
return st.session_state[f"color_{label}"]
|
125 |
+
|
126 |
+
|
127 |
+
def get_fg_color(hex_color: str) -> str:
|
128 |
+
"""Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/"""
|
129 |
+
r = int(hex_color[1:3], 16)
|
130 |
+
g = int(hex_color[3:5], 16)
|
131 |
+
b = int(hex_color[5:7], 16)
|
132 |
+
yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
|
133 |
+
return "black" if (yiq >= 128) else "white"
|
134 |
+
|
135 |
+
|
136 |
+
def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
|
137 |
+
"""Colorize the errors in the dataframe."""
|
138 |
+
|
139 |
+
def colorize_row(row):
|
140 |
+
return [
|
141 |
+
"background-color: "
|
142 |
+
+ ("white" if (row["labels"] == "IGN" or (row["preds"] == row["labels"])) else "pink")
|
143 |
+
+ ";"
|
144 |
+
] * len(row)
|
145 |
+
|
146 |
+
def colorize_col(col):
|
147 |
+
if col.name == "labels" or col.name == "preds":
|
148 |
+
bgs = []
|
149 |
+
fgs = []
|
150 |
+
for v in col.values:
|
151 |
+
bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
|
152 |
+
fgs.append(get_fg_color(bgs[-1]))
|
153 |
+
return [f"background-color: {bg}; color: {fg};" for bg, fg in zip(bgs, fgs)]
|
154 |
+
return [""] * len(col)
|
155 |
+
|
156 |
+
df = df.reset_index().drop(columns=["index"]).T
|
157 |
+
return df # .style.apply(colorize_col, axis=0)
|
158 |
+
|
159 |
+
|
160 |
+
def htmlify_labeled_example(example: pd.DataFrame) -> str:
|
161 |
+
html = []
|
162 |
+
classmap = {
|
163 |
+
"O": "O",
|
164 |
+
"PER": "π",
|
165 |
+
"person": "π",
|
166 |
+
"LOC": "π",
|
167 |
+
"location": "π",
|
168 |
+
"ORG": "π€",
|
169 |
+
"corporation": "π€",
|
170 |
+
"product": "π±",
|
171 |
+
"creative": "π·",
|
172 |
+
"MISC": "π·",
|
173 |
+
}
|
174 |
+
|
175 |
+
for _, row in example.iterrows():
|
176 |
+
pred = row.preds.split("-")[1] if "-" in row.preds else "O"
|
177 |
+
label = row.labels
|
178 |
+
label_class = row.labels.split("-")[1] if "-" in row.labels else "O"
|
179 |
+
|
180 |
+
color = get_bg_color(row.preds.split("-")[1]) if "-" in row.preds else "#000000"
|
181 |
+
true_color = get_bg_color(row.labels.split("-")[1]) if "-" in row.labels else "#000000"
|
182 |
+
|
183 |
+
font_color = get_fg_color(color) if color else "white"
|
184 |
+
true_font_color = get_fg_color(true_color) if true_color else "white"
|
185 |
+
|
186 |
+
is_correct = row.preds == row.labels
|
187 |
+
loss_html = (
|
188 |
+
""
|
189 |
+
if float(row.losses) < 0.01
|
190 |
+
else f"<span style='background-color: yellow; color: font_color; padding: 0 5px;'>{row.losses:.3f}</span>"
|
191 |
+
)
|
192 |
+
loss_html = ""
|
193 |
+
|
194 |
+
if row.labels == row.preds == "O":
|
195 |
+
html.append(f"<span>{row.tokens}</span>")
|
196 |
+
elif row.labels == "IGN":
|
197 |
+
assert False
|
198 |
+
else:
|
199 |
+
opacity = "1" if not is_correct else "0.5"
|
200 |
+
correct = (
|
201 |
+
""
|
202 |
+
if is_correct
|
203 |
+
else f"<span title='{label}' style='background-color: {true_color}; opacity: 1; color: {true_font_color}; padding: 0 5px; border: 1px solid black; min-width: 30px'>{classmap[label_class]}</span>"
|
204 |
+
)
|
205 |
+
pred_icon = classmap[pred] if pred != "O" and row.preds[:2] != "I-" else ""
|
206 |
+
html.append(
|
207 |
+
f"<span style='border: 1px solid black; color: {color}; padding: 0 5px;' title={row.preds}>{pred_icon + ' '}{row.tokens}</span>{correct}{loss_html}"
|
208 |
+
)
|
209 |
+
|
210 |
+
return " ".join(html)
|
211 |
+
|
212 |
+
|
213 |
+
def htmlify_example(example: pd.DataFrame) -> str:
|
214 |
+
corr_html = " ".join(
|
215 |
+
[
|
216 |
+
f", {row.tokens}" if row.labels == "B-COMMA" else row.tokens
|
217 |
+
for _, row in example.iterrows()
|
218 |
+
]
|
219 |
+
).strip()
|
220 |
+
return f"<em>{corr_html}</em>"
|
221 |
+
|
222 |
+
|
223 |
+
def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
|
224 |
+
"""Turn a value into a color using a color map."""
|
225 |
+
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
|
226 |
+
cmap = cm.get_cmap(cmap_name) # PiYG
|
227 |
+
rgba = cmap(norm(abs(value)))
|
228 |
+
color = matplotlib.colors.rgb2hex(rgba[:3])
|
229 |
+
return color
|