Spaces:
Runtime error
Runtime error
from os import WEXITED | |
import streamlit as st | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
import torch | |
from spectral_metric.estimator import CumulativeGradientEstimator | |
import numpy as np | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from spectral_metric.visualize import make_graph | |
from scipy.stats import entropy | |
import pandas as pd | |
from utils import show_most_confused | |
AVAILABLE_DATASETS = [ | |
("clinc_oos", "small"), | |
("clinc_oos", "imbalanced"), | |
("banking77",), | |
("tweet_eval", "emoji"), | |
("tweet_eval", "stance_climate") | |
] | |
label_column_mapping = { | |
"clinc_oos": "intent", | |
"banking77": "label", | |
"tweet_eval": "label", | |
} | |
st.title("Perform a data-driven analysis using `spectral-metric`") | |
st.markdown( | |
"""Today, I would like to analyze this dataset and perform a | |
data-driven analysis by `sentence-transformers` to extract features | |
and `spectral_metric` to perform a spectral analysis of the dataset. | |
For support, please submit an issue on [our repo](https://github.com/Dref360/spectral-metric) or [contact me directly](https://github.com/Dref360) | |
""" | |
) | |
st.markdown( | |
""" | |
Let's load your dataset, we will run our analysis on the train set. | |
""" | |
) | |
dataset_name = st.selectbox("Select your dataset", AVAILABLE_DATASETS) | |
if st.button("Start the analysis"): | |
label_column = label_column_mapping[dataset_name[0]] | |
# We perform the analysis on the train set. | |
ds = load_dataset(*dataset_name)["train"] | |
class_names = ds.features[label_column].names | |
ds | |
# I use all-MiniLM-L12-v2 as it is a good compromise between speed and performance. | |
embedder = SentenceTransformer("all-MiniLM-L12-v2") | |
# We will get **normalized** features for the dataset using our embedder. | |
with st.spinner(text="Computing embeddings..."): | |
features = embedder.encode( | |
ds["text"], | |
device=0 if torch.cuda.is_available() else "cpu", | |
normalize_embeddings=True, | |
) | |
st.markdown( | |
""" | |
### Running the spectral analysis | |
Now that we have our embeddings extracted by our sentence embedder, we can make an in-depth analysis of these features. | |
To do so, we will use CSG (Branchaud-Charron et al, 2019), a technique that combines Probability Product Kernels (Jebara et al, 2004) and spectral clustering to analyze a dataset without training a model. | |
In this notebook, we won't use the actual CSG metrics, but we will use the $W$ matrix. | |
This matrix is computed as: | |
* Run a Probabilistic K-NN on the dataset (optionally done via Monte-Carlo) | |
* Compute the average prediction per class (results in the $S$ matrix) | |
* Symetrize this matrix using Bray-Curtis distance metric, a metric that was made to compare samplings from a distribution. | |
These steps are all done by `spectral_metric.estimator.CumulativeGradientEstimator`. | |
""" | |
) | |
X, y = features, np.array(ds[label_column]) # Your dataset with shape [N, ?], [N] | |
estimator = CumulativeGradientEstimator(M_sample=250, k_nearest=9, distance="cosine") | |
estimator.fit(data=X, target=y) | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
sns.heatmap(estimator.W, ax=ax, cmap="rocket_r") | |
ax.set_title(f"Similarity between classes in {dataset_name[0]}") | |
st.pyplot(fig) | |
st.markdown( | |
""" | |
This figure will be hard to read on most datasets, so we need to go deeper. | |
Let's do the following analysis: | |
1. Find the class with the highest entropy ie. the class that is the most confused with others. | |
2. Find the 5 pairs of classes that are the most confused. | |
3. Find the items in these pairs that contribute to the confusion. | |
""" | |
) | |
entropy_per_class = entropy(estimator.W / estimator.W.sum(-1)[:, None], axis=-1) | |
st.markdown( | |
f"Most confused class (highest entropy): {class_names[np.argmax(entropy_per_class)]}", | |
) | |
st.markdown( | |
f"Least confused class (lowest entropy): {class_names[np.argmin(entropy_per_class)]}", | |
) | |
pairs = list(zip(*np.unravel_index(np.argsort(estimator.W, axis=None), estimator.W.shape)))[::-1] | |
pairs = [(i,j) for i,j in pairs if i != j] | |
lst = [] | |
for idx, (i,j) in enumerate(pairs[::2][:10]): | |
lst.append({"Intent A" : class_names[i], "Intent B": class_names[j], "Similarity": estimator.W[i,j]}) | |
st.title("Most similar pairs") | |
st.dataframe(pd.DataFrame(lst).sort_values("Similarity", ascending=False)) | |
st.markdown(""" | |
## Analysis | |
By looking at the top-10 most similar pairs, we get some good insights on the dataset. | |
While this does not 100% indicates that the classifier trained downstream will have issues with these pairs, | |
we know that these intents are similar. | |
In consequence, the classifier might not be able to separate them easily. | |
Let's now look at which utterance is contributing the most to the confusion. | |
""") | |
first_pair = pairs[0] | |
second_pair = pairs[2] | |
st.dataframe(pd.DataFrame({**show_most_confused(ds,first_pair[0], first_pair[1], estimator, class_names), | |
**show_most_confused(ds, first_pair[1], first_pair[0], estimator, class_names)}), | |
width=1000) | |
st.markdown("### We can do the same for the second pair") | |
st.dataframe(pd.DataFrame({**show_most_confused(ds, second_pair[0], second_pair[1], estimator, class_names), | |
**show_most_confused(ds, second_pair[1], second_pair[0], estimator, class_names)}), | |
width=1000) | |
st.markdown(f""" | |
From the top-5 most confused examples per pair, we can see that the sentences are quite similar. | |
While a human could easily separate the two intents, we see that the sentences are made of the same words which might confuse the classifier. | |
Some sentences could be seen as mislabelled. | |
Of course, these features come from a model that was not trained to separate these classes, | |
they come from a general-purpose language model. | |
The goal of this analysis is to give insights to the data scientist before they train an expensive model. | |
If we were to train a model on this dataset, the model could probably handle the confusion between `{class_names[first_pair[0]]}` | |
and `{class_names[first_pair[1]]}`, | |
but maybe not easily. | |
## Conclusion | |
In this tutorial, we covered how to conduct a data-driven analysis for on a text classification dataset. | |
By using sentence embedding and the `spectral_metric` library, we found the intents that would be the most likely to be confused and which utterances caused this confusion. | |
Following our analysis, we could take the following actions: | |
1. Upweight the classes that are confused during training for the model to better learn to separate them. | |
2. Merge similar classes together. | |
3. Analyse sentences that are confusing to find mislabelled sentences. | |
If you have any questions, suggestions or ideas for this library please reach out: | |
1. [email protected] | |
2. [@Dref360 on Github](https://github.com/Dref360) | |
If you have a dataset that you think would be a good fit for this analysis let me know too! | |
""") | |