山越貴耀
fix a bug
ceb5190
import pandas as pd
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
from transformers import BertTokenizer,BertForMaskedLM
import io
import time
@st.cache(show_spinner=True,allow_output_mutation=True)
def load_sentence_model():
sentence_model = SentenceTransformer('paraphrase-distilroberta-base-v1')
return sentence_model
@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model(model_name):
if model_name.startswith('bert'):
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)
model.eval()
return tokenizer,model
@st.cache(show_spinner=False)
def load_data(sentence_num):
df = pd.read_csv('tsne_out.csv')
df = df.loc[lambda d: (d['sentence_num']==sentence_num)&(d['iter_num']<1000)]
return df.reset_index()
#@st.cache(show_spinner=False)
def mask_prob(model,mask_id,sentences,position,temp=1):
masked_sentences = sentences.clone()
masked_sentences[:, position] = mask_id
with torch.no_grad():
logits = model(masked_sentences)[0]
return F.log_softmax(logits[:, position] / temp, dim = -1)
#@st.cache(show_spinner=False)
def sample_words(probs,pos,sentences):
candidates = [[tokenizer.decode([candidate]),torch.exp(probs)[0,candidate].item()]
for candidate in torch.argsort(probs[0],descending=True)[:10]]
df = pd.DataFrame(data=candidates,columns=['word','prob'])
chosen_words = torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1)
new_sentences = sentences.clone()
new_sentences[:, pos] = chosen_words
return new_sentences, df
def run_chains(tokenizer,model,mask_id,input_text,num_steps):
init_sent = tokenizer(input_text,return_tensors='pt')['input_ids']
seq_len = init_sent.shape[1]
sentence = init_sent.clone()
data_list = []
st.sidebar.write('Generating samples...')
st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences')
chain_progress = st.sidebar.progress(0)
for step_id in range(num_steps):
chain_progress.progress((step_id+1)/num_steps)
pos = torch.randint(seq_len-2,size=(1,)).item()+1
#data_list.append([step_id,' '.join([tokenizer.decode([token]) for token in sentence[0]]),pos])
data_list.append([step_id,tokenizer.decode([token for token in sentence[0]]),pos])
probs = mask_prob(model,mask_id,sentence,pos)
sentence,_ = sample_words(probs,pos,sentence)
return pd.DataFrame(data=data_list,columns=['step','sentence','next_sample_loc'])
#@st.cache(show_spinner=True,allow_output_mutation=True)
def show_tsne_panel(df, step_id):
x_tsne, y_tsne = df.x_tsne, df.y_tsne
xscale_unit = (max(x_tsne)-min(x_tsne))/10
yscale_unit = (max(y_tsne)-min(y_tsne))/10
xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit]
ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit]
color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
fig = plt.figure(figsize=(5,5),dpi=200)
ax = fig.add_subplot(1,1,1)
ax.plot(x_tsne[:step_id+1],y_tsne[:step_id+1],linewidth=0.2,color='gray',zorder=1)
ax.scatter(x_tsne[:step_id+1],y_tsne[:step_id+1],s=5,color=color_list[:step_id+1],zorder=2)
ax.scatter(x_tsne[step_id:step_id+1],y_tsne[step_id:step_id+1],s=50,marker='*',color='blue',zorder=3)
ax.set_xlim(*xlims)
ax.set_ylim(*ylims)
ax.axis('off')
return fig
def run_tsne(chain):
st.sidebar.write('Running t-SNE...')
st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences')
chain = chain.assign(cleaned_sentence=chain.sentence.str.replace(r'\[CLS\] ', '',regex=True).str.replace(r' \[SEP\]', '',regex=True))
sentence_model = load_sentence_model()
sentence_embeddings = sentence_model.encode(chain.cleaned_sentence.to_list(), show_progress_bar=False)
tsne = TSNE(n_components = 2, n_iter=2000)
big_pca = PCA(n_components = 50)
tsne_vals = tsne.fit_transform(big_pca.fit_transform(sentence_embeddings))
tsne = pd.concat([chain, pd.DataFrame(tsne_vals, columns = ['x_tsne', 'y_tsne'],index=chain.index)], axis = 1)
return tsne
def autoplay() :
for step_id in range(st.session_state.step_id, len(st.session_state.df), 1):
x = st.empty()
with x.container():
st.markdown(show_changed_site(), unsafe_allow_html = True)
fig = show_tsne_panel(st.session_state.df, step_id)
st.session_state.prev_step_id = st.session_state.step_id
st.session_state.step_id = step_id
#plt.title(f'Step {step_id}')#: {show_changed_site()}')
cols = st.columns([1,2,1])
with cols[1]:
st.pyplot(fig)
time.sleep(.25)
x.empty()
def initialize_buttons() :
buttons = st.sidebar.empty()
button_ids = []
with buttons.container() :
row1_labels = ['+1','+10','+100','+500']
row1 = st.columns([4,5,6,6])
for col_id,col in enumerate(row1):
button_ids.append(col.button(row1_labels[col_id],key=row1_labels[col_id]))
row2_labels = ['-1','-10','-100','-500']
row2 = st.columns([4,5,6,6])
for col_id,col in enumerate(row2):
button_ids.append(col.button(row2_labels[col_id],key=row2_labels[col_id]))
show_candidates_checked = st.checkbox('Show candidates')
# Increment if any of them have been pressed
increments = np.array([1,10,100,500,-1,-10,-100,-500])
if any(button_ids) :
increment_value = increments[np.array(button_ids)][0]
st.session_state.prev_step_id = st.session_state.step_id
new_step_id = st.session_state.step_id + increment_value
st.session_state.step_id = min(len(st.session_state.df) - 1, max(0, new_step_id))
if show_candidates_checked:
st.write('Click any word to see each candidate with its probability')
show_candidates()
def show_candidates():
if 'curr_table' in st.session_state:
st.session_state.curr_table.empty()
step_id = st.session_state.step_id
sentence = df.cleaned_sentence.loc[step_id]
input_sent = tokenizer(sentence,return_tensors='pt')['input_ids']
decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]]
char_nums = [len(word)+2 for word in decoded_sent]
cols = st.columns(char_nums)
with cols[0]:
st.write(decoded_sent[0])
with cols[-1]:
st.write(decoded_sent[-1])
for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])):
with col:
if st.button(word,key=f'word_{word_id}'):
probs = mask_prob(model,mask_id,input_sent,word_id+1)
_, candidates_df = sample_words(probs, word_id+1, input_sent)
st.session_state.curr_table = st.table(candidates_df)
def show_changed_site():
df = st.session_state.df
step_id = st.session_state.step_id
prev_step_id = st.session_state.prev_step_id
curr_sent = df.cleaned_sentence.loc[step_id].split(' ')
prev_sent = df.cleaned_sentence.loc[prev_step_id].split(' ')
locs = [df.next_sample_loc.to_list()[step_id-1]-1] if 'next_sample_loc' in df else (
[i for i in range(len(curr_sent)) if curr_sent[i] not in prev_sent]
)
disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
prefix = f'<p style={disp_style}>Step {st.session_state.step_id}&colon;&nbsp; <span style="font-weight:bold">'
disp = ' '.join([f'<span style="color:Red">{word}</span>' if i in locs else f'{word}'
for (i, word) in enumerate(curr_sent)])
suffix = '</span></p>'
return prefix + disp + suffix
def clear_df():
if 'df' in st.session_state:
del st.session_state['df']
if __name__=='__main__':
# Config
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
<style>
.appview-container .main .block-container{{
max-width: {max_width}px;
padding-top: {padding_top}rem;
padding-right: {padding_right}rem;
padding-left: {padding_left}rem;
padding-bottom: {padding_bottom}rem;
}}
</style>
"""
hide_table_row_index = """
<style>
tbody th {display:none}
.blank {display:none}
</style>
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
input_type = st.sidebar.radio(
label='1. Choose the input type',
on_change=clear_df,
options=('Use one of the example sentences','Use your own initial sentence')
)
# Title
st.header("Demo: Probing BERT's priors with serial reproduction chains")
# Load BERT
tokenizer,model = load_model('bert-base-uncased')
mask_id = tokenizer.encode("[MASK]")[1:-1][0]
# First step: load the dataframe containing sentences
if input_type=='Use one of the example sentences':
sentence = st.sidebar.selectbox("Select the inital sentence",
('--- Please select one from below ---',
'About 170 campers attend the camps each week.',
"Ali marpet's mother is joy rose.",
'She grew up with three brothers and ten sisters.'))
if sentence!='--- Please select one from below ---':
if sentence=='About 170 campers attend the camps each week.':
sentence_num = 6
elif sentence=='She grew up with three brothers and ten sisters.':
sentence_num = 8
elif sentence=="Ali marpet's mother is joy rose." :
sentence_num = 2
st.session_state.df = load_data(sentence_num)
st.session_state.finished_sampling = True
else:
sentence = st.sidebar.text_input('Type your own sentence here.',on_change=clear_df)
num_steps = st.sidebar.number_input(label='How many steps do you want to run?',value=500)
if st.sidebar.button('Run chains'):
chain = run_chains(tokenizer, model, mask_id, sentence, num_steps=num_steps)
st.session_state.df = run_tsne(chain)
st.session_state.finished_sampling = True
st.empty().markdown("\
Let's explore sentences from BERT's prior! \
Use the menu to the left to select a pre-generated chain, \
or start a new chain using your own initial sentence.\
" if not 'df' in st.session_state else "\
Use the slider to select a step, or watch the autoplay.\
Click 'Show candidates' to see the top proposals when each word is masked out.\
")
if 'df' in st.session_state:
df = st.session_state.df
if 'step_id' not in st.session_state:
st.session_state.prev_step_id = 0
st.session_state.step_id = 0
explore_type = st.sidebar.radio(
'2. Choose how to explore the chain',
options=['Click through steps','Autoplay']
)
if explore_type=='Autoplay':
st.empty()
st.sidebar.empty()
autoplay()
elif explore_type=='Click through steps':
initialize_buttons()
with st.container():
st.markdown(show_changed_site(), unsafe_allow_html = True)
fig = show_tsne_panel(df, st.session_state.step_id)
cols = st.columns([1,2,1])
with cols[1]:
st.pyplot(fig)