Spaces:
Build error
Build error
import gradio as gr | |
import transformers | |
import torch | |
#import neptune | |
#from knockknock import slack_sender | |
from transformers import * | |
#import glob | |
from transformers import BertTokenizer | |
from transformers import BertForSequenceClassification, AdamW, BertConfig | |
import random | |
import pandas as pd | |
from transformers import BertTokenizer | |
#from Models.utils import masked_cross_entropy,fix_the_random,format_time,save_normal_model,save_bert_model | |
from sklearn.metrics import accuracy_score,f1_score | |
from tqdm import tqdm | |
'''from TensorDataset.datsetSplitter import createDatasetSplit | |
from TensorDataset.dataLoader import combine_features | |
from Preprocess.dataCollect import collect_data,set_name''' | |
from sklearn.metrics import accuracy_score,f1_score,roc_auc_score,recall_score,precision_score | |
import matplotlib.pyplot as plt | |
import time | |
import os | |
from transformers import BertTokenizer | |
#import GPUtil | |
from sklearn.utils import class_weight | |
#import json | |
#from Models.bertModels import * | |
#from Models.otherModels import * | |
import sys | |
#import time | |
#from waiting import wait | |
from sklearn.preprocessing import LabelEncoder | |
import numpy as np | |
#import threading | |
#import argparse | |
#import ast | |
#from manual_training_inference import select_model | |
#from Models.utils import save_normal_model,save_bert_model,load_model | |
#from Models.utils import return_params | |
from transformers import DistilBertTokenizer | |
#from TensorDataset.dataLoader import custom_att_masks | |
#from keras.preprocessing.sequence import pad_sequences | |
#import seaborn as sns | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import PIL.Image as Image | |
from torch import nn | |
from pyvene import embed_to_distrib, top_vals, format_token | |
from pyvene import ( | |
IntervenableModel, | |
VanillaIntervention, Intervention, | |
RepresentationConfig, | |
IntervenableConfig, | |
ConstantSourceIntervention, | |
LocalistRepresentationIntervention | |
) | |
from pyvene import create_gpt2 | |
#%config InlineBackend.figure_formats = ['svg'] | |
from plotnine import ( | |
ggplot, | |
geom_tile, | |
aes, | |
facet_wrap, | |
theme, | |
element_text, | |
geom_bar, | |
geom_hline, | |
scale_y_log10, | |
xlab, ylab, ylim, | |
scale_y_discrete, scale_y_continuous, ggsave | |
) | |
from plotnine.scales import scale_y_reverse, scale_fill_cmap | |
from tqdm import tqdm | |
global device | |
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def create_bert(cache_dir=None): | |
"""Creates a GPT2 model, config, and tokenizer from the given name and revision""" | |
from transformers import BertConfig | |
config = BertConfig.from_pretrained("./cs772_proj/bert_base/checkpoint-3848/config.json") | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
gpt = AutoModelForSequenceClassification.from_pretrained("./cs772_proj/bert_base/checkpoint-3848", config=config, cache_dir=cache_dir) | |
print("loaded model") | |
return config, tokenizer, gpt | |
def interpret(text,label): | |
titles={ | |
"block_output": "single restored layer in BERT", | |
"mlp_activation": "center of interval of 5 patched mlp layer", | |
"attention_output": "center of interval of 5 patched attn layer" | |
} | |
colors={ | |
"block_output": "Purples", | |
"mlp_activation": "Greens", | |
"attention_output": "Reds" | |
} | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
#config, tokenizer, gpt = pv.create_llama(name="sharpbai/alpaca-7b-merged") | |
config, tokenizer, gpt = create_bert() | |
#config, tokenizer, gpt = create_gpt2(name="gpt2-xl") | |
gpt.to(device) | |
base = text | |
inputs = [ | |
tokenizer(base, return_tensors="pt").to(device), | |
] | |
#print(base) | |
base_token = tokenizer.convert_ids_to_tokens(inputs[0]['input_ids'][0]) | |
res = gpt(**inputs[0]) | |
probabilities = nn.functional.softmax(res[0], dim=-1) | |
if label=="hate": | |
l = 0 | |
elif label=="normal": | |
l=1 | |
else:l=2 | |
#print(probabilities) | |
#print(res[0][0][0].item()) | |
#print(res) | |
#distrib = embed_to_distrib(gpt, res.last_hidden_state, logits=False) | |
#top_vals(tokenizer, distrib[0][-1], n=20) | |
base = tokenizer(text, return_tensors="pt").to(device) | |
config = corrupted_config(type(gpt)) | |
intervenable = IntervenableModel(config, gpt) | |
_, counterfactual_outputs = intervenable( | |
base, unit_locations={"base": ([[[0,1,2,3]]])} | |
) | |
#probabilities = nn.functional.softmax(counterfactual_outputs[0], dim=-1) | |
#print(probabilities) | |
for stream in ["block_output", "mlp_activation", "attention_output"]: | |
data = [] | |
for layer_i in tqdm(range(gpt.config.num_hidden_layers)): | |
for pos_i in range(len(base_token)): | |
config = restore_corrupted_with_interval_config( | |
layer_i, stream, | |
window=1 if stream == "block_output" else 5 | |
) | |
n_restores = len(config.representations) - 1 | |
intervenable = IntervenableModel(config, gpt) | |
_, counterfactual_outputs = intervenable( | |
base, | |
[None] + [base]*n_restores, | |
{ | |
"sources->base": ( | |
[None] + [[[pos_i]]]*n_restores, | |
[[[0,1,2,3]]] + [[[pos_i]]]*n_restores, | |
) | |
}, | |
) | |
#distrib = embed_to_distrib( | |
#gpt, counterfactual_outputs.last_hidden_state, logits=False | |
#) | |
#prob = distrib[0][-1][token].detach().cpu().item() | |
logits = counterfactual_outputs[0] | |
probabilities = nn.functional.softmax(logits, dim=-1) | |
prob_offense = probabilities[0][l].item() | |
data.append({"layer": layer_i, "pos": pos_i, "prob": prob_offense}) | |
df = pd.DataFrame(data) | |
df.to_csv(f"./cs772_proj/tutorial_data/pyvene_rome_{stream}.csv") | |
for stream in ["block_output", "mlp_activation", "attention_output"]: | |
df = pd.read_csv(f"./cs772_proj/tutorial_data/pyvene_rome_{stream}.csv") | |
df["layer"] = df["layer"].astype(int) | |
df["pos"] = df["pos"].astype(int) | |
prob_type = "p"+"("+label+")" | |
df[prob_type] = df["prob"].astype(float) | |
#custom_labels = ["imagine*","the*", "riots*", "if", "people", "actually", "got" ,"food" ,"boxes" ,"instead", "of" ,"ebt", "cards", "every", "ghetto", "in", "america", "would" ,"look", "like", "ferguson"] | |
custom_labels = base_token #["what*", "sort*", "of*", "white*","man" ,"or", "woman", "would", "vote", "for", "this", "nigger"] | |
#custom_labels = ["no*", "liberal*","congratulated*", "hindu*", "refugees", "post", "cab", "because", "they", "hate", "hindus"] | |
breaks = list(range(len(custom_labels)))#[0, 1, 2, 3, 4, 5, 6,7,8,9,10,11] | |
plot = ( | |
ggplot(df, aes(x="layer", y="pos")) | |
+ geom_tile(aes(fill=prob_type)) | |
+ scale_fill_cmap(colors[stream]) + xlab(titles[stream]) | |
+ scale_y_reverse( | |
limits = (-0.5, len(custom_labels)), | |
breaks=breaks, labels=custom_labels) | |
+ theme(figure_size=(6,9)) + ylab("") | |
+ theme(axis_text_y = element_text(angle = 90, hjust = 1)) | |
) | |
ggsave( | |
plot, filename=f"./cs772_proj/tutorial_data/pyvene_rome_{stream}.png", dpi=200 | |
) | |
if stream == "mlp_activation": | |
mlp_img_path = f"./cs772_proj/tutorial_data/pyvene_rome_{stream}.png" | |
elif stream=="block_output": | |
bo_path = f"./cs772_proj/tutorial_data/pyvene_rome_{stream}.png" | |
else:attention_path = f"./cs772_proj/tutorial_data/pyvene_rome_{stream}.png" | |
return mlp_img_path,bo_path,attention_path | |
def restore_corrupted_with_interval_config( | |
layer, stream="mlp_activation", window=5, num_layers=12): | |
start = max(0, layer - window // 2) | |
end = min(num_layers, layer - (-window // 2)) | |
config = IntervenableConfig( | |
representations=[ | |
RepresentationConfig( | |
0, # layer | |
"block_input", # intervention type | |
), | |
] + [ | |
RepresentationConfig( | |
i, # layer | |
stream, # intervention type | |
) for i in range(start, end)], | |
intervention_types=\ | |
[NoiseIntervention]+[VanillaIntervention]*(end-start) | |
) | |
return config | |
class NoiseIntervention(ConstantSourceIntervention, LocalistRepresentationIntervention): | |
def __init__(self, embed_dim, **kwargs): | |
super().__init__() | |
self.interchange_dim = embed_dim | |
rs = np.random.RandomState(1) | |
prng = lambda *shape: rs.randn(*shape) | |
self.noise = torch.from_numpy( | |
prng(1, 4, embed_dim)).to(device) | |
self.noise_level = 0.7462981581687927 #0.3462981581687927 | |
def forward(self, base, source=None, subspaces=None): | |
base[..., : self.interchange_dim] += self.noise * self.noise_level | |
return base | |
def __str__(self): | |
return f"NoiseIntervention(embed_dim={self.embed_dim})" | |
def corrupted_config(model_type): | |
config = IntervenableConfig( | |
model_type=model_type, | |
representations=[ | |
RepresentationConfig( | |
0, # layer | |
"block_input", # intervention type | |
), | |
], | |
intervention_types=NoiseIntervention, | |
) | |
return config | |
def create_bert(cache_dir=None): | |
"""Creates a GPT2 model, config, and tokenizer from the given name and revision""" | |
from transformers import BertConfig | |
config = BertConfig.from_pretrained("./cs772_proj/bert_base/checkpoint-3848/config.json") | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
gpt = AutoModelForSequenceClassification.from_pretrained("./cs772_proj/bert_base/checkpoint-3848", config=config, cache_dir=cache_dir) | |
print("loaded model") | |
return config, tokenizer, gpt | |
# params = return_params('best_model_json/distilbert.json', 0.001 ) | |
#params = return_params('best_model_json/distilbert.json', 1 ) | |
'''embeddings=None | |
if(params['bert_tokens']): | |
train,val,test=createDatasetSplit(params) #update | |
else: | |
train,val,test,vocab_own=createDatasetSplit(params) | |
params['embed_size']=vocab_own.embeddings.shape[1] | |
params['vocab_size']=vocab_own.embeddings.shape[0] | |
embeddings=vocab_own.embeddings | |
if(params['auto_weights']): | |
y_test = [ele[2] for ele in test] | |
# print(y_test) | |
encoder = LabelEncoder() | |
encoder.classes_ = np.load(params['class_names'],allow_pickle=True) | |
params['weights']=class_weight.compute_class_weight('balanced',np.unique(y_test),y_test).astype('float32') | |
#params['weights']=np.array([len(y_test)/y_test.count(encoder.classes_[0]),len(y_test)/y_test.count(encoder.classes_[1]),len(y_test)/y_test.count(encoder.classes_[2])]).astype('float32') | |
model=select_model(params,embeddings) | |
model = model.eval() | |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') | |
classes_ = np.load('Data/classes.npy') | |
''' | |
def main_function(text,label): | |
'''tokens = tokenizer.encode_plus(text) | |
input_ids = pad_sequences(torch.tensor(tokens['input_ids']).unsqueeze(0),maxlen=int(params['max_length']),\ | |
dtype="long", | |
value=0, truncating="post", padding="post") | |
# att_vals = pad_sequences(att_vals,maxlen=int(params['max_length']), dtype="float", | |
# value=0.0, truncating="post", padding="post") | |
att_masks=custom_att_masks(input_ids) | |
outs = model(torch.tensor(input_ids), | |
attention_mask=torch.tensor(att_masks, dtype=bool), | |
labels=None, | |
device='cuda') | |
text_tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze()) | |
text_tokens_ = text_tokens[:len(tokens['input_ids'])] | |
print ('xyz') | |
print (outs[1][5].shape) | |
avg_attn = torch.mean(outs[1][5], dim=1) | |
avg_attn_np = avg_attn[0,0,:len(tokens['input_ids'])].detach().squeeze().numpy() | |
logits = outs[0] | |
print (logits) | |
print (np.sum(avg_attn_np)) | |
print (avg_attn_np) | |
pred = torch.argmax(logits) | |
pred_label = classes_[pred] | |
''' | |
ml_img_path,bo_img_path,atten_img_path = interpret(text,label) | |
ml_im = Image.open(ml_img_path) | |
bo_im = Image.open(bo_img_path) | |
atten_im = Image.open(atten_img_path) | |
yield ml_im, bo_im, atten_im | |
''' | |
sns.set_theme(rc={'figure.figsize':(30,1)}) | |
# creating subplot | |
fig, ax = plt.subplots() | |
# drawing heatmap on current axes | |
ax = sns.heatmap(np.expand_dims(avg_attn_np,0), annot= np.expand_dims(np.array(text_tokens_),0), \ | |
fmt="", annot_kws={'size': 10}, cmap="magma") | |
fig = ax.get_figure() | |
fig.savefig("out.png" ,bbox_inches='tight') | |
im = Image.open("out.png") | |
yield im | |
''' | |
#return list(zip(text_tokens_ , avg_attn_np)), pred_label | |
# return list(zip(text_tokens_[1:-1] , avg_attn_np[1:-1])) | |
demo = gr.Interface(main_function, | |
inputs="textbox", | |
outputs="image", | |
theme = 'compact') | |
with gr.Blocks() as demo: | |
with gr.Tab("Text Input"): | |
text_input = gr.Textbox() | |
label_input = gr.Textbox() | |
text_button = gr.Button("Show") | |
with gr.Tab("Interpretability"): | |
with gr.Row(): | |
image_output1 = gr.Image() | |
image_output2 = gr.Image() | |
image_output3 = gr.Image() | |
text_button.click(main_function, inputs=[text_input,label_input], outputs=[image_output1,image_output2,image_output3]) | |
if __name__ == "__main__": | |
demo.launch() |