Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
c6dd7aa
1
Parent(s):
a440ac3
mask out
Browse files
app.py
CHANGED
@@ -16,20 +16,7 @@ from transformers import AlbertTokenizer, AlbertForMaskedLM
|
|
16 |
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
|
17 |
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
|
18 |
|
19 |
-
|
20 |
-
def load_model():
|
21 |
-
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
|
22 |
-
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
|
23 |
-
model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
|
24 |
-
return tokenizer,model
|
25 |
-
|
26 |
-
def clear_data():
|
27 |
-
for key in st.session_state:
|
28 |
-
del st.session_state[key]
|
29 |
-
|
30 |
-
if __name__=='__main__':
|
31 |
-
|
32 |
-
# Config
|
33 |
max_width = 1500
|
34 |
padding_top = 0
|
35 |
padding_right = 2
|
@@ -56,9 +43,61 @@ if __name__=='__main__':
|
|
56 |
st.markdown(define_margins, unsafe_allow_html=True)
|
57 |
st.markdown(hide_table_row_index, unsafe_allow_html=True)
|
58 |
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
|
|
62 |
sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
|
63 |
sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
|
64 |
input_ids_1 = tokenizer(sent_1).input_ids
|
@@ -69,3 +108,4 @@ if __name__=='__main__':
|
|
69 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
70 |
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
|
71 |
st.write([tokenizer.decode([token]) for token in preds])
|
|
|
|
16 |
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
|
17 |
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
|
18 |
|
19 |
+
def wide_setup():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
max_width = 1500
|
21 |
padding_top = 0
|
22 |
padding_right = 2
|
|
|
43 |
st.markdown(define_margins, unsafe_allow_html=True)
|
44 |
st.markdown(hide_table_row_index, unsafe_allow_html=True)
|
45 |
|
46 |
+
@st.cache(show_spinner=True,allow_output_mutation=True)
|
47 |
+
def load_model():
|
48 |
+
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
|
49 |
+
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
|
50 |
+
model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
|
51 |
+
return tokenizer,model
|
52 |
+
|
53 |
+
def clear_data():
|
54 |
+
for key in st.session_state:
|
55 |
+
del st.session_state[key]
|
56 |
+
|
57 |
+
if __name__=='__main__':
|
58 |
+
wide_setup()
|
59 |
+
|
60 |
+
if 'page_status' not in st.session_state:
|
61 |
+
st.session_state['page_status'] = 'type_in'
|
62 |
+
|
63 |
+
if st.session_state['page_status']=='type_in':
|
64 |
+
tokenizer,model = load_model()
|
65 |
+
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
66 |
+
|
67 |
+
st.write('1. Type in the sentences and click "Tokenize"')
|
68 |
+
sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
|
69 |
+
sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
|
70 |
+
if st.sidebar.button('Tokenize'):
|
71 |
+
st.session_state['page_status'] = 'tokenized'
|
72 |
+
st.session_state['sent_1'] = sent_1
|
73 |
+
st.session_state['sent_2'] = sent_2
|
74 |
+
|
75 |
+
if st.session_state['page_status']=='tokenized':
|
76 |
+
tokenizer,model = load_model()
|
77 |
+
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
78 |
+
sent_1 = st.session_state['sent_1']
|
79 |
+
sent_2 = st.session_state['sent_2']
|
80 |
+
if 'masked_pos_1' not in st.session_state:
|
81 |
+
st.session_state['masked_pos_1'] = []
|
82 |
+
if 'masked_pos_2' not in st.session_state:
|
83 |
+
st.session_state['masked_pos_2'] = []
|
84 |
+
|
85 |
+
st.write('2. Select sites to mask out and click "Confirm"')
|
86 |
+
input_sent = tokenizer(sent_1).input_ids
|
87 |
+
decoded_sent = [tokenizer.decode([token]) for token in input_sent]
|
88 |
+
char_nums = [len(word)+2 for word in decoded_sent]
|
89 |
+
cols = st.columns(char_nums)
|
90 |
+
with cols[0]:
|
91 |
+
st.write(decoded_sent[0])
|
92 |
+
with cols[-1]:
|
93 |
+
st.write(decoded_sent[-1])
|
94 |
+
for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])):
|
95 |
+
with col:
|
96 |
+
if st.button(word,key=f'word_{word_id}'):
|
97 |
+
st.session_state['masked_pos_1'].append(word_id)
|
98 |
+
st.write(f'Masked words: {" ".join([decoded_sent[word_id+1] for word_id in st.session_state["masked_pos_1"])}')
|
99 |
|
100 |
+
'''
|
101 |
sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
|
102 |
sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
|
103 |
input_ids_1 = tokenizer(sent_1).input_ids
|
|
|
108 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
109 |
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
|
110 |
st.write([tokenizer.decode([token]) for token in preds])
|
111 |
+
'''
|