shambhavi3 commited on
Commit
959622f
·
verified ·
1 Parent(s): 0c7c487

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -0
app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import transformers
4
+ import torch
5
+ #import neptune
6
+ #from knockknock import slack_sender
7
+ from transformers import *
8
+ #import glob
9
+ from transformers import BertTokenizer
10
+ from transformers import BertForSequenceClassification, AdamW, BertConfig
11
+ import random
12
+ import pandas as pd
13
+ from transformers import BertTokenizer
14
+ #from Models.utils import masked_cross_entropy,fix_the_random,format_time,save_normal_model,save_bert_model
15
+ from sklearn.metrics import accuracy_score,f1_score
16
+ from tqdm import tqdm
17
+ '''from TensorDataset.datsetSplitter import createDatasetSplit
18
+ from TensorDataset.dataLoader import combine_features
19
+ from Preprocess.dataCollect import collect_data,set_name'''
20
+ from sklearn.metrics import accuracy_score,f1_score,roc_auc_score,recall_score,precision_score
21
+ import matplotlib.pyplot as plt
22
+ import time
23
+ import os
24
+ from transformers import BertTokenizer
25
+ #import GPUtil
26
+ from sklearn.utils import class_weight
27
+ #import json
28
+ #from Models.bertModels import *
29
+ #from Models.otherModels import *
30
+ import sys
31
+ #import time
32
+ #from waiting import wait
33
+ from sklearn.preprocessing import LabelEncoder
34
+ import numpy as np
35
+ #import threading
36
+ #import argparse
37
+ #import ast
38
+
39
+ #from manual_training_inference import select_model
40
+ #from Models.utils import save_normal_model,save_bert_model,load_model
41
+ #from Models.utils import return_params
42
+ from transformers import DistilBertTokenizer
43
+
44
+
45
+ #from TensorDataset.dataLoader import custom_att_masks
46
+ #from keras.preprocessing.sequence import pad_sequences
47
+
48
+ #import seaborn as sns
49
+ import matplotlib.pyplot as plt
50
+ import numpy as np
51
+ import PIL.Image as Image
52
+ from torch import nn
53
+
54
+ from pyvene import embed_to_distrib, top_vals, format_token
55
+ from pyvene import (
56
+ IntervenableModel,
57
+ VanillaIntervention, Intervention,
58
+ RepresentationConfig,
59
+ IntervenableConfig,
60
+ ConstantSourceIntervention,
61
+ LocalistRepresentationIntervention
62
+ )
63
+ from pyvene import create_gpt2
64
+ #%config InlineBackend.figure_formats = ['svg']
65
+ from plotnine import (
66
+ ggplot,
67
+ geom_tile,
68
+ aes,
69
+ facet_wrap,
70
+ theme,
71
+ element_text,
72
+ geom_bar,
73
+ geom_hline,
74
+ scale_y_log10,
75
+ xlab, ylab, ylim,
76
+ scale_y_discrete, scale_y_continuous, ggsave
77
+ )
78
+ from plotnine.scales import scale_y_reverse, scale_fill_cmap
79
+ from tqdm import tqdm
80
+ global device
81
+ device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
82
+ def create_bert(cache_dir=None):
83
+ """Creates a GPT2 model, config, and tokenizer from the given name and revision"""
84
+ from transformers import BertConfig
85
+
86
+ config = BertConfig.from_pretrained("./cs77_proj/bert_base/checkpoint-3848/config.json")
87
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
88
+ gpt = AutoModelForSequenceClassification.from_pretrained("./cs77_proj/bert_base/checkpoint-3848", config=config, cache_dir=cache_dir)
89
+ print("loaded model")
90
+ return config, tokenizer, gpt
91
+ def interpret(text,label):
92
+ titles={
93
+ "block_output": "single restored layer in BERT",
94
+ "mlp_activation": "center of interval of 5 patched mlp layer",
95
+ "attention_output": "center of interval of 5 patched attn layer"
96
+ }
97
+
98
+ colors={
99
+ "block_output": "Purples",
100
+ "mlp_activation": "Greens",
101
+ "attention_output": "Reds"
102
+ }
103
+
104
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
105
+ #config, tokenizer, gpt = pv.create_llama(name="sharpbai/alpaca-7b-merged")
106
+ config, tokenizer, gpt = create_bert()
107
+ #config, tokenizer, gpt = create_gpt2(name="gpt2-xl")
108
+
109
+ gpt.to(device)
110
+
111
+ base = text
112
+ inputs = [
113
+ tokenizer(base, return_tensors="pt").to(device),
114
+ ]
115
+ #print(base)
116
+ base_token = tokenizer.convert_ids_to_tokens(inputs[0]['input_ids'][0])
117
+ res = gpt(**inputs[0])
118
+ probabilities = nn.functional.softmax(res[0], dim=-1)
119
+ if label=="hate":
120
+ l = 0
121
+ elif label=="normal":
122
+ l=1
123
+ else:l=2
124
+ #print(probabilities)
125
+ #print(res[0][0][0].item())
126
+ #print(res)
127
+ #distrib = embed_to_distrib(gpt, res.last_hidden_state, logits=False)
128
+ #top_vals(tokenizer, distrib[0][-1], n=20)
129
+ base = tokenizer(text, return_tensors="pt").to(device)
130
+ config = corrupted_config(type(gpt))
131
+ intervenable = IntervenableModel(config, gpt)
132
+ _, counterfactual_outputs = intervenable(
133
+ base, unit_locations={"base": ([[[0,1,2,3]]])}
134
+ )
135
+ #probabilities = nn.functional.softmax(counterfactual_outputs[0], dim=-1)
136
+ #print(probabilities)
137
+ for stream in ["block_output", "mlp_activation", "attention_output"]:
138
+ data = []
139
+ for layer_i in tqdm(range(gpt.config.num_hidden_layers)):
140
+ for pos_i in range(len(base_token)):
141
+ config = restore_corrupted_with_interval_config(
142
+ layer_i, stream,
143
+ window=1 if stream == "block_output" else 5
144
+ )
145
+
146
+ n_restores = len(config.representations) - 1
147
+ intervenable = IntervenableModel(config, gpt)
148
+ _, counterfactual_outputs = intervenable(
149
+ base,
150
+ [None] + [base]*n_restores,
151
+ {
152
+ "sources->base": (
153
+ [None] + [[[pos_i]]]*n_restores,
154
+ [[[0,1,2,3]]] + [[[pos_i]]]*n_restores,
155
+ )
156
+ },
157
+ )
158
+ #distrib = embed_to_distrib(
159
+ #gpt, counterfactual_outputs.last_hidden_state, logits=False
160
+ #)
161
+ #prob = distrib[0][-1][token].detach().cpu().item()
162
+ logits = counterfactual_outputs[0]
163
+ probabilities = nn.functional.softmax(logits, dim=-1)
164
+ prob_offense = probabilities[0][l].item()
165
+ data.append({"layer": layer_i, "pos": pos_i, "prob": prob_offense})
166
+ df = pd.DataFrame(data)
167
+ df.to_csv(f"./cs77_proj/tutorial_data/pyvene_rome_{stream}.csv")
168
+ for stream in ["block_output", "mlp_activation", "attention_output"]:
169
+ df = pd.read_csv(f"./cs77_proj/tutorial_data/pyvene_rome_{stream}.csv")
170
+ df["layer"] = df["layer"].astype(int)
171
+ df["pos"] = df["pos"].astype(int)
172
+ prob_type = "p"+"("+label+")"
173
+ df[prob_type] = df["prob"].astype(float)
174
+ #custom_labels = ["imagine*","the*", "riots*", "if", "people", "actually", "got" ,"food" ,"boxes" ,"instead", "of" ,"ebt", "cards", "every", "ghetto", "in", "america", "would" ,"look", "like", "ferguson"]
175
+ custom_labels = base_token #["what*", "sort*", "of*", "white*","man" ,"or", "woman", "would", "vote", "for", "this", "nigger"]
176
+ #custom_labels = ["no*", "liberal*","congratulated*", "hindu*", "refugees", "post", "cab", "because", "they", "hate", "hindus"]
177
+ breaks = list(range(len(custom_labels)))#[0, 1, 2, 3, 4, 5, 6,7,8,9,10,11]
178
+
179
+
180
+ plot = (
181
+ ggplot(df, aes(x="layer", y="pos"))
182
+
183
+ + geom_tile(aes(fill=prob_type))
184
+ + scale_fill_cmap(colors[stream]) + xlab(titles[stream])
185
+ + scale_y_reverse(
186
+ limits = (-0.5, len(custom_labels)),
187
+ breaks=breaks, labels=custom_labels)
188
+ + theme(figure_size=(6,9)) + ylab("")
189
+ + theme(axis_text_y = element_text(angle = 90, hjust = 1))
190
+ )
191
+ ggsave(
192
+ plot, filename=f"./cs77_proj/tutorial_data/pyvene_rome_{stream}.png", dpi=200
193
+ )
194
+ if stream == "mlp_activation":
195
+ mlp_img_path = f"./cs77_proj/tutorial_data/pyvene_rome_{stream}.png"
196
+ elif stream=="block_output":
197
+ bo_path = f"./cs77_proj/tutorial_data/pyvene_rome_{stream}.png"
198
+ else:attention_path = f"./cs77_proj/tutorial_data/pyvene_rome_{stream}.png"
199
+ return mlp_img_path,bo_path,attention_path
200
+
201
+ def restore_corrupted_with_interval_config(
202
+ layer, stream="mlp_activation", window=5, num_layers=12):
203
+ start = max(0, layer - window // 2)
204
+ end = min(num_layers, layer - (-window // 2))
205
+ config = IntervenableConfig(
206
+ representations=[
207
+ RepresentationConfig(
208
+ 0, # layer
209
+ "block_input", # intervention type
210
+ ),
211
+ ] + [
212
+ RepresentationConfig(
213
+ i, # layer
214
+ stream, # intervention type
215
+ ) for i in range(start, end)],
216
+ intervention_types=\
217
+ [NoiseIntervention]+[VanillaIntervention]*(end-start)
218
+ )
219
+ return config
220
+
221
+ class NoiseIntervention(ConstantSourceIntervention, LocalistRepresentationIntervention):
222
+ def __init__(self, embed_dim, **kwargs):
223
+ super().__init__()
224
+ self.interchange_dim = embed_dim
225
+ rs = np.random.RandomState(1)
226
+ prng = lambda *shape: rs.randn(*shape)
227
+ self.noise = torch.from_numpy(
228
+ prng(1, 4, embed_dim)).to(device)
229
+ self.noise_level = 0.7462981581687927 #0.3462981581687927
230
+
231
+ def forward(self, base, source=None, subspaces=None):
232
+ base[..., : self.interchange_dim] += self.noise * self.noise_level
233
+ return base
234
+
235
+ def __str__(self):
236
+ return f"NoiseIntervention(embed_dim={self.embed_dim})"
237
+
238
+
239
+ def corrupted_config(model_type):
240
+ config = IntervenableConfig(
241
+ model_type=model_type,
242
+ representations=[
243
+ RepresentationConfig(
244
+ 0, # layer
245
+ "block_input", # intervention type
246
+ ),
247
+ ],
248
+ intervention_types=NoiseIntervention,
249
+ )
250
+ return config
251
+ def create_bert(cache_dir=None):
252
+ """Creates a GPT2 model, config, and tokenizer from the given name and revision"""
253
+ from transformers import BertConfig
254
+
255
+ config = BertConfig.from_pretrained("./cs77_proj/bert_base/checkpoint-3848/config.json")
256
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
257
+ gpt = AutoModelForSequenceClassification.from_pretrained("./cs77_proj/bert_base/checkpoint-3848", config=config, cache_dir=cache_dir)
258
+ print("loaded model")
259
+ return config, tokenizer, gpt
260
+
261
+ # params = return_params('best_model_json/distilbert.json', 0.001 )
262
+ #params = return_params('best_model_json/distilbert.json', 1 )
263
+
264
+
265
+ '''embeddings=None
266
+ if(params['bert_tokens']):
267
+ train,val,test=createDatasetSplit(params) #update
268
+ else:
269
+ train,val,test,vocab_own=createDatasetSplit(params)
270
+ params['embed_size']=vocab_own.embeddings.shape[1]
271
+ params['vocab_size']=vocab_own.embeddings.shape[0]
272
+ embeddings=vocab_own.embeddings
273
+ if(params['auto_weights']):
274
+ y_test = [ele[2] for ele in test]
275
+ # print(y_test)
276
+ encoder = LabelEncoder()
277
+ encoder.classes_ = np.load(params['class_names'],allow_pickle=True)
278
+ params['weights']=class_weight.compute_class_weight('balanced',np.unique(y_test),y_test).astype('float32')
279
+ #params['weights']=np.array([len(y_test)/y_test.count(encoder.classes_[0]),len(y_test)/y_test.count(encoder.classes_[1]),len(y_test)/y_test.count(encoder.classes_[2])]).astype('float32')
280
+
281
+ model=select_model(params,embeddings)
282
+ model = model.eval()
283
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
284
+
285
+
286
+ classes_ = np.load('Data/classes.npy')
287
+ '''
288
+ def main_function(text,label):
289
+ '''tokens = tokenizer.encode_plus(text)
290
+ input_ids = pad_sequences(torch.tensor(tokens['input_ids']).unsqueeze(0),maxlen=int(params['max_length']),\
291
+ dtype="long",
292
+ value=0, truncating="post", padding="post")
293
+ # att_vals = pad_sequences(att_vals,maxlen=int(params['max_length']), dtype="float",
294
+ # value=0.0, truncating="post", padding="post")
295
+ att_masks=custom_att_masks(input_ids)
296
+
297
+ outs = model(torch.tensor(input_ids),
298
+ attention_mask=torch.tensor(att_masks, dtype=bool),
299
+ labels=None,
300
+ device='cuda')
301
+
302
+ text_tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze())
303
+
304
+ text_tokens_ = text_tokens[:len(tokens['input_ids'])]
305
+
306
+ print ('xyz')
307
+ print (outs[1][5].shape)
308
+ avg_attn = torch.mean(outs[1][5], dim=1)
309
+ avg_attn_np = avg_attn[0,0,:len(tokens['input_ids'])].detach().squeeze().numpy()
310
+
311
+ logits = outs[0]
312
+ print (logits)
313
+ print (np.sum(avg_attn_np))
314
+ print (avg_attn_np)
315
+
316
+ pred = torch.argmax(logits)
317
+ pred_label = classes_[pred]
318
+ '''
319
+ ml_img_path,bo_img_path,atten_img_path = interpret(text,label)
320
+ ml_im = Image.open(ml_img_path)
321
+ bo_im = Image.open(bo_img_path)
322
+ atten_im = Image.open(atten_img_path)
323
+
324
+ yield ml_im, bo_im, atten_im
325
+
326
+ '''
327
+ sns.set_theme(rc={'figure.figsize':(30,1)})
328
+
329
+ # creating subplot
330
+ fig, ax = plt.subplots()
331
+
332
+ # drawing heatmap on current axes
333
+ ax = sns.heatmap(np.expand_dims(avg_attn_np,0), annot= np.expand_dims(np.array(text_tokens_),0), \
334
+ fmt="", annot_kws={'size': 10}, cmap="magma")
335
+
336
+ fig = ax.get_figure()
337
+ fig.savefig("out.png" ,bbox_inches='tight')
338
+
339
+ im = Image.open("out.png")
340
+
341
+ yield im
342
+
343
+ '''
344
+
345
+ #return list(zip(text_tokens_ , avg_attn_np)), pred_label
346
+ # return list(zip(text_tokens_[1:-1] , avg_attn_np[1:-1]))
347
+
348
+
349
+ demo = gr.Interface(main_function,
350
+ inputs="textbox",
351
+ outputs="image",
352
+ theme = 'compact')
353
+
354
+ with gr.Blocks() as demo:
355
+ with gr.Tab("Text Input"):
356
+ text_input = gr.Textbox()
357
+ label_input = gr.Textbox()
358
+ text_button = gr.Button("Show")
359
+
360
+ with gr.Tab("Interpretability"):
361
+ with gr.Row():
362
+ image_output1 = gr.Image()
363
+ image_output2 = gr.Image()
364
+ image_output3 = gr.Image()
365
+
366
+ text_button.click(main_function, inputs=[text_input,label_input], outputs=[image_output1,image_output2,image_output3])
367
+
368
+
369
+
370
+
371
+ if __name__ == "__main__":
372
+ demo.launch()