DecompX / app.py
mohsenfayyaz's picture
Create app.py
094135a
raw
history blame
8.32 kB
import gradio as gr
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
from transformers import AutoTokenizer
from DecompX.src.globenc_utils import GlobencConfig
from DecompX.src.modeling_bert import BertForSequenceClassification
from DecompX.src.modeling_roberta import RobertaForSequenceClassification
plt.style.use("ggplot")
MODELS = ["WillHeld/roberta-base-sst2"]
def plot_clf(tokens, logits, label_names, title="", file_name=None):
print(tokens)
plt.figure(figsize=(4.5, 5))
colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)]
plt.barh(range(len(tokens)), logits, color=colors)
plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2)
plt.gca().invert_yaxis()
max_limit = np.max(np.abs(logits)) + 0.2
min_limit = -0.01 if np.min(logits) > 0 else -max_limit
plt.xlim(min_limit, max_limit)
plt.gca().set_xticks([min_limit, max_limit])
plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold")
plt.gca().set_yticks(range(len(tokens)))
plt.gca().set_yticklabels(tokens)
plt.gca().yaxis.tick_right()
for xtick, color in zip(plt.gca().get_yticklabels(), colors):
xtick.set_color(color)
xtick.set_fontweight("bold")
xtick.set_verticalalignment("center")
for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]):
xtick.set_color(color)
# plt.title(title, fontsize=14, fontweight="bold")
plt.title(title)
plt.tight_layout()
def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False):
"""
importance: (sent_len)
"""
if no_cls_sep:
importance = importance[1:-1]
tokenized_text = tokenized_text[1:-1]
importance = importance / np.abs(importance).max() / 1.5 # Normalize
if discrete:
importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6
html = "<pre style='color:black; padding: 3px;'>"+prefix
for i in range(len(tokenized_text)):
if importance[i] >= 0:
rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i]) # Wistia
else:
rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i])) # Wistia
text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
html += (f"<span style='"
f"{color}"
f"color:black; border-radius: 5px; padding: 3px;"
f"font-weight: {int(800)};"
"'>")
html += tokenized_text[i].replace('<', "[").replace(">", "]")
html += "</span> "
html += "</pre>"
# display(HTML(html))
return html
def print_preview(decompx_outputs_df, idx=0, discrete=False):
html = ""
NO_CLS_SEP = False
df = decompx_outputs_df
for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]:
if col in df and df[col][idx] is not None:
if "aggregated" in col:
sentence_importance = df[col].iloc[idx][0, :]
if "classifier" in col:
for label in range(df[col].iloc[idx].shape[-1]):
sentence_importance = df[col].iloc[idx][:, label]
html += print_importance(
sentence_importance,
df["tokens"].iloc[idx],
prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
no_cls_sep=NO_CLS_SEP,
discrete=False
)
break
sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
html += print_importance(
sentence_importance,
df["tokens"].iloc[idx],
prefix=f"{col.split('_')[-1]}:".ljust(20),
no_cls_sep=NO_CLS_SEP,
discrete=discrete
)
return "<div style='overflow:auto; background-color:white; padding: 10px;'>" + html
def run_decompx(text, model):
"""
Provide DecompX Token Explanation of Model on Text
"""
SENTENCES = [text, "nothing"]
CONFIGS = {
"DecompX":
GlobencConfig(
include_biases=True,
bias_decomp_type="absdot",
include_LN1=True,
include_FFN=True,
FFN_approx_type="GeLU_ZO",
include_LN2=True,
aggregation="vector",
include_classifier_w_pooler=True,
tanh_approx_type="ZO",
output_all_layers=True,
output_attention=None,
output_res1=None,
output_LN1=None,
output_FFN=None,
output_res2=None,
output_encoder=None,
output_aggregated="norm",
output_pooler="norm",
output_classifier=True,
),
}
MODEL = model
# LOAD MODEL AND TOKENIZER
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True)
batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)
if "roberta" in MODEL:
model = RobertaForSequenceClassification.from_pretrained(MODEL)
elif "bert" in MODEL:
model = BertForSequenceClassification.from_pretrained(MODEL)
else:
raise Exception(f"Not implemented model: {MODEL}")
# RUN DECOMPX
with torch.no_grad():
model.eval()
logits, hidden_states, globenc_last_layer_outputs, globenc_all_layers_outputs = model(
**tokenized_sentence,
output_attentions=False,
return_dict=False,
output_hidden_states=True,
globenc_config=CONFIGS["DecompX"]
)
decompx_outputs = {
"tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
"logits": logits.cpu().detach().numpy().tolist(), # (batch, classes)
"cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
}
### globenc_last_layer_outputs.classifier ~ (8, 55, 2) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes)
importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
decompx_outputs["importance_last_layer_classifier"] = importance
### globenc_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len)
importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len)
importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
decompx_outputs["importance_all_layers_aggregated"] = importance
decompx_outputs_df = pd.DataFrame(decompx_outputs)
idx = 0
pred_label = np.argmax(decompx_outputs_df.iloc[idx]["logits"], axis=-1)
label = decompx_outputs_df.iloc[idx]["importance_last_layer_classifier"][:, pred_label]
tokens = decompx_outputs_df.iloc[idx]["tokens"][1:-1]
label = label[1:-1]
label = label / np.max(np.abs(label))
plot_clf(tokens, label, ['-','+'], title=f"DecompX for Predicted Label: {pred_label}", file_name="example_sst2_our_method")
return plt, print_preview(decompx_outputs_df)
demo = gr.Interface(
fn=run_decompx,
inputs=[
gr.components.Textbox(label="Text"),
gr.components.Dropdown(label="Model", choices=MODELS),
],
outputs=["plot", "html"],
examples=[["Building a translation demo with Gradio is so easy!", "WillHeld/roberta-base-sst2"]],
cache_examples=False,
title="DecompX Demo",
description="This demo is a simplified version of the original [NLLB-Translator](https://huggingface.co/spaces/Narrativaai/NLLB-Translator) space"
)
demo.launch()