Alexander Seifert commited on
Commit
597bf7d
Β·
1 Parent(s): 9b03751

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.gitignore.io/api/python,osx,linux
3
+ # Edit at https://www.gitignore.io/?templates=python,osx,linux
4
+
5
+ ### Linux ###
6
+ *~
7
+
8
+ # temporary files which can be created if a process still has a handle open of a deleted file
9
+ .fuse_hidden*
10
+
11
+ # KDE directory preferences
12
+ .directory
13
+
14
+ # Linux trash folder which might appear on any partition or disk
15
+ .Trash-*
16
+
17
+ # .nfs files are created when an open file is removed but is still being accessed
18
+ .nfs*
19
+
20
+ ### OSX ###
21
+ # General
22
+ .DS_Store
23
+ .AppleDouble
24
+ .LSOverride
25
+
26
+ # Icon must end with two \r
27
+ Icon
28
+
29
+ # Thumbnails
30
+ ._*
31
+
32
+ # Files that might appear in the root of a volume
33
+ .DocumentRevisions-V100
34
+ .fseventsd
35
+ .Spotlight-V100
36
+ .TemporaryItems
37
+ .Trashes
38
+ .VolumeIcon.icns
39
+ .com.apple.timemachine.donotpresent
40
+
41
+ # Directories potentially created on remote AFP share
42
+ .AppleDB
43
+ .AppleDesktop
44
+ Network Trash Folder
45
+ Temporary Items
46
+ .apdisk
47
+
48
+ ### Python ###
49
+ # Byte-compiled / optimized / DLL files
50
+ __pycache__/
51
+ *.py[cod]
52
+ *$py.class
53
+
54
+ # C extensions
55
+ *.so
56
+
57
+ # Distribution / packaging
58
+ .Python
59
+ build/
60
+ develop-eggs/
61
+ dist/
62
+ downloads/
63
+ eggs/
64
+ .eggs/
65
+ lib/
66
+ lib64/
67
+ parts/
68
+ sdist/
69
+ var/
70
+ wheels/
71
+ pip-wheel-metadata/
72
+ share/python-wheels/
73
+ *.egg-info/
74
+ .installed.cfg
75
+ *.egg
76
+ MANIFEST
77
+
78
+ # PyInstaller
79
+ # Usually these files are written by a python script from a template
80
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
81
+ *.manifest
82
+ *.spec
83
+
84
+ # Installer logs
85
+ pip-log.txt
86
+ pip-delete-this-directory.txt
87
+
88
+ # Unit test / coverage reports
89
+ htmlcov/
90
+ .tox/
91
+ .nox/
92
+ .coverage
93
+ .coverage.*
94
+ .cache
95
+ nosetests.xml
96
+ coverage.xml
97
+ *.cover
98
+ .hypothesis/
99
+ .pytest_cache/
100
+
101
+ # Translations
102
+ *.mo
103
+ *.pot
104
+
105
+ # Scrapy stuff:
106
+ .scrapy
107
+
108
+ # Sphinx documentation
109
+ docs/_build/
110
+
111
+ # PyBuilder
112
+ target/
113
+
114
+ # pyenv
115
+ .python-version
116
+
117
+ # pipenv
118
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
119
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
120
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
121
+ # install all needed dependencies.
122
+ #Pipfile.lock
123
+
124
+ # celery beat schedule file
125
+ celerybeat-schedule
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Spyder project settings
131
+ .spyderproject
132
+ .spyproject
133
+
134
+ # Rope project settings
135
+ .ropeproject
136
+
137
+ # Mr Developer
138
+ .mr.developer.cfg
139
+ .project
140
+ .pydevproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # End of https://www.gitignore.io/api/python,osx,linux
154
+
155
+ .idea/
156
+ .ipynb_checkpoints/
157
+ node_modules/
158
+ data/books/
159
+ docx/cache/
160
+ docx/data/
161
+ save_dir/
162
+ cache_dir/
163
+ outputs/
164
+ models/
165
+ runs/
README.md CHANGED
@@ -1,12 +1,46 @@
1
- ---
2
- title: ExplaiNER
3
- emoji: πŸ“‰
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: streamlit
7
- sdk_version: 1.10.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ExplaiNER
2
+
3
+ Error Analysis is an important but often overlooked part of the data science project lifecycle, for which there is still very little tooling available. Due to the lack of tooling, practitioners often write throwaway code or, worse, skip understanding their models' errors altogether. This project tries to provide an extensive toolkit to probe any NER model/dataset combination, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further improvements.
4
+
5
+
6
+ ## Sections
7
+
8
+ ### Probing
9
+
10
+ A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors.
11
+
12
+ ### Embeddings
13
+
14
+ For every token in the dataset, we take its hidden state and using TruncatedSVD we project it onto a two-dimensional plane. Data points are colored by label, with mislabeled examples signified by a small black border.
15
+
16
+ ### Metrics
17
+
18
+ The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
19
+
20
+ ### Misclassified
21
+
22
+ asdf
23
+
24
+ ### Loss by Token/Label
25
+
26
+ Shows count, mean and median loss per token and label.
27
+
28
+ ### Samples by Loss
29
+
30
+ Shows every example sorted by loss (descending) for close inspection.
31
+
32
+ ### Random Samples
33
+
34
+ Shows random samples. Simple idea, but often it turns up some interesting things.
35
+
36
+ ### Inspect
37
+
38
+ Inspect your whole dataset, either unfiltered or by id.
39
+
40
+ ### Raw data
41
+
42
+ See the data as seen by your model.
43
+
44
+ ### Debug
45
+
46
+ Some debug info.
data.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import pandas as pd
4
+ import streamlit as st
5
+ import torch
6
+ from datasets import Dataset, DatasetDict, load_dataset # type: ignore
7
+ from torch.nn.functional import cross_entropy
8
+ from transformers import DataCollatorForTokenClassification # type: ignore
9
+
10
+ from utils import device, tokenizer_hash_funcs
11
+
12
+
13
+ @st.cache(allow_output_mutation=True)
14
+ def get_data(ds_name, config_name, split_name, split_sample_size) -> Dataset:
15
+ ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(seed=0) # type: ignore
16
+ split = ds[split_name].select(range(split_sample_size))
17
+ return split
18
+
19
+
20
+ @st.cache(
21
+ allow_output_mutation=True,
22
+ hash_funcs=tokenizer_hash_funcs,
23
+ )
24
+ def get_collator(tokenizer) -> DataCollatorForTokenClassification:
25
+ return DataCollatorForTokenClassification(tokenizer)
26
+
27
+
28
+ def create_word_ids_from_tokens(tokenizer, input_ids: list[int]):
29
+ word_ids = []
30
+ wid = -1
31
+ tokens = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
32
+
33
+ for i, tok in enumerate(tokens):
34
+ if tok in tokenizer.all_special_tokens:
35
+ word_ids.append(-1)
36
+ continue
37
+
38
+ if not tokens[i - 1].endswith("@@") and tokens[i - 1] != "<unk>":
39
+ wid += 1
40
+
41
+ word_ids.append(wid)
42
+
43
+ assert len(word_ids) == len(input_ids)
44
+ return word_ids
45
+
46
+
47
+ def tokenize_and_align_labels(examples, tokenizer):
48
+ tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
49
+ labels = []
50
+ wids = []
51
+
52
+ for idx, label in enumerate(examples["ner_tags"]):
53
+ try:
54
+ word_ids = tokenized_inputs.word_ids(batch_index=idx)
55
+ except ValueError:
56
+ word_ids = create_word_ids_from_tokens(tokenizer, tokenized_inputs["input_ids"][idx])
57
+ previous_word_idx = None
58
+ label_ids = []
59
+ for word_idx in word_ids:
60
+ if word_idx is -1 or word_idx is None or word_idx == previous_word_idx:
61
+ label_ids.append(-100)
62
+ else:
63
+ label_ids.append(label[word_idx])
64
+ previous_word_idx = word_idx
65
+ wids.append(word_ids)
66
+ labels.append(label_ids)
67
+ tokenized_inputs["word_ids"] = wids
68
+ tokenized_inputs["labels"] = labels
69
+ return tokenized_inputs
70
+
71
+
72
+ def stringify_ner_tags(batch, tags):
73
+ return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}
74
+
75
+
76
+ def encode_dataset(split, tokenizer):
77
+ tags = split.features["ner_tags"].feature
78
+ split = split.map(partial(stringify_ner_tags, tags=tags), batched=True)
79
+ remove_columns = split.column_names
80
+ ids = split["id"]
81
+ split = split.map(
82
+ partial(tokenize_and_align_labels, tokenizer=tokenizer),
83
+ batched=True,
84
+ remove_columns=remove_columns,
85
+ )
86
+ word_ids = [[id if id is not None else -1 for id in wids] for wids in split["word_ids"]]
87
+ return split.remove_columns(["word_ids"]), word_ids, ids
88
+
89
+
90
+ def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
91
+ # Convert dict of lists to list of dicts suitable for data collator
92
+ features = [dict(zip(batch, t)) for t in zip(*batch.values())]
93
+
94
+ # Pad inputs and labels and put all tensors on device
95
+ batch = collator(features)
96
+ input_ids = batch["input_ids"].to(device)
97
+ attention_mask = batch["attention_mask"].to(device)
98
+ labels = batch["labels"].to(device)
99
+
100
+ with torch.no_grad():
101
+ # Pass data through model
102
+ output = model(input_ids, attention_mask, output_hidden_states=True)
103
+ # logit.size: [batch_size, sequence_length, classes]
104
+
105
+ # Predict class with largest logit value on classes axis
106
+ preds = torch.argmax(output.logits, axis=-1).cpu().numpy() # type: ignore
107
+
108
+ # Calculate loss per token after flattening batch dimension with view
109
+ loss = cross_entropy(
110
+ output.logits.view(-1, num_classes), labels.view(-1), reduction="none"
111
+ )
112
+
113
+ # Unflatten batch dimension and convert to numpy array
114
+ loss = loss.view(len(input_ids), -1).cpu().numpy()
115
+ hidden_states = output.hidden_states[-1].cpu().numpy()
116
+
117
+ # logits = output.logits.view(len(input_ids), -1).cpu().numpy()
118
+
119
+ return {"losses": loss, "preds": preds, "hidden_states": hidden_states}
120
+
121
+
122
+ def get_split_df(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
123
+ split_encoded = split_encoded.map(
124
+ partial(
125
+ forward_pass_with_label,
126
+ model=model,
127
+ collator=collator,
128
+ num_classes=tags.num_classes,
129
+ ),
130
+ batched=True,
131
+ batch_size=8,
132
+ )
133
+ df: pd.DataFrame = split_encoded.to_pandas() # type: ignore
134
+
135
+ df["tokens"] = df["input_ids"].apply(
136
+ lambda x: tokenizer.convert_ids_to_tokens(x) # type: ignore
137
+ )
138
+ df["labels"] = df["labels"].apply(
139
+ lambda x: ["IGN" if i == -100 else tags.int2str(int(i)) for i in x]
140
+ )
141
+ df["preds"] = df["preds"].apply(lambda x: [model.config.id2label[i] for i in x])
142
+ df["preds"] = df.apply(lambda x: x["preds"][: len(x["input_ids"])], axis=1)
143
+ df["losses"] = df.apply(lambda x: x["losses"][: len(x["input_ids"])], axis=1)
144
+ df["hidden_states"] = df.apply(lambda x: x["hidden_states"][: len(x["input_ids"])], axis=1)
145
+ df["total_loss"] = df["losses"].apply(sum)
146
+
147
+ return df
load.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import pandas as pd
4
+ import streamlit as st
5
+ from datasets import Dataset # type: ignore
6
+
7
+ from data import encode_dataset, get_collator, get_data, get_split_df
8
+ from model import get_encoder, get_model, get_tokenizer
9
+ from subpages import Context
10
+ from utils import align_sample, device, explode_df
11
+
12
+ _TOKENIZER_NAME = (
13
+ "xlm-roberta-base",
14
+ "gagan3012/bert-tiny-finetuned-ner",
15
+ "distilbert-base-german-cased",
16
+ )[0]
17
+
18
+
19
+ def _load_models_and_tokenizer(
20
+ encoder_model_name: str,
21
+ model_name: str,
22
+ tokenizer_name: Optional[str],
23
+ device: str = "cpu",
24
+ ):
25
+ sentence_encoder = get_encoder(encoder_model_name, device=device)
26
+ tokenizer = get_tokenizer(tokenizer_name if tokenizer_name else model_name)
27
+ labels = "O B-COMMA".split() if "comma" in model_name else None
28
+ model = get_model(model_name, labels=labels)
29
+ return sentence_encoder, model, tokenizer
30
+
31
+
32
+ @st.cache(allow_output_mutation=True)
33
+ def load_context(
34
+ encoder_model_name: str,
35
+ model_name: str,
36
+ ds_name: str,
37
+ ds_config_name: str,
38
+ ds_split_name: str,
39
+ split_sample_size: int,
40
+ **kw_args,
41
+ ) -> Context:
42
+
43
+ sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
44
+ encoder_model_name=encoder_model_name,
45
+ model_name=model_name,
46
+ tokenizer_name=_TOKENIZER_NAME if "comma" in model_name else None,
47
+ device=str(device),
48
+ )
49
+ collator = get_collator(tokenizer)
50
+
51
+ # load data related stuff
52
+ split: Dataset = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
53
+ tags = split.features["ner_tags"].feature
54
+ split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
55
+
56
+ # transform into dataframe
57
+ df = get_split_df(split_encoded, model, tokenizer, collator, tags)
58
+ df["word_ids"] = word_ids
59
+ df["ids"] = ids
60
+
61
+ # explode, clean, merge
62
+ df_tokens = explode_df(df)
63
+ df_tokens_cleaned = df_tokens.query("labels != 'IGN'")
64
+ df_merged = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
65
+ df_tokens_merged = explode_df(df_merged)
66
+
67
+ return Context(
68
+ **{
69
+ "model": model,
70
+ "tokenizer": tokenizer,
71
+ "sentence_encoder": sentence_encoder,
72
+ "df": df,
73
+ "df_tokens": df_tokens,
74
+ "df_tokens_cleaned": df_tokens_cleaned,
75
+ "df_tokens_merged": df_tokens_merged,
76
+ "tags": tags,
77
+ "labels": tags.names,
78
+ "split_sample_size": split_sample_size,
79
+ "ds_name": ds_name,
80
+ "ds_config_name": ds_config_name,
81
+ "ds_split_name": ds_split_name,
82
+ "split": split,
83
+ }
84
+ )
main.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ from streamlit_option_menu import option_menu
4
+
5
+ from load import load_context
6
+ from subpages import (
7
+ DebugPage,
8
+ FindDuplicatesPage,
9
+ HomePage,
10
+ LossesPage,
11
+ LossySamplesPage,
12
+ MetricsPage,
13
+ MisclassifiedPage,
14
+ Page,
15
+ ProbingPage,
16
+ RandomSamplesPage,
17
+ RawDataPage,
18
+ )
19
+ from subpages.attention import AttentionPage
20
+ from subpages.embeddings import EmbeddingsPage
21
+ from subpages.inspect import InspectPage
22
+
23
+ sts = st.sidebar
24
+ st.set_page_config(
25
+ layout="wide",
26
+ page_title="Error Analysis",
27
+ page_icon="🏷️",
28
+ )
29
+
30
+
31
+ def _show_menu(pages: list[Page]) -> int:
32
+ with st.sidebar:
33
+ page_names = [p.name for p in pages]
34
+ page_icons = [p.icon for p in pages]
35
+ selected_menu_item = st.session_state.active_page = option_menu(
36
+ menu_title="ExplaiNER",
37
+ options=page_names,
38
+ icons=page_icons,
39
+ menu_icon="layout-wtf",
40
+ default_index=0,
41
+ )
42
+ return page_names.index(selected_menu_item)
43
+ assert False
44
+
45
+
46
+ def _initialize_session_state(pages: list[Page]):
47
+ if "active_page" not in st.session_state:
48
+ for page in pages:
49
+ st.session_state.update(**page.get_widget_defaults())
50
+ st.session_state.update(st.session_state)
51
+
52
+
53
+ def _write_color_legend(context):
54
+ def style(x):
55
+ return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]
56
+
57
+ labelmap = {
58
+ "O": "O",
59
+ "person": "πŸ™Ž",
60
+ "PER": "πŸ™Ž",
61
+ "location": "🌎",
62
+ "LOC": "🌎",
63
+ "corporation": "🏀",
64
+ "ORG": "🏀",
65
+ "product": "πŸ“±",
66
+ "creative": "🎷",
67
+ "group": "🎷",
68
+ "MISC": "🎷",
69
+ }
70
+
71
+ labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
72
+ colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]
73
+
74
+ color_legend_df = pd.DataFrame(
75
+ [labelmap[l] for l in labels], columns=["label"], index=labels
76
+ ).T
77
+ st.sidebar.write(
78
+ color_legend_df.T.style.apply(style, axis=0).set_properties(
79
+ **{"color": "white", "text-align": "center"}
80
+ )
81
+ )
82
+
83
+
84
+ def main():
85
+ pages: list[Page] = [
86
+ HomePage(),
87
+ AttentionPage(),
88
+ EmbeddingsPage(),
89
+ ProbingPage(),
90
+ MetricsPage(),
91
+ MisclassifiedPage(),
92
+ LossesPage(),
93
+ LossySamplesPage(),
94
+ RandomSamplesPage(),
95
+ FindDuplicatesPage(),
96
+ InspectPage(),
97
+ RawDataPage(),
98
+ DebugPage(),
99
+ ]
100
+
101
+ _initialize_session_state(pages)
102
+
103
+ selected_page_idx = _show_menu(pages)
104
+ selected_page = pages[selected_page_idx]
105
+
106
+ if isinstance(selected_page, HomePage):
107
+ selected_page.render()
108
+ return
109
+
110
+ if "model_name" not in st.session_state:
111
+ # this can happen if someone loads another page directly (without going through home)
112
+ st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'")
113
+ return
114
+
115
+ context = load_context(**st.session_state)
116
+ _write_color_legend(context)
117
+ selected_page.render(context)
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import AutoModelForTokenClassification # type: ignore
4
+ from transformers import AutoTokenizer # type: ignore
5
+
6
+
7
+ @st.experimental_singleton()
8
+ def get_model(model_name: str, labels=None):
9
+ if labels is None:
10
+ return AutoModelForTokenClassification.from_pretrained(
11
+ model_name,
12
+ output_attentions=True,
13
+ ) # type: ignore
14
+ else:
15
+ id2label = {idx: tag for idx, tag in enumerate(labels)}
16
+ label2id = {tag: idx for idx, tag in enumerate(labels)}
17
+ return AutoModelForTokenClassification.from_pretrained(
18
+ model_name,
19
+ output_attentions=True,
20
+ num_labels=len(labels),
21
+ id2label=id2label,
22
+ label2id=label2id,
23
+ ) # type: ignore
24
+
25
+
26
+ @st.experimental_singleton()
27
+ def get_encoder(model_name: str, device: str = "cpu"):
28
+ return SentenceTransformer(model_name, device=device)
29
+
30
+
31
+ @st.experimental_singleton()
32
+ def get_tokenizer(tokenizer_name: str):
33
+ return AutoTokenizer.from_pretrained(tokenizer_name)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ scikit-learn
4
+ plotly
5
+ sentence-transformers
6
+ transformers
7
+ tokenizers
8
+ datasets
9
+ numpy
10
+ matplotlib
11
+ seqeval
12
+ st_aggrid
13
+ streamlit_option_menu
14
+ git+git://github.com/aseifert/ecco@streamlit
subpages/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from subpages.attention import AttentionPage
2
+ from subpages.debug import DebugPage
3
+ from subpages.embeddings import EmbeddingsPage
4
+ from subpages.find_duplicates import FindDuplicatesPage
5
+ from subpages.home import HomePage
6
+ from subpages.inspect import InspectPage
7
+ from subpages.losses import LossesPage
8
+ from subpages.lossy_samples import LossySamplesPage
9
+ from subpages.metrics import MetricsPage
10
+ from subpages.misclassified import MisclassifiedPage
11
+ from subpages.page import Context, Page
12
+ from subpages.probing import ProbingPage
13
+ from subpages.random_samples import RandomSamplesPage
14
+ from subpages.raw_data import RawDataPage
subpages/attention.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ecco
2
+ import streamlit as st
3
+ from streamlit.components.v1 import html
4
+
5
+ from subpages.page import Context, Page # type: ignore
6
+
7
+ SETUP_HTML = """
8
+ <script src="https://requirejs.org/docs/release/2.3.6/minified/require.js"></script>
9
+ <script>
10
+ var ecco_url = 'https://storage.googleapis.com/ml-intro/ecco/'
11
+ //var ecco_url = 'http://localhost:8000/'
12
+
13
+ if (window.ecco === undefined) window.ecco = {}
14
+
15
+ // Setup the paths of the script we'll be using
16
+ requirejs.config({
17
+ urlArgs: "bust=" + (new Date()).getTime(),
18
+ nodeRequire: require,
19
+ paths: {
20
+ d3: "https://d3js.org/d3.v6.min", // This is only for use in setup.html and basic.html
21
+ "d3-array": "https://d3js.org/d3-array.v2.min",
22
+ jquery: "https://code.jquery.com/jquery-3.5.1.min",
23
+ ecco: ecco_url + 'js/0.0.6/ecco-bundle.min',
24
+ xregexp: 'https://cdnjs.cloudflare.com/ajax/libs/xregexp/3.2.0/xregexp-all.min'
25
+ }
26
+ });
27
+
28
+ // Add the css file
29
+ //requirejs(['d3'],
30
+ // function (d3) {
31
+ // d3.select('#css').attr('href', ecco_url + 'html/styles.css')
32
+ // })
33
+
34
+ console.log('Ecco initialize!!')
35
+
36
+ // returns a 'basic' object. basic.init() selects the html div we'll be
37
+ // rendering the html into, adds styles.css to the document.
38
+ define('basic', ['d3'],
39
+ function (d3) {
40
+ return {
41
+ init: function (viz_id = null) {
42
+ if (viz_id == null) {
43
+ viz_id = "viz_" + Math.round(Math.random() * 10000000)
44
+ }
45
+ // Select the div rendered below, change its id
46
+ const div = d3.select('#basic').attr('id', viz_id),
47
+ div_parent = d3.select('#' + viz_id).node().parentNode
48
+
49
+ // Link to CSS file
50
+ d3.select(div_parent).insert('link')
51
+ .attr('rel', 'stylesheet')
52
+ .attr('type', 'text/css')
53
+ .attr('href', ecco_url + 'html/0.0.2/styles.css')
54
+
55
+ return viz_id
56
+ }
57
+ }
58
+ }, function (err) {
59
+ console.log(err);
60
+ }
61
+ )
62
+ </script>
63
+
64
+ <head>
65
+ <link id='css' rel="stylesheet" type="text/css">
66
+ </head>
67
+ <div id="basic"></div>
68
+ """
69
+
70
+ JS_TEMPLATE = """requirejs(['basic', 'ecco'], function(basic, ecco){{
71
+ const viz_id = basic.init()
72
+
73
+ ecco.interactiveTokensAndFactorSparklines(viz_id, {},
74
+ {{
75
+ 'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}}
76
+ }}
77
+ }})
78
+ }}, function (err) {{
79
+ console.log(err);
80
+ }})"""
81
+
82
+
83
+ @st.cache(allow_output_mutation=True)
84
+ def load_ecco_model():
85
+ model_config = {
86
+ "embedding": "embeddings.word_embeddings",
87
+ "type": "mlm",
88
+ "activations": [r"ffn\.lin1"],
89
+ "token_prefix": "",
90
+ "partial_token_prefix": "##",
91
+ }
92
+ return ecco.from_pretrained(
93
+ "elastic/distilbert-base-uncased-finetuned-conll03-english",
94
+ model_config=model_config,
95
+ activations=True,
96
+ )
97
+
98
+
99
+ class AttentionPage(Page):
100
+ name = "Activations"
101
+ icon = "activity"
102
+
103
+ def get_widget_defaults(self):
104
+ return {
105
+ "act_n_components": 8,
106
+ "act_default_text": """Now I ask you: \n what can be expected of man since he is a being endowed with strange qualities? Shower upon him every earthly blessing, drown him in a sea of happiness, so that nothing but bubbles of bliss can be seen on the surface; give him economic prosperity, such that he should have nothing else to do but sleep, eat cakes and busy himself with the continuation of his species, and even then out of sheer ingratitude, sheer spite, man would play you some nasty trick. He would even risk his cakes and would deliberately desire the most fatal rubbish, the most uneconomical absurdity, simply to introduce into all this positive good sense his fatal fantastic element. It is just his fantastic dreams, his vulgar folly that he will desire to retain, simply in order to prove to himself--as though that were so necessary-- that men still are men and not the keys of a piano, which the laws of nature threaten to control so completely that soon one will be able to desire nothing but by the calendar. And that is not all: even if man really were nothing but a piano-key, even if this were proved to him by natural science and mathematics, even then he would not become reasonable, but would purposely do something perverse out of simple ingratitude, simply to gain his point. And if he does not find means he will contrive destruction and chaos, will contrive sufferings of all sorts, only to gain his point! He will launch a curse upon the world, and as only man can curse (it is his privilege, the primary distinction between him and other animals), may be by his curse alone he will attain his object--that is, convince himself that he is a man and not a piano-key!""",
107
+ "act_from_layer": 0,
108
+ "act_to_layer": 6,
109
+ }
110
+
111
+ def render(self, context: Context):
112
+ st.title(self.name)
113
+
114
+ with st.expander("ℹ️", expanded=True):
115
+ st.write(
116
+ "A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model."
117
+ )
118
+
119
+ lm = load_ecco_model()
120
+
121
+ col1, _, col2 = st.columns([1.5, 0.5, 4])
122
+ with col1:
123
+ st.subheader("Settings")
124
+ n_components = st.slider(
125
+ "#components",
126
+ key="act_n_components",
127
+ min_value=2,
128
+ max_value=10,
129
+ step=1,
130
+ )
131
+ from_layer = (
132
+ st.slider(
133
+ "from layer",
134
+ key="act_from_layer",
135
+ value=0,
136
+ min_value=0,
137
+ max_value=len(lm.model.transformer.layer) - 1,
138
+ step=1,
139
+ )
140
+ or None
141
+ )
142
+ to_layer = (
143
+ st.slider(
144
+ "to layer",
145
+ key="act_to_layer",
146
+ value=0,
147
+ min_value=0,
148
+ max_value=len(lm.model.transformer.layer),
149
+ step=1,
150
+ )
151
+ or None
152
+ )
153
+ with col2:
154
+ st.subheader("–")
155
+ text = st.text_area("Text", key="act_default_text")
156
+
157
+ inputs = lm.tokenizer([text], return_tensors="pt")
158
+ output = lm(inputs)
159
+ nmf = output.run_nmf(n_components=n_components, from_layer=from_layer, to_layer=to_layer)
160
+ data = nmf.explore(returnData=True)
161
+ JS_TEMPLATE = f"""<script>requirejs(['basic', 'ecco'], function(basic, ecco){{
162
+ const viz_id = basic.init()
163
+
164
+ ecco.interactiveTokensAndFactorSparklines(viz_id, {data},
165
+ {{
166
+ 'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}}
167
+ }}
168
+ }})
169
+ }}, function (err) {{
170
+ console.log(err);
171
+ }})</script>"""
172
+ html(SETUP_HTML + JS_TEMPLATE, height=800, scrolling=True)
subpages/debug.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pip._internal.operations import freeze
3
+
4
+ from subpages.page import Context, Page
5
+
6
+
7
+ class DebugPage(Page):
8
+ name = "Debug"
9
+ icon = "bug"
10
+
11
+ def render(self, context: Context):
12
+ st.title(self.name)
13
+ # with st.expander("πŸ’‘", expanded=True):
14
+ # st.write("Some debug info.")
15
+
16
+ st.subheader("Installed Packages")
17
+ # get output of pip freeze from system
18
+ with st.expander("pip freeze"):
19
+ st.code("\n".join(freeze.freeze()))
20
+
21
+ st.subheader("Streamlit Session State")
22
+ st.json(st.session_state)
23
+ st.subheader("Tokenizer")
24
+ st.code(context.tokenizer)
25
+ st.subheader("Model")
26
+ st.code(context.model.config)
27
+ st.code(context.model)
subpages/embeddings.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import plotly.express as px
3
+ import plotly.graph_objects as go
4
+ import streamlit as st
5
+
6
+ from subpages.page import Context, Page
7
+
8
+
9
+ @st.cache
10
+ def reduce_dim_svd(X, n_iter, random_state=42):
11
+ from sklearn.decomposition import TruncatedSVD
12
+
13
+ svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
14
+ return svd.fit_transform(X)
15
+
16
+
17
+ @st.cache
18
+ def reduce_dim_pca(X, random_state=42):
19
+ from sklearn.decomposition import PCA
20
+
21
+ return PCA(n_components=2, random_state=random_state).fit_transform(X)
22
+
23
+
24
+ @st.cache
25
+ def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
26
+ from umap import UMAP
27
+
28
+ return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
29
+
30
+
31
+ class EmbeddingsPage(Page):
32
+ name = "Embeddings"
33
+ icon = "grid-3x3"
34
+
35
+ def get_widget_defaults(self):
36
+ return {
37
+ "n_tokens": 1_000,
38
+ "svd_n_iter": 5,
39
+ "svd_random_state": 42,
40
+ "umap_n_neighbors": 15,
41
+ "umap_metric": "euclidean",
42
+ "umap_min_dist": 0.1,
43
+ }
44
+
45
+ def render(self, context: Context):
46
+ st.title("Embeddings")
47
+
48
+ with st.expander("πŸ’‘", expanded=True):
49
+ st.write(
50
+ "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples signified by a small black border."
51
+ )
52
+
53
+ col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
54
+ df = context.df_tokens_merged.copy()
55
+ dim_algo = "SVD"
56
+ n_tokens = 100
57
+
58
+ with col1:
59
+ st.subheader("Settings")
60
+ n_tokens = st.slider(
61
+ "#tokens",
62
+ key="n_tokens",
63
+ min_value=100,
64
+ max_value=len(df["tokens"].unique()),
65
+ step=100,
66
+ )
67
+
68
+ dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
69
+ if dim_algo == "SVD":
70
+ svd_n_iter = st.slider(
71
+ "#iterations",
72
+ key="svd_n_iter",
73
+ min_value=1,
74
+ max_value=10,
75
+ step=1,
76
+ )
77
+ elif dim_algo == "UMAP":
78
+ umap_n_neighbors = st.slider(
79
+ "#neighbors",
80
+ key="umap_n_neighbors",
81
+ min_value=2,
82
+ max_value=100,
83
+ step=1,
84
+ )
85
+ umap_min_dist = st.number_input(
86
+ "Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
87
+ )
88
+ umap_metric = st.selectbox(
89
+ "Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
90
+ )
91
+ else:
92
+ pass
93
+
94
+ with col2:
95
+ sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
96
+
97
+ X = np.array(df["hidden_states"].tolist())
98
+ transformed_hidden_states = None
99
+ if dim_algo == "SVD":
100
+ transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore
101
+ elif dim_algo == "PCA":
102
+ transformed_hidden_states = reduce_dim_pca(X)
103
+ elif dim_algo == "UMAP":
104
+ transformed_hidden_states = reduce_dim_umap(
105
+ X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
106
+ )
107
+
108
+ assert isinstance(transformed_hidden_states, np.ndarray)
109
+ df["x"] = transformed_hidden_states[:, 0]
110
+ df["y"] = transformed_hidden_states[:, 1]
111
+ df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
112
+ df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
113
+ df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
114
+ df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
115
+ df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
116
+ df["mislabeled"] = df["labels"] != df["preds"]
117
+
118
+ subset = df[:n_tokens]
119
+ mislabeled_examples_trace = go.Scatter(
120
+ x=subset[subset["mislabeled"]]["x"],
121
+ y=subset[subset["mislabeled"]]["y"],
122
+ mode="markers",
123
+ marker=dict(
124
+ size=6,
125
+ color="rgba(0,0,0,0)",
126
+ line=dict(width=1),
127
+ ),
128
+ hoverinfo="skip",
129
+ )
130
+
131
+ st.subheader("Projection Results")
132
+
133
+ fig = px.scatter(
134
+ subset,
135
+ x="x",
136
+ y="y",
137
+ color="labels",
138
+ hover_data=["sent0", "sent1", "sent2", "sent3", "sent4"],
139
+ hover_name="tokens",
140
+ title="Colored by label",
141
+ )
142
+ fig.add_trace(mislabeled_examples_trace)
143
+ st.plotly_chart(fig)
144
+
145
+ fig = px.scatter(
146
+ subset,
147
+ x="x",
148
+ y="y",
149
+ color="preds",
150
+ hover_data=["sent0", "sent1", "sent2", "sent3", "sent4"],
151
+ hover_name="tokens",
152
+ title="Colored by prediction",
153
+ )
154
+ fig.add_trace(mislabeled_examples_trace)
155
+ st.plotly_chart(fig)
subpages/emoji-en-US.json ADDED
The diff for this file is too large to render. See raw diff
 
subpages/faiss.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from datasets import Dataset
3
+
4
+ from subpages.page import Context, Page # type: ignore
5
+ from utils import device, explode_df, htmlify_labeled_example, tag_text
6
+
7
+
8
+ class FaissPage(Page):
9
+ name = "Bla"
10
+ icon = "x-octagon"
11
+
12
+ def render(self, context: Context):
13
+ dd = Dataset.from_pandas(context.df_tokens_merged, preserve_index=False) # type: ignore
14
+
15
+ dd.add_faiss_index(column="hidden_states", index_name="token_index")
16
+ token_id, text = (
17
+ 6,
18
+ "Die Wissenschaft ist eine wichtige Grundlage fΓΌr die Entwicklung von neuen Technologien.",
19
+ )
20
+ token_id, text = (
21
+ 15,
22
+ "Außer der unbewussten Beeinflussung eines Resultats gibt es auch noch andere Motive die das reine strahlende Licht der Wissenschaft etwas zu trüben vermâgen.",
23
+ )
24
+ token_id, text = (
25
+ 3,
26
+ "Mit mehr Instrumenten einer besseren prΓ€ziseren Datenbasis ist auch ein viel besseres smarteres Risikomanagement mΓΆglich.",
27
+ )
28
+ token_id, text = (
29
+ 7,
30
+ "Es gilt die akademische Viertelstunde das heißt Beginn ist fünfzehn Minuten spÀter.",
31
+ )
32
+ token_id, text = (
33
+ 7,
34
+ "Damit einher geht ΓΌbrigens auch dass Marcella Collocinis Tochter keine wie auch immer geartete strafrechtliche Verfolgung zu befΓΌrchten hat.",
35
+ )
36
+ token_id, text = (
37
+ 16,
38
+ "After Steve Jobs met with Bill Gates of Microsoft back in 1993, they went to Cupertino and made the deal.",
39
+ )
40
+
41
+ tagged = tag_text(text, context.tokenizer, context.model, device)
42
+ hidden_states = tagged["hidden_states"]
43
+ # tagged.drop("hidden_states", inplace=True, axis=1)
44
+ # hidden_states_vec = svd.transform([hidden_states[token_id]])[0].astype(np.float32)
45
+ hidden_states_vec = hidden_states[token_id]
46
+ tagged = tagged.astype(str)
47
+ tagged["probs"] = tagged["probs"].apply(lambda x: x[:-2])
48
+ tagged["check"] = tagged["probs"].apply(
49
+ lambda x: "βœ… βœ…" if int(x) < 100 else "βœ…" if int(x) < 1000 else ""
50
+ )
51
+ st.dataframe(tagged.drop("hidden_states", axis=1).T)
52
+ results = dd.get_nearest_examples("token_index", hidden_states_vec, k=10)
53
+ for i, (dist, idx, token) in enumerate(
54
+ zip(results.scores, results.examples["ids"], results.examples["tokens"])
55
+ ):
56
+ st.code(f"{dist:.3f} {token}")
57
+ sample = context.df_tokens_merged.query(f"ids == '{idx}'")
58
+ st.write(f"[{i};{idx}] " + htmlify_labeled_example(sample), unsafe_allow_html=True)
subpages/find_duplicates.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers.util import cos_sim
3
+
4
+ from subpages.page import Context, Page
5
+
6
+
7
+ @st.cache()
8
+ def get_sims(texts: list[str], sentence_encoder):
9
+ embeddings = sentence_encoder.encode(texts, batch_size=8, convert_to_numpy=True)
10
+ return cos_sim(embeddings, embeddings)
11
+
12
+
13
+ class FindDuplicatesPage(Page):
14
+ name = "Find Duplicates"
15
+ icon = "fingerprint"
16
+
17
+ def get_widget_defaults(self):
18
+ return {
19
+ "cutoff": 0.95,
20
+ }
21
+
22
+ def render(self, context: Context):
23
+ st.title("Find Duplicates")
24
+ with st.expander("πŸ’‘", expanded=True):
25
+ st.write("Find potential duplicates in the data using cosine similarity.")
26
+
27
+ cutoff = st.slider("Similarity threshold", min_value=0.0, max_value=1.0, key="cutoff")
28
+ # split.add_faiss_index(column="embeddings", index_name="sent_index")
29
+ # st.write("Index is ready")
30
+ # sentence_encoder.encode(["hello world"], batch_size=8)
31
+ # st.write(split["tokens"][0])
32
+ texts = [" ".join(ts) for ts in context.split["tokens"]]
33
+ sims = get_sims(texts, context.sentence_encoder)
34
+
35
+ candidates = []
36
+ for i in range(len(sims)):
37
+ for j in range(i + 1, len(sims)):
38
+ if sims[i][j] >= cutoff:
39
+ candidates.append((sims[i][j], i, j))
40
+ candidates.sort(reverse=False)
41
+
42
+ for (sim, i, j) in candidates[:100]:
43
+ st.markdown(f"**Possible duplicate ({i}, {j}, sim: {sim:.3f}):**")
44
+ st.markdown("* " + " ".join(context.split["tokens"][i]))
45
+ st.markdown("* " + " ".join(context.split["tokens"][j]))
46
+
47
+ # st.write("queries")
48
+ # results = split.get_nearest_examples("sent_index", np.array(split["embeddings"][0], dtype=np.float32), k=2)
49
+ # results = split.get_nearest_examples_batch("sent_index", queries, k=2)
50
+ # st.write(results.total_examples[0]["id"][1])
51
+ # st.write(results.total_examples[0])
subpages/home.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from typing import Optional
4
+
5
+ import streamlit as st
6
+ from pandas import wide_to_long
7
+
8
+ from data import get_data
9
+ from subpages.page import Context, Page
10
+ from utils import color_map_color
11
+
12
+ _SENTENCE_ENCODER_MODEL = (
13
+ "sentence-transformers/all-MiniLM-L6-v2",
14
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
15
+ )[0]
16
+ _MODEL_NAME = (
17
+ "elastic/distilbert-base-uncased-finetuned-conll03-english",
18
+ "gagan3012/bert-tiny-finetuned-ner",
19
+ "socialmediaie/bertweet-base_wnut17_ner",
20
+ "sberbank-ai/bert-base-NER-reptile-5-datasets",
21
+ "aseifert/comma-xlm-roberta-base",
22
+ "dslim/bert-base-NER",
23
+ "aseifert/distilbert-base-german-cased-comma-derstandard",
24
+ )[0]
25
+ _DATASET_NAME = (
26
+ "conll2003",
27
+ "wnut_17",
28
+ "aseifert/comma",
29
+ )[0]
30
+ _CONFIG_NAME = (
31
+ "conll2003",
32
+ "wnut_17",
33
+ "seifertverlag",
34
+ )[0]
35
+
36
+
37
+ class HomePage(Page):
38
+ name = "Home / Setup"
39
+ icon = "house"
40
+
41
+ def get_widget_defaults(self):
42
+ return {
43
+ "encoder_model_name": _SENTENCE_ENCODER_MODEL,
44
+ "model_name": _MODEL_NAME,
45
+ "ds_name": _DATASET_NAME,
46
+ "ds_split_name": "validation",
47
+ "ds_config_name": _CONFIG_NAME,
48
+ "split_sample_size": 512,
49
+ }
50
+
51
+ def render(self, context: Optional[Context] = None):
52
+ st.title("ExplaiNER")
53
+
54
+ with st.expander("πŸ’‘", expanded=True):
55
+ st.write(
56
+ "**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further improvements."
57
+ )
58
+
59
+ col1, _, col2a, col2b = st.columns([1, 0.05, 0.15, 0.15])
60
+
61
+ with col1:
62
+ random_form_key = f"settings-{random.randint(0, 100000)}"
63
+ # FIXME: for some reason I'm getting the following error if I don't randomize the key:
64
+ """
65
+ 2022-05-05 20:37:16.507 Traceback (most recent call last):
66
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/scriptrunner/script_runner.py", line 443, in _run_script
67
+ exec(code, module.__dict__)
68
+ File "/Users/zoro/code/error-analysis/main.py", line 162, in <module>
69
+ main()
70
+ File "/Users/zoro/code/error-analysis/main.py", line 102, in main
71
+ show_setup()
72
+ File "/Users/zoro/code/error-analysis/section/setup.py", line 68, in show_setup
73
+ st.form_submit_button("Load Model & Data")
74
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 240, in form_submit_button
75
+ return self._form_submit_button(
76
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 260, in _form_submit_button
77
+ return self.dg._button(
78
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/button.py", line 304, in _button
79
+ check_session_state_rules(default_value=None, key=key, writes_allowed=False)
80
+ File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/utils.py", line 74, in check_session_state_rules
81
+ raise StreamlitAPIException(
82
+ streamlit.errors.StreamlitAPIException: Values for st.button, st.download_button, st.file_uploader, and st.form cannot be set using st.session_state.
83
+ """
84
+ with st.form(key=random_form_key):
85
+ st.subheader("Model & Data Selection")
86
+ st.text_input(
87
+ label="NER Model:",
88
+ key="model_name",
89
+ help="Path or name of the model to use",
90
+ )
91
+ st.text_input(
92
+ label="Encoder Model:",
93
+ key="encoder_model_name",
94
+ help="Path or name of the encoder to use",
95
+ )
96
+ ds_name = st.text_input(
97
+ label="Dataset:",
98
+ key="ds_name",
99
+ help="Path or name of the dataset to use",
100
+ )
101
+ ds_config_name = st.text_input(
102
+ label="Config (optional):",
103
+ key="ds_config_name",
104
+ )
105
+ ds_split_name = st.selectbox(
106
+ label="Split:",
107
+ options=["train", "validation", "test"],
108
+ key="ds_split_name",
109
+ )
110
+ split_sample_size = st.number_input(
111
+ "Sample size:",
112
+ step=16,
113
+ key="split_sample_size",
114
+ help="Sample size for the split, speeds up processing inside streamlit",
115
+ )
116
+ # breakpoint()
117
+ # st.form_submit_button("Submit")
118
+ st.form_submit_button("Load Model & Data")
119
+
120
+ split = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
121
+ labels = list(
122
+ set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"])
123
+ )
124
+
125
+ with col2a:
126
+ st.subheader("Classes")
127
+ st.write("**Color**")
128
+ colors = {label: color_map_color(i / len(labels)) for i, label in enumerate(labels)}
129
+ for label in labels:
130
+ if f"color_{label}" not in st.session_state:
131
+ st.session_state[f"color_{label}"] = colors[label]
132
+ st.color_picker(label, key=f"color_{label}")
133
+ with col2b:
134
+ st.subheader("β€”")
135
+ st.write("**Icon**")
136
+ emojis = list(json.load(open("subpages/emoji-en-US.json")).keys())
137
+ for label in labels:
138
+ if f"icon_{label}" not in st.session_state:
139
+ st.session_state[f"icon_{label}"] = "πŸ€—" # labels[label]
140
+ st.selectbox(label, key=f"icon_{label}", options=emojis)
141
+
142
+ # if st.button("Reset to defaults"):
143
+ # st.session_state.update(**get_home_page_defaults())
144
+ # # time.sleep 2 secs
145
+ # import time
146
+ # time.sleep(1)
147
+
148
+ # # st.legacy_caching.clear_cache()
149
+ # st.experimental_rerun()
subpages/inspect.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from subpages.page import Context, Page
4
+ from utils import aggrid_interactive_table, colorize_classes
5
+
6
+
7
+ class InspectPage(Page):
8
+ name = "Inspect"
9
+ icon = "search"
10
+
11
+ def render(self, context: Context):
12
+ st.title(self.name)
13
+ with st.expander("πŸ’‘", expanded=True):
14
+ st.write("Inspect your whole dataset, either unfiltered or by id.")
15
+
16
+ df = context.df_tokens
17
+ cols = (
18
+ "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
19
+ )
20
+ if "token_type_ids" not in df.columns:
21
+ cols.remove("token_type_ids")
22
+ df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols]
23
+
24
+ if st.checkbox("Filter by id", value=True):
25
+ ids = list(sorted(map(int, df.ids.unique())))
26
+ next_id = st.session_state.get("next_id", 0)
27
+
28
+ example_id = st.selectbox("Select an example", ids, index=next_id)
29
+ df = df[df.ids == str(example_id)][1:-1]
30
+ # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses")) # type: ignore
31
+ st.dataframe(colorize_classes(df.round(3).astype(str)))
32
+
33
+ if st.button("Next example"):
34
+ st.session_state.next_id = (ids.index(example_id) + 1) % len(ids)
35
+ if st.button("Previous example"):
36
+ st.session_state.next_id = (ids.index(example_id) - 1) % len(ids)
37
+ else:
38
+ aggrid_interactive_table(df.round(3))
subpages/losses.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from subpages.page import Context, Page
4
+ from utils import AgGrid, aggrid_interactive_table
5
+
6
+
7
+ @st.cache
8
+ def get_loss_by_token(df_tokens):
9
+ return (
10
+ df_tokens.groupby("tokens")[["losses"]]
11
+ .agg(["count", "mean", "median", "sum"])
12
+ .droplevel(level=0, axis=1) # Get rid of multi-level columns
13
+ .sort_values(by="sum", ascending=False)
14
+ .reset_index()
15
+ )
16
+
17
+
18
+ @st.cache
19
+ def get_loss_by_label(df_tokens):
20
+ return (
21
+ df_tokens.groupby("labels")[["losses"]]
22
+ .agg(["count", "mean", "median", "sum"])
23
+ .droplevel(level=0, axis=1)
24
+ .sort_values(by="mean", ascending=False)
25
+ .reset_index()
26
+ )
27
+
28
+
29
+ class LossesPage(Page):
30
+ name = "Loss by Token/Label"
31
+ icon = "sort-alpha-down"
32
+
33
+ def render(self, context: Context):
34
+ st.title(self.name)
35
+ with st.expander("πŸ’‘", expanded=True):
36
+ st.write("Show count, mean and median loss per token and label.")
37
+
38
+ col1, _, col2 = st.columns([8, 1, 6])
39
+
40
+ with col1:
41
+ st.subheader("πŸ’¬ Loss by Token")
42
+
43
+ st.session_state["_merge_tokens"] = st.checkbox(
44
+ "Merge tokens", value=True, key="merge_tokens"
45
+ )
46
+ loss_by_token = (
47
+ get_loss_by_token(context.df_tokens_merged)
48
+ if st.session_state["merge_tokens"]
49
+ else get_loss_by_token(context.df_tokens_cleaned)
50
+ )
51
+ aggrid_interactive_table(loss_by_token.round(3))
52
+ # st.subheader("🏷️ Loss by Label")
53
+ # loss_by_label = get_loss_by_label(df_tokens_cleaned)
54
+ # st.dataframe(loss_by_label)
55
+
56
+ st.write(
57
+ "_Attention: This statistic disregards that tokens have contextual representations._"
58
+ )
59
+
60
+ with col2:
61
+ st.subheader("🏷️ Loss by Label")
62
+ loss_by_label = get_loss_by_label(context.df_tokens_cleaned)
63
+ AgGrid(loss_by_label.round(3), height=200)
subpages/lossy_samples.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+
4
+ from subpages.page import Context, Page
5
+ from utils import colorize_classes, get_bg_color, get_fg_color, htmlify_labeled_example
6
+
7
+
8
+ class LossySamplesPage(Page):
9
+ name = "Samples by Loss"
10
+ icon = "sort-numeric-down-alt"
11
+
12
+ def get_widget_defaults(self):
13
+ return {
14
+ "skip_correct": True,
15
+ "samples_by_loss_show_df": True,
16
+ }
17
+
18
+ def render(self, context: Context):
19
+ st.title(self.name)
20
+ with st.expander("πŸ’‘", expanded=True):
21
+ st.write("Show every example sorted by loss (descending) for close inspection.")
22
+
23
+ st.subheader("πŸ’₯ Samples ⬇loss")
24
+ skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct")
25
+ show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df")
26
+
27
+ st.write(
28
+ """<style>
29
+ thead {
30
+ display: none;
31
+ }
32
+ td {
33
+ white-space: nowrap;
34
+ padding: 0 5px !important;
35
+ }
36
+ </style>""",
37
+ unsafe_allow_html=True,
38
+ )
39
+
40
+ top_indices = (
41
+ context.df.sort_values(by="total_loss", ascending=False)
42
+ .query("total_loss > 0.5")
43
+ .index
44
+ )
45
+
46
+ cnt = 0
47
+ for idx in top_indices:
48
+ sample = context.df_tokens_merged.loc[idx]
49
+
50
+ if isinstance(sample, pd.Series):
51
+ continue
52
+
53
+ if skip_correct and sum(sample.labels != sample.preds) == 0:
54
+ continue
55
+
56
+ if show_df:
57
+
58
+ def colorize_col(col):
59
+ if col.name == "labels" or col.name == "preds":
60
+ bgs = []
61
+ fgs = []
62
+ ops = []
63
+ for v in col.values:
64
+ bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
65
+ fgs.append(get_fg_color(bgs[-1]))
66
+ ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5")
67
+ return [
68
+ f"background-color: {bg}; color: {fg}; opacity: {op};"
69
+ for bg, fg, op in zip(bgs, fgs, ops)
70
+ ]
71
+ return [""] * len(col)
72
+
73
+ df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3)
74
+ losses_slice = pd.IndexSlice["losses", :]
75
+ # x = df.T.astype(str)
76
+ # st.dataframe(x)
77
+ # st.dataframe(x.loc[losses_slice])
78
+ styler = (
79
+ df.T.style.apply(colorize_col, axis=1)
80
+ .bar(subset=losses_slice, axis=1)
81
+ .format(precision=3)
82
+ )
83
+ # styler.data = styler.data.astype(str)
84
+ st.write(styler.to_html(), unsafe_allow_html=True)
85
+ # st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses')) # type: ignore
86
+ # st.write(
87
+ # colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str))
88
+ # )
89
+
90
+ col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32])
91
+
92
+ cnt += 1
93
+ counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>"
94
+ loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
95
+ col1.write(f"{counter}{loss}", unsafe_allow_html=True)
96
+ col1.write("")
97
+
98
+ col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
99
+ # st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)
subpages/metrics.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ import plotly.express as px
7
+ import streamlit as st
8
+ from seqeval.metrics import classification_report
9
+ from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
10
+
11
+ from subpages.page import Context, Page
12
+
13
+
14
+ def _get_evaluation(df):
15
+ y_true = df.apply(lambda row: [lbl for lbl in row.labels if lbl != "IGN"], axis=1)
16
+ y_pred = df.apply(
17
+ lambda row: [pred for (pred, lbl) in zip(row.preds, row.labels) if lbl != "IGN"],
18
+ axis=1,
19
+ )
20
+ report: str = classification_report(y_true, y_pred, scheme="IOB2", digits=3) # type: ignore
21
+ return report.replace(
22
+ "precision recall f1-score support",
23
+ "=" * 12 + " precision recall f1-score support",
24
+ )
25
+
26
+
27
+ def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True):
28
+ cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels)
29
+ if zero_diagonal:
30
+ np.fill_diagonal(cm, 0)
31
+
32
+ # st.write(plt.rcParams["font.size"])
33
+ # plt.rcParams.update({'font.size': 10.0})
34
+ fig, ax = plt.subplots(figsize=(10, 10))
35
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
36
+ fmt = "d" if normalize is None else ".3f"
37
+ disp.plot(
38
+ cmap="Blues",
39
+ include_values=True,
40
+ xticks_rotation="vertical",
41
+ values_format=fmt,
42
+ ax=ax,
43
+ colorbar=False,
44
+ )
45
+ return fig
46
+
47
+
48
+ class MetricsPage(Page):
49
+ name = "Metrics"
50
+ icon = "graph-up-arrow"
51
+
52
+ def get_widget_defaults(self):
53
+ return {
54
+ "normalize": True,
55
+ "zero_diagonal": False,
56
+ }
57
+
58
+ def render(self, context: Context):
59
+ st.title(self.name)
60
+ with st.expander("πŸ’‘", expanded=True):
61
+ st.write(
62
+ "The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)."
63
+ )
64
+
65
+ eval_results = _get_evaluation(context.df)
66
+ if len(eval_results.splitlines()) < 8:
67
+ col1, _, col2 = st.columns([8, 1, 1])
68
+ else:
69
+ col1 = col2 = st
70
+
71
+ col1.subheader("🎯 Evaluation Results")
72
+ col1.code(eval_results)
73
+
74
+ results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]]
75
+ data = [(r[0], int(r[-1]), float(r[-2])) for r in results]
76
+ df = pd.DataFrame(data, columns="class support f1".split())
77
+ fig = px.scatter(
78
+ df,
79
+ x="support",
80
+ y="f1",
81
+ range_y=(0, 1.05),
82
+ color="class",
83
+ )
84
+ # fig.update_layout(title_text="asdf", title_yanchor="bottom")
85
+ col1.plotly_chart(fig)
86
+
87
+ col2.subheader("πŸ”  Confusion Matrix")
88
+ normalize = None if not col2.checkbox("Normalize", key="normalize") else "true"
89
+ zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal")
90
+ col2.pyplot(
91
+ plot_confusion_matrix(
92
+ y_true=context.df_tokens_cleaned["labels"],
93
+ y_preds=context.df_tokens_cleaned["preds"],
94
+ labels=context.labels,
95
+ normalize=normalize,
96
+ zero_diagonal=zero_diagonal,
97
+ ),
98
+ )
subpages/misclassified.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ import pandas as pd
4
+ import streamlit as st
5
+ from sklearn.metrics import confusion_matrix
6
+
7
+ from subpages.page import Context, Page
8
+ from utils import htmlify_labeled_example
9
+
10
+
11
+ class MisclassifiedPage(Page):
12
+ name = "Misclassified"
13
+ icon = "x-octagon"
14
+
15
+ def render(self, context: Context):
16
+ st.title(self.name)
17
+ with st.expander("πŸ’‘", expanded=True):
18
+ st.write(
19
+ "This page contains all misclassified examples and allows filtering by specific error types."
20
+ )
21
+
22
+ misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
23
+ misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
24
+ cm = confusion_matrix(
25
+ misclassified_samples.labels,
26
+ misclassified_samples.preds,
27
+ labels=context.labels,
28
+ )
29
+
30
+ # st.pyplot(
31
+ # plot_confusion_matrix(
32
+ # y_preds=misclassified_samples["preds"],
33
+ # y_true=misclassified_samples["labels"],
34
+ # labels=labels,
35
+ # normalize=None,
36
+ # zero_diagonal=True,
37
+ # ),
38
+ # )
39
+ df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
40
+ import numpy as np
41
+
42
+ np.fill_diagonal(df.values, "")
43
+ st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
44
+ # import matplotlib.pyplot as plt
45
+ # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
46
+ # selection = aggrid_interactive_table(df)
47
+
48
+ # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
49
+
50
+ confusions = defaultdict(int)
51
+ for i, row in enumerate(cm):
52
+ for j, _ in enumerate(row):
53
+ if i == j or cm[i][j] == 0:
54
+ continue
55
+ confusions[(context.labels[i], context.labels[j])] += cm[i][j]
56
+
57
+ def format_func(item):
58
+ return (
59
+ f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
60
+ )
61
+
62
+ conf = st.radio(
63
+ "Filter by Class Confusion",
64
+ options=list(zip(confusions.keys(), confusions.values())),
65
+ format_func=format_func,
66
+ )
67
+
68
+ # st.write(
69
+ # f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
70
+ # )
71
+
72
+ filtered_indices = misclassified_samples.query(
73
+ f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
74
+ ).index
75
+ for i, idx in enumerate(filtered_indices):
76
+ sample = context.df_tokens_merged.loc[idx]
77
+ st.write(
78
+ htmlify_labeled_example(sample),
79
+ unsafe_allow_html=True,
80
+ )
81
+ st.write("---")
subpages/page.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ import pandas as pd
5
+ from datasets import Dataset # type: ignore
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoModelForSequenceClassification # type: ignore
8
+ from transformers import AutoTokenizer # type: ignore
9
+
10
+
11
+ @dataclass
12
+ class Context:
13
+ model: AutoModelForSequenceClassification
14
+ tokenizer: AutoTokenizer
15
+ sentence_encoder: SentenceTransformer
16
+ tags: Any
17
+ df: pd.DataFrame
18
+ df_tokens: pd.DataFrame
19
+ df_tokens_cleaned: pd.DataFrame
20
+ df_tokens_merged: pd.DataFrame
21
+ split_sample_size: int
22
+ ds_name: str
23
+ ds_config_name: str
24
+ ds_split_name: str
25
+ split: Dataset
26
+ labels: list[str]
27
+
28
+
29
+ class Page:
30
+ name: str
31
+ icon: str
32
+
33
+ def get_widget_defaults(self):
34
+ return {}
35
+
36
+ def render(self, context):
37
+ ...
subpages/probing.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from subpages.page import Context, Page
4
+ from utils import device, tag_text
5
+
6
+ _DEFAULT_SENTENCES = """
7
+ Damit hatte er auf ihr letztes , vΓΆllig schiefgelaufenes GeschΓ€ftsessen angespielt .
8
+ Damit einher geht ΓΌbrigens auch , dass Marcella , Collocinis Tochter , keine wie auch immer geartete strafrechtliche Verfolgung zu befΓΌrchten hat .
9
+ Nach dem Bell ’ schen Theorem , einer Physik jenseits der Quanten , ist die Welt , die wir fΓΌr real halten , nicht objektivierbar .
10
+ Dazu muss man wiederum wissen , dass die Aussagekraft von Tests , neben der SensitivitΓ€t und SpezifitΓ€t , ganz entscheidend von der Vortestwahrscheinlichkeit abhΓ€ngt .
11
+ Haben Sie sich schon eingelebt ? Β« erkundigte er sich .
12
+ Das Auto ein Totalschaden , mein Beifahrer ein weinender Jammerlappen .
13
+ Seltsam , wunderte sie sich , dass das StΓΌck nach mehr als eineinhalb Jahrhunderten noch so gut in Schuss ist .
14
+ Oder auf den Strich gehen , StrΓΌmpfe stricken , Geld hamstern .
15
+ Und Allah ist Allumfassend Allwissend .
16
+ Und Pedro Moacir redete weiter : » Verzicht , Pater Antonio , Verzicht , zu großer Schmerz über Verzicht , Sehnsucht , die sich nicht erfüllt , die sich nicht erfüllen kann , das sind Qualen , die ein Verstummen nach sich ziehen kânnen , oder HÀrte .
17
+ Mama-San ging mittlerweile fast ausnahmslos nur mit Wei an ihrer Seite aus dem Haus , kaum je mit einem der MΓ€dchen und niemals allein.
18
+ """.strip()
19
+ _DEFAULT_SENTENCES = """
20
+ Elon Musk’s Berghain humiliation β€” I know the feeling
21
+ Musk was also seen at a local spot called Sisyphos celebrating entrepreneur Adeo Ressi's birthday, according to The Times.
22
+ """.strip()
23
+
24
+
25
+ class ProbingPage(Page):
26
+ name = "Probing"
27
+ icon = "fonts"
28
+
29
+ def get_widget_defaults(self):
30
+ return {"probing_textarea": _DEFAULT_SENTENCES}
31
+
32
+ def render(self, context: Context):
33
+ st.title("πŸ”  Interactive Probing")
34
+
35
+ with st.expander("πŸ’‘", expanded=True):
36
+ st.write(
37
+ "A very direct and interactive way to test your model is by providing it with a list of text inputs and then inspecting the model outputs. The application features a multiline text field so the user can input multiple texts separated by newlines. For each text, the app will show a data frame containing the tokenized string, token predictions, probabilities and a visual indicator for low probability predictions -- these are the ones you should inspect first for prediction errors."
38
+ )
39
+
40
+ sentences = st.text_area("Sentences", height=200, key="probing_textarea")
41
+ if not sentences.strip():
42
+ return
43
+ sentences = [sentence.strip() for sentence in sentences.splitlines()]
44
+
45
+ for sent in sentences:
46
+ sent = sent.replace(",", "").replace(" ", " ")
47
+ with st.expander(sent):
48
+ tagged = tag_text(sent, context.tokenizer, context.model, device)
49
+ tagged = tagged.astype(str)
50
+ tagged["probs"] = tagged["probs"].apply(lambda x: x[:-2])
51
+ tagged["check"] = tagged["probs"].apply(
52
+ lambda x: "βœ… βœ…" if int(x) < 100 else "βœ…" if int(x) < 1000 else ""
53
+ )
54
+ st.dataframe(tagged.drop("hidden_states", axis=1).T)
subpages/random_samples.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+
4
+ from subpages.page import Context, Page
5
+ from utils import htmlify_labeled_example
6
+
7
+
8
+ class RandomSamplesPage(Page):
9
+ name = "Random Samples"
10
+ icon = "shuffle"
11
+
12
+ def get_widget_defaults(self):
13
+ return {
14
+ "random_sample_size_min": 128,
15
+ }
16
+
17
+ def render(self, context: Context):
18
+ st.title("🎲 Random Samples")
19
+ with st.expander("πŸ’‘", expanded=True):
20
+ st.write(
21
+ "Show random samples. Simple idea, but often it turns up some interesting things."
22
+ )
23
+
24
+ random_sample_size = st.number_input(
25
+ "Random sample size:",
26
+ value=min(st.session_state.random_sample_size_min, context.split_sample_size),
27
+ step=16,
28
+ key="random_sample_size",
29
+ )
30
+
31
+ if st.button("🎲 Resample"):
32
+ st.experimental_rerun()
33
+
34
+ random_indices = context.df.sample(int(random_sample_size)).index
35
+ samples = context.df_tokens_merged.loc[random_indices]
36
+ return
37
+
38
+ for i, idx in enumerate(random_indices):
39
+ sample = samples.loc[idx]
40
+
41
+ if isinstance(sample, pd.Series):
42
+ continue
43
+
44
+ col1, _, col2 = st.columns([0.08, 0.025, 0.8])
45
+
46
+ counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: wh^; padding: 0 5px'>[{i+1} | {idx}]</span>"
47
+ loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
48
+ col1.write(f"{counter}{loss}", unsafe_allow_html=True)
49
+ col1.write("")
50
+ st.write(sample.astype(str))
51
+ col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
subpages/raw_data.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+
4
+ from subpages.page import Context, Page
5
+ from utils import aggrid_interactive_table
6
+
7
+
8
+ @st.cache
9
+ def convert_df(df):
10
+ return df.to_csv().encode("utf-8")
11
+
12
+
13
+ class RawDataPage(Page):
14
+ name = "Raw data"
15
+ icon = "qr-code"
16
+
17
+ def render(self, context: Context):
18
+ st.title(self.name)
19
+ with st.expander("πŸ’‘", expanded=True):
20
+ st.write("See the data as seen by your model.")
21
+
22
+ st.subheader("Dataset")
23
+ st.code(
24
+ f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
25
+ )
26
+
27
+ st.write("**Data after processing and inference**")
28
+
29
+ processed_df = (
30
+ context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
31
+ )
32
+ cols = (
33
+ "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
34
+ )
35
+ if "token_type_ids" not in processed_df.columns:
36
+ cols.remove("token_type_ids")
37
+ processed_df = processed_df[cols]
38
+ aggrid_interactive_table(processed_df)
39
+ processed_df_csv = convert_df(processed_df)
40
+ st.download_button(
41
+ "Download csv",
42
+ processed_df_csv,
43
+ "processed_data.csv",
44
+ "text/csv",
45
+ )
46
+
47
+ st.write("**Raw data (exploded by tokens)**")
48
+ raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore
49
+ aggrid_interactive_table(raw_data_df)
50
+ raw_data_df_csv = convert_df(raw_data_df)
51
+ st.download_button(
52
+ "Download csv",
53
+ raw_data_df_csv,
54
+ "raw_data.csv",
55
+ "text/csv",
56
+ )
utils.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib as matplotlib
2
+ import matplotlib.cm as cm
3
+ import pandas as pd
4
+ import streamlit as st
5
+ import tokenizers
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
9
+
10
+ tokenizer_hash_funcs = {
11
+ tokenizers.Tokenizer: lambda _: None,
12
+ tokenizers.AddedToken: lambda _: None,
13
+ }
14
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu")
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+
18
+ def aggrid_interactive_table(df: pd.DataFrame) -> dict:
19
+ """Creates an st-aggrid interactive table based on a dataframe.
20
+ Args:
21
+ df (pd.DataFrame]): Source dataframe
22
+ Returns:
23
+ dict: The selected row
24
+ """
25
+ options = GridOptionsBuilder.from_dataframe(
26
+ df, enableRowGroup=True, enableValue=True, enablePivot=True
27
+ )
28
+
29
+ options.configure_side_bar()
30
+ # options.configure_default_column(cellRenderer=JsCode('''function(params) {return '<a href="#samples-loss">'+params.value+'</a>'}'''))
31
+
32
+ options.configure_selection("single")
33
+ selection = AgGrid(
34
+ df,
35
+ enable_enterprise_modules=True,
36
+ gridOptions=options.build(),
37
+ theme="light",
38
+ update_mode=GridUpdateMode.NO_UPDATE,
39
+ allow_unsafe_jscode=True,
40
+ )
41
+
42
+ return selection
43
+
44
+
45
+ def explode_df(df: pd.DataFrame) -> pd.DataFrame:
46
+ df_tokens = df.apply(pd.Series.explode)
47
+ if "losses" in df.columns:
48
+ df_tokens["losses"] = df_tokens["losses"].astype(float)
49
+ return df_tokens # type: ignore
50
+
51
+
52
+ def align_sample(row: pd.Series):
53
+ """Use word_ids to align all lists in a sample."""
54
+
55
+ columns = row.axes[0].to_list()
56
+ indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
57
+
58
+ out = {}
59
+
60
+ tokens = []
61
+ for i, tok in enumerate(row.tokens):
62
+ if row.word_ids[i] == -1:
63
+ continue
64
+
65
+ if row.word_ids[i] != row.word_ids[i - 1]:
66
+ tokens.append(tok.lstrip("▁").lstrip("##").rstrip("@@"))
67
+ else:
68
+ tokens[-1] += tok.lstrip("▁").lstrip("##").rstrip("@@")
69
+ out["tokens"] = tokens
70
+
71
+ if "labels" in columns:
72
+ out["labels"] = [row.labels[i] for i in indices]
73
+
74
+ if "preds" in columns:
75
+ out["preds"] = [row.preds[i] for i in indices]
76
+
77
+ if "losses" in columns:
78
+ out["losses"] = [row.losses[i] for i in indices]
79
+
80
+ if "probs" in columns:
81
+ out["probs"] = [row.probs[i] for i in indices]
82
+
83
+ if "hidden_states" in columns:
84
+ out["hidden_states"] = [row.hidden_states[i] for i in indices]
85
+
86
+ if "ids" in columns:
87
+ out["ids"] = row.ids
88
+
89
+ assert len(tokens) == len(out["preds"]), (tokens, row.tokens)
90
+
91
+ return out
92
+
93
+
94
+ @st.cache(
95
+ allow_output_mutation=True,
96
+ hash_funcs=tokenizer_hash_funcs,
97
+ )
98
+ def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
99
+ """Create an (exploded) DataFrame with the predicted labels and probabilities."""
100
+
101
+ tokens = tokenizer(text).tokens()
102
+ tokenized = tokenizer(text, return_tensors="pt")
103
+ word_ids = [w if w is not None else -1 for w in tokenized.word_ids()]
104
+ input_ids = tokenized.input_ids.to(device)
105
+ outputs = model(input_ids, output_hidden_states=True)
106
+ preds = torch.argmax(outputs.logits, dim=2)
107
+ preds = [model.config.id2label[p] for p in preds[0].cpu().numpy()]
108
+ hidden_states = outputs.hidden_states[-1][0].detach().cpu().numpy()
109
+ # hidden_states = np.mean([hidden_states, outputs.hidden_states[0][0].detach().cpu().numpy()], axis=0)
110
+
111
+ probs = 1 // (
112
+ torch.min(F.softmax(outputs.logits, dim=-1), dim=-1).values[0].detach().cpu().numpy()
113
+ )
114
+
115
+ df = pd.DataFrame(
116
+ [[tokens, word_ids, preds, probs, hidden_states]],
117
+ columns="tokens word_ids preds probs hidden_states".split(),
118
+ )
119
+ merged_df = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
120
+ return explode_df(merged_df).reset_index().drop(columns=["index"])
121
+
122
+
123
+ def get_bg_color(label):
124
+ return st.session_state[f"color_{label}"]
125
+
126
+
127
+ def get_fg_color(hex_color: str) -> str:
128
+ """Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/"""
129
+ r = int(hex_color[1:3], 16)
130
+ g = int(hex_color[3:5], 16)
131
+ b = int(hex_color[5:7], 16)
132
+ yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
133
+ return "black" if (yiq >= 128) else "white"
134
+
135
+
136
+ def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
137
+ """Colorize the errors in the dataframe."""
138
+
139
+ def colorize_row(row):
140
+ return [
141
+ "background-color: "
142
+ + ("white" if (row["labels"] == "IGN" or (row["preds"] == row["labels"])) else "pink")
143
+ + ";"
144
+ ] * len(row)
145
+
146
+ def colorize_col(col):
147
+ if col.name == "labels" or col.name == "preds":
148
+ bgs = []
149
+ fgs = []
150
+ for v in col.values:
151
+ bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
152
+ fgs.append(get_fg_color(bgs[-1]))
153
+ return [f"background-color: {bg}; color: {fg};" for bg, fg in zip(bgs, fgs)]
154
+ return [""] * len(col)
155
+
156
+ df = df.reset_index().drop(columns=["index"]).T
157
+ return df # .style.apply(colorize_col, axis=0)
158
+
159
+
160
+ def htmlify_labeled_example(example: pd.DataFrame) -> str:
161
+ html = []
162
+ classmap = {
163
+ "O": "O",
164
+ "PER": "πŸ™Ž",
165
+ "person": "πŸ™Ž",
166
+ "LOC": "🌎",
167
+ "location": "🌎",
168
+ "ORG": "🏀",
169
+ "corporation": "🏀",
170
+ "product": "πŸ“±",
171
+ "creative": "🎷",
172
+ "MISC": "🎷",
173
+ }
174
+
175
+ for _, row in example.iterrows():
176
+ pred = row.preds.split("-")[1] if "-" in row.preds else "O"
177
+ label = row.labels
178
+ label_class = row.labels.split("-")[1] if "-" in row.labels else "O"
179
+
180
+ color = get_bg_color(row.preds.split("-")[1]) if "-" in row.preds else "#000000"
181
+ true_color = get_bg_color(row.labels.split("-")[1]) if "-" in row.labels else "#000000"
182
+
183
+ font_color = get_fg_color(color) if color else "white"
184
+ true_font_color = get_fg_color(true_color) if true_color else "white"
185
+
186
+ is_correct = row.preds == row.labels
187
+ loss_html = (
188
+ ""
189
+ if float(row.losses) < 0.01
190
+ else f"<span style='background-color: yellow; color: font_color; padding: 0 5px;'>{row.losses:.3f}</span>"
191
+ )
192
+ loss_html = ""
193
+
194
+ if row.labels == row.preds == "O":
195
+ html.append(f"<span>{row.tokens}</span>")
196
+ elif row.labels == "IGN":
197
+ assert False
198
+ else:
199
+ opacity = "1" if not is_correct else "0.5"
200
+ correct = (
201
+ ""
202
+ if is_correct
203
+ else f"<span title='{label}' style='background-color: {true_color}; opacity: 1; color: {true_font_color}; padding: 0 5px; border: 1px solid black; min-width: 30px'>{classmap[label_class]}</span>"
204
+ )
205
+ pred_icon = classmap[pred] if pred != "O" and row.preds[:2] != "I-" else ""
206
+ html.append(
207
+ f"<span style='border: 1px solid black; color: {color}; padding: 0 5px;' title={row.preds}>{pred_icon + ' '}{row.tokens}</span>{correct}{loss_html}"
208
+ )
209
+
210
+ return " ".join(html)
211
+
212
+
213
+ def htmlify_example(example: pd.DataFrame) -> str:
214
+ corr_html = " ".join(
215
+ [
216
+ f", {row.tokens}" if row.labels == "B-COMMA" else row.tokens
217
+ for _, row in example.iterrows()
218
+ ]
219
+ ).strip()
220
+ return f"<em>{corr_html}</em>"
221
+
222
+
223
+ def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
224
+ """Turn a value into a color using a color map."""
225
+ norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
226
+ cmap = cm.get_cmap(cmap_name) # PiYG
227
+ rgba = cmap(norm(abs(value)))
228
+ color = matplotlib.colors.rgb2hex(rgba[:3])
229
+ return color