import gradio as gr import torch import numpy as np import pandas as pd from tqdm.auto import tqdm import matplotlib.pyplot as plt import matplotlib from IPython.display import display, HTML from transformers import AutoTokenizer from DecompX.src.globenc_utils import GlobencConfig from DecompX.src.modeling_bert import BertForSequenceClassification from DecompX.src.modeling_roberta import RobertaForSequenceClassification plt.style.use("ggplot") MODELS = ["WillHeld/roberta-base-sst2"] def plot_clf(tokens, logits, label_names, title="", file_name=None): print(tokens) plt.figure(figsize=(4.5, 5)) colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)] plt.barh(range(len(tokens)), logits, color=colors) plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2) plt.gca().invert_yaxis() max_limit = np.max(np.abs(logits)) + 0.2 min_limit = -0.01 if np.min(logits) > 0 else -max_limit plt.xlim(min_limit, max_limit) plt.gca().set_xticks([min_limit, max_limit]) plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold") plt.gca().set_yticks(range(len(tokens))) plt.gca().set_yticklabels(tokens) plt.gca().yaxis.tick_right() for xtick, color in zip(plt.gca().get_yticklabels(), colors): xtick.set_color(color) xtick.set_fontweight("bold") xtick.set_verticalalignment("center") for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]): xtick.set_color(color) # plt.title(title, fontsize=14, fontweight="bold") plt.title(title) plt.tight_layout() def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False): """ importance: (sent_len) """ if no_cls_sep: importance = importance[1:-1] tokenized_text = tokenized_text[1:-1] importance = importance / np.abs(importance).max() / 1.5 # Normalize if discrete: importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6 html = "
"+prefix
    for i in range(len(tokenized_text)):
        if importance[i] >= 0:
            rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i])   # Wistia
        else:
            rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i]))   # Wistia
        text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
        color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
        html += (f"")
        html += tokenized_text[i].replace('<', "[").replace(">", "]")
        html += " "
    html += "
" # display(HTML(html)) return html def print_preview(decompx_outputs_df, idx=0, discrete=False): html = "" NO_CLS_SEP = False df = decompx_outputs_df for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]: if col in df and df[col][idx] is not None: if "aggregated" in col: sentence_importance = df[col].iloc[idx][0, :] if "classifier" in col: for label in range(df[col].iloc[idx].shape[-1]): sentence_importance = df[col].iloc[idx][:, label] html += print_importance( sentence_importance, df["tokens"].iloc[idx], prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20), no_cls_sep=NO_CLS_SEP, discrete=False ) break sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]] html += print_importance( sentence_importance, df["tokens"].iloc[idx], prefix=f"{col.split('_')[-1]}:".ljust(20), no_cls_sep=NO_CLS_SEP, discrete=discrete ) return "
" + html def run_decompx(text, model): """ Provide DecompX Token Explanation of Model on Text """ SENTENCES = [text, "nothing"] CONFIGS = { "DecompX": GlobencConfig( include_biases=True, bias_decomp_type="absdot", include_LN1=True, include_FFN=True, FFN_approx_type="GeLU_ZO", include_LN2=True, aggregation="vector", include_classifier_w_pooler=True, tanh_approx_type="ZO", output_all_layers=True, output_attention=None, output_res1=None, output_LN1=None, output_FFN=None, output_res2=None, output_encoder=None, output_aggregated="norm", output_pooler="norm", output_classifier=True, ), } MODEL = model # LOAD MODEL AND TOKENIZER tokenizer = AutoTokenizer.from_pretrained(MODEL) tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True) batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1) if "roberta" in MODEL: model = RobertaForSequenceClassification.from_pretrained(MODEL) elif "bert" in MODEL: model = BertForSequenceClassification.from_pretrained(MODEL) else: raise Exception(f"Not implemented model: {MODEL}") # RUN DECOMPX with torch.no_grad(): model.eval() logits, hidden_states, globenc_last_layer_outputs, globenc_all_layers_outputs = model( **tokenized_sentence, output_attentions=False, return_dict=False, output_hidden_states=True, globenc_config=CONFIGS["DecompX"] ) decompx_outputs = { "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))], "logits": logits.cpu().detach().numpy().tolist(), # (batch, classes) "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim) } ### globenc_last_layer_outputs.classifier ~ (8, 55, 2) ### importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes) importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))] decompx_outputs["importance_last_layer_classifier"] = importance ### globenc_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ### importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len) importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len) importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))] decompx_outputs["importance_all_layers_aggregated"] = importance decompx_outputs_df = pd.DataFrame(decompx_outputs) idx = 0 pred_label = np.argmax(decompx_outputs_df.iloc[idx]["logits"], axis=-1) label = decompx_outputs_df.iloc[idx]["importance_last_layer_classifier"][:, pred_label] tokens = decompx_outputs_df.iloc[idx]["tokens"][1:-1] label = label[1:-1] label = label / np.max(np.abs(label)) plot_clf(tokens, label, ['-','+'], title=f"DecompX for Predicted Label: {pred_label}", file_name="example_sst2_our_method") return plt, print_preview(decompx_outputs_df) demo = gr.Interface( fn=run_decompx, inputs=[ gr.components.Textbox(label="Text"), gr.components.Dropdown(label="Model", choices=MODELS), ], outputs=["plot", "html"], examples=[["Building a translation demo with Gradio is so easy!", "WillHeld/roberta-base-sst2"]], cache_examples=False, title="DecompX Demo", description="This demo is a simplified version of the original [NLLB-Translator](https://huggingface.co/spaces/Narrativaai/NLLB-Translator) space" ) demo.launch()