File size: 6,731 Bytes
2609434
 
 
 
 
 
 
 
 
 
 
 
 
b6df2fe
 
 
 
 
 
2609434
f6cb372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd36097
 
f6cb372
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from transformers import pipeline, AutoTokenizer
import pandas as pd
import numpy as np
import torch
import streamlit as st

USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = torch.device("cuda:0")
else: 
    device = torch.device('cpu')

MODEL_NAME_ENGLISH = "facebook/xlm-v-base"
#SENTENCE_MODEL_NAME_ENGLISH = 'sentence-transformers/all-MiniLM-L6-v2'
#WORD_MODEL_NAME_ENGLISH = 'vocab-transformers/distilbert-word2vec_256k-MLM_best'

# chinese models
MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-186M-Chinese-SentencePiece"

WORD_PROBABILITY_THRESHOLD = 0.02
#WORD_PROBABILITY_THRESHOLD_ENGLISH = 0.02
#WORD_PROBABILITY_THRESHOLD_CHINESE = 0.02
TOP_K_WORDS = 10

ENGLISH_LANG = "English"
CHINESE_LANG = "Chinese"

CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']

@st.cache_resource
def get_model_chinese():
    return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)

@st.cache_resource
def get_model_english():
    return pipeline("fill-mask", MODEL_NAME_ENGLISH, device = device)    

@st.cache_data
def get_wordlist_chinese():
    return pd.read_csv('wordlist_chinese.csv')

@st.cache_data
def get_wordlist_english():
    return pd.read_csv('wordlist_english.csv')

def assess_chinese(word, sentence):
    print("Assessing English")
    if sentence.lower().find(word.lower()) == -1:
        print('Sentence does not contain the word!')
        return

    text = sentence.replace(word.lower(), "<mask>")

    top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
    target_word_prediction = mask_filler_chinese(text, targets = word)

    score = target_word_prediction[0]['score']

    # append the original word if its not found in the results
    top_k_prediction_filtered = [output for output in top_k_prediction if \
                                 output['token_str'] == word]
    if len(top_k_prediction_filtered) == 0:
        top_k_prediction.extend(target_word_prediction)

    return top_k_prediction, score

def assess_english(word, sentence):
    if sentence.lower().find(word.lower()) == -1:
        raise Exception("Sentence does not contain the target word")

    text = sentence.replace(word.lower(), "<mask>")

    top_k_prediction = mask_filler_english(text, top_k=TOP_K_WORDS)
    target_word_prediction = mask_filler_english(text, targets = chr(9601)+word)

    score = target_word_prediction[0]['score']

    # append the original word if its not found in the results
    top_k_prediction_filtered = [output for output in top_k_prediction if \
                                 output['token_str'] == word]
    if len(top_k_prediction_filtered) == 0:
        top_k_prediction.extend(target_word_prediction)

    return top_k_prediction, score

def assess_sentence(language, word, sentence):
    if (language == ENGLISH_LANG):
        return assess_english(word, sentence)
    elif (language == CHINESE_LANG):
        return assess_chinese(word, sentence)
    
def get_chinese_word():
    include = (wordlist_chinese.assess == True) & (wordlist_chinese.Chinese.apply(len) == 2)
    possible_words = wordlist_chinese[include]
    word = possible_words.sample(1).iloc[0].Chinese
    test_words = CHINESE_WORDLIST
    word = np.random.choice(test_words)
    return word

def get_english_word():
    include = (wordlist_english.assess == True)
    possible_words = wordlist_english[include]
    word = possible_words.sample(1).iloc[0].word
    test_words = ["independent","satisfied","excited"]
    word = np.random.choice(test_words)
    return word

def get_word(language):
    if (language == ENGLISH_LANG):
        return get_english_word()
    elif (language == CHINESE_LANG):
        return get_chinese_word()

mask_filler_chinese = get_model_chinese()
mask_filler_english = get_model_english()
wordlist_chinese = get_wordlist_chinese()
wordlist_english = get_wordlist_english()

def highlight_given_word(row):
    color = '#ACE5EE' if row.Words == target_word else 'white'
    return [f'background-color:{color}'] * len(row)

def get_top_5_results(top_k_prediction):
    predictions_df = pd.DataFrame(top_k_prediction)
    predictions_df = predictions_df.drop(columns=["token", "sequence"])
    predictions_df = predictions_df.rename(columns={"score": "Probability", "token_str": "Words"})

    if (predictions_df[:5].Words == target_word).sum() == 0:
        print("target word not in top 5")
        top_5_df = predictions_df[:5]
        target_word_df = predictions_df[(predictions_df.Words == target_word)]
        print(target_word_df)
        top_5_df = pd.concat([top_5_df, target_word_df])

    else:
        top_5_df = predictions_df[:5]
    top_5_df['Probability'] = top_5_df['Probability'].apply(lambda x: f"{x:.2%}")

    return top_5_df

#### Streamlit Page
st.title("造句 Auto-marking Demo")
language = st.radio("Select your language", (ENGLISH_LANG, CHINESE_LANG))
#st.info("You are practising on " + language)

if 'target_word' not in st.session_state:
    st.session_state['target_word'] = get_word(language)
target_word = st.session_state['target_word']

st.write("Target word: ", target_word)
if st.button("Get new word"):
    st.session_state['target_word'] = get_word(language)
    st.experimental_rerun()

st.subheader("Form your sentence and input below!")
sentence = st.text_input('Enter your sentence here', placeholder="Enter your sentence here!")

if st.button("Grade"):
    top_k_prediction, score = assess_sentence(language, target_word, sentence)
    with open('./result01.json', 'w') as outfile:
        outfile.write(str(top_k_prediction))

    st.write(f"Probability: {score:.2%}")
    st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.2%}")
    predictions_df = get_top_5_results(top_k_prediction)
    df_style = predictions_df.style.apply(highlight_given_word, axis=1)

    if (score >= WORD_PROBABILITY_THRESHOLD):
#        st.balloons()
        st.success("Yay good job! 🕺 Practice again with other words", icon="✅")
        st.table(df_style)
    else:
        st.warning("Hmmm.. maybe try again?")
        st.table(df_style)