File size: 7,294 Bytes
6e367e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from os import WEXITED
import streamlit as st
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import torch
from spectral_metric.estimator import CumulativeGradientEstimator
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from spectral_metric.visualize import make_graph
from scipy.stats import entropy
import pandas as pd

from utils import show_most_confused


AVAILABLE_DATASETS = [
    ("clinc_oos", "small"),
    ("clinc_oos", "imbalanced"),
    ("banking77",),
    ("tweet_eval", "emoji"),
    ("tweet_eval", "stance_climate")
]

label_column_mapping = {
    "clinc_oos": "intent",
    "banking77": "label",
    "tweet_eval": "label",
}

st.title("Perform a data-driven analysis using `spectral-metric`")
st.markdown(
    """Today, I would like to analyze this dataset and perform a
 data-driven analysis by `sentence-transformers` to extract features
  and `spectral_metric` to perform a spectral analysis of the dataset.

For support, please submit an issue on [our repo](https://github.com/Dref360/spectral-metric) or [contact me directly](https://github.com/Dref360)
"""
)

st.markdown(
    """
Let's load your dataset, we will run our analysis on the train set.
"""
)

dataset_name = st.selectbox("Select your dataset", AVAILABLE_DATASETS)
if st.button("Start the analysis"):

    label_column = label_column_mapping[dataset_name[0]]

    # We perform the analysis on the train set.
    ds = load_dataset(*dataset_name)["train"]
    class_names = ds.features[label_column].names
    ds

    # I use all-MiniLM-L12-v2 as it is a good compromise between speed and performance.
    embedder = SentenceTransformer("all-MiniLM-L12-v2")
    # We will get **normalized** features for the dataset using our embedder.
    with st.spinner(text="Computing embeddings..."):
        features = embedder.encode(
            ds["text"],
            device=0 if torch.cuda.is_available() else "cpu",
            normalize_embeddings=True,
        )

    st.markdown(
        """
    ### Running the spectral analysis

    Now that we have our embeddings extracted by our sentence embedder, we can make an in-depth analysis of these features.

    To do so, we will use CSG (Branchaud-Charron et al, 2019), a technique that combines Probability Product Kernels (Jebara et al, 2004) and spectral clustering to analyze a dataset without training a model. 

    In this notebook, we won't use the actual CSG metrics, but we will use the $W$ matrix. 
    This matrix is computed as:
    * Run a Probabilistic K-NN on the dataset (optionally done via Monte-Carlo)
    * Compute the average prediction per class (results in the $S$ matrix)
    * Symetrize this matrix using Bray-Curtis distance metric, a metric that was made to compare samplings from a distribution.

    These steps are all done by `spectral_metric.estimator.CumulativeGradientEstimator`.
    """
    )
    X, y = features, np.array(ds[label_column])  # Your dataset with shape [N, ?], [N]
    estimator = CumulativeGradientEstimator(M_sample=250, k_nearest=9, distance="cosine")
    estimator.fit(data=X, target=y)

    fig, ax = plt.subplots(figsize=(10, 5))
    sns.heatmap(estimator.W, ax=ax, cmap="rocket_r")
    ax.set_title(f"Similarity between classes in {dataset_name[0]}")
    st.pyplot(fig)

    st.markdown(
        """
        This figure will be hard to read on most datasets, so we need to go deeper. 
        Let's do the following analysis:
        1. Find the class with the highest entropy ie. the class that is the most confused with others.
        2. Find the 5 pairs of classes that are the most confused.
        3. Find the items in these pairs that contribute to the confusion.
    """
    )


    entropy_per_class = entropy(estimator.W / estimator.W.sum(-1)[:, None], axis=-1)
    st.markdown(
        f"Most confused class (highest entropy): {class_names[np.argmax(entropy_per_class)]}",
    )
    st.markdown(
        f"Least confused class (lowest entropy): {class_names[np.argmin(entropy_per_class)]}",
    )

    pairs = list(zip(*np.unravel_index(np.argsort(estimator.W, axis=None), estimator.W.shape)))[::-1]
    pairs = [(i,j) for i,j in pairs if i != j]

    lst = []
    for idx, (i,j) in enumerate(pairs[::2][:10]):
        lst.append({"Intent A" : class_names[i], "Intent B": class_names[j], "Similarity": estimator.W[i,j]})

    st.title("Most similar pairs")
    st.dataframe(pd.DataFrame(lst).sort_values("Similarity", ascending=False))


    st.markdown("""
    ## Analysis
    By looking at the top-10 most similar pairs, we get some good insights on the dataset.
    While this does not 100% indicates that the classifier trained downstream will have issues with these pairs,
    we know that these intents are similar.
    In consequence, the classifier might not be able to separate them easily.


    Let's now look at which utterance is contributing the most to the confusion.
    """)

    first_pair = pairs[0]
    second_pair = pairs[2]    
    st.dataframe(pd.DataFrame({**show_most_confused(ds,first_pair[0], first_pair[1], estimator, class_names),
                               **show_most_confused(ds, first_pair[1], first_pair[0], estimator, class_names)}),
                               width=1000)

    st.markdown("### We can do the same for the second pair")

    st.dataframe(pd.DataFrame({**show_most_confused(ds, second_pair[0], second_pair[1], estimator, class_names),
                               **show_most_confused(ds, second_pair[1], second_pair[0], estimator, class_names)}),
                               width=1000)

    st.markdown(f"""
    From the top-5 most confused examples per pair, we can see that the sentences are quite similar.
    While a human could easily separate the two intents, we see that the sentences are made of the same words which might confuse the classifier.

    Some sentences could be seen as mislabelled.
    Of course, these features come from a model that was not trained to separate these classes,
    they come from a general-purpose language model. 
    The goal of this analysis is to give insights to the data scientist before they train an expensive model.
    If we were to train a model on this dataset, the model could probably handle the confusion between `{class_names[first_pair[0]]}`
     and `{class_names[first_pair[1]]}`,
     but maybe not easily.


    ## Conclusion

    In this tutorial, we covered how to conduct a data-driven analysis for on a text classification dataset.
    By using sentence embedding and the `spectral_metric` library, we found the intents that would be the most likely to be confused and which utterances caused this confusion.

    Following our analysis, we could take the following actions:
    1. Upweight the classes that are confused during training for the model to better learn to separate them.
    2. Merge similar classes together.
    3. Analyse sentences that are confusing to find mislabelled sentences.

    If you have any questions, suggestions or ideas for this library please reach out:

    1. [email protected]
    2. [@Dref360 on Github](https://github.com/Dref360)


    If you have a dataset that you think would be a good fit for this analysis let me know too!
    """)