Spaces:
Sleeping
Sleeping
from transformers import pipeline, AutoTokenizer | |
import pandas as pd | |
import numpy as np | |
import torch | |
import streamlit as st | |
USE_GPU = True | |
if USE_GPU and torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
else: | |
device = torch.device('cpu') | |
MODEL_NAME_ENGLISH = "facebook/xlm-v-base" | |
#SENTENCE_MODEL_NAME_ENGLISH = 'sentence-transformers/all-MiniLM-L6-v2' | |
#WORD_MODEL_NAME_ENGLISH = 'vocab-transformers/distilbert-word2vec_256k-MLM_best' | |
# chinese models | |
MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-186M-Chinese-SentencePiece" | |
WORD_PROBABILITY_THRESHOLD = 0.02 | |
#WORD_PROBABILITY_THRESHOLD_ENGLISH = 0.02 | |
#WORD_PROBABILITY_THRESHOLD_CHINESE = 0.02 | |
TOP_K_WORDS = 10 | |
ENGLISH_LANG = "English" | |
CHINESE_LANG = "Chinese" | |
CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静'] | |
def get_model_chinese(): | |
return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device) | |
def get_model_english(): | |
return pipeline("fill-mask", MODEL_NAME_ENGLISH, device = device) | |
def get_wordlist_chinese(): | |
return pd.read_csv('wordlist_chinese.csv') | |
def get_wordlist_english(): | |
return pd.read_csv('wordlist_english.csv') | |
def assess_chinese(word, sentence): | |
print("Assessing English") | |
if sentence.lower().find(word.lower()) == -1: | |
print('Sentence does not contain the word!') | |
return | |
text = sentence.replace(word.lower(), "<mask>") | |
top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS) | |
target_word_prediction = mask_filler_chinese(text, targets = word) | |
score = target_word_prediction[0]['score'] | |
# append the original word if its not found in the results | |
top_k_prediction_filtered = [output for output in top_k_prediction if \ | |
output['token_str'] == word] | |
if len(top_k_prediction_filtered) == 0: | |
top_k_prediction.extend(target_word_prediction) | |
return top_k_prediction, score | |
def assess_english(word, sentence): | |
if sentence.lower().find(word.lower()) == -1: | |
raise Exception("Sentence does not contain the target word") | |
text = sentence.replace(word.lower(), "<mask>") | |
top_k_prediction = mask_filler_english(text, top_k=TOP_K_WORDS) | |
target_word_prediction = mask_filler_english(text, targets = chr(9601)+word) | |
score = target_word_prediction[0]['score'] | |
# append the original word if its not found in the results | |
top_k_prediction_filtered = [output for output in top_k_prediction if \ | |
output['token_str'] == word] | |
if len(top_k_prediction_filtered) == 0: | |
top_k_prediction.extend(target_word_prediction) | |
return top_k_prediction, score | |
def assess_sentence(language, word, sentence): | |
if (language == ENGLISH_LANG): | |
return assess_english(word, sentence) | |
elif (language == CHINESE_LANG): | |
return assess_chinese(word, sentence) | |
def get_chinese_word(): | |
include = (wordlist_chinese.assess == True) & (wordlist_chinese.Chinese.apply(len) == 2) | |
possible_words = wordlist_chinese[include] | |
word = possible_words.sample(1).iloc[0].Chinese | |
test_words = CHINESE_WORDLIST | |
word = np.random.choice(test_words) | |
return word | |
def get_english_word(): | |
include = (wordlist_english.assess == True) | |
possible_words = wordlist_english[include] | |
word = possible_words.sample(1).iloc[0].word | |
test_words = ["independent","satisfied","excited"] | |
word = np.random.choice(test_words) | |
return word | |
def get_word(language): | |
if (language == ENGLISH_LANG): | |
return get_english_word() | |
elif (language == CHINESE_LANG): | |
return get_chinese_word() | |
mask_filler_chinese = get_model_chinese() | |
mask_filler_english = get_model_english() | |
wordlist_chinese = get_wordlist_chinese() | |
wordlist_english = get_wordlist_english() | |
def highlight_given_word(row): | |
color = '#ACE5EE' if row.Words == target_word else 'white' | |
return [f'background-color:{color}'] * len(row) | |
def get_top_5_results(top_k_prediction): | |
predictions_df = pd.DataFrame(top_k_prediction) | |
predictions_df = predictions_df.drop(columns=["token", "sequence"]) | |
predictions_df = predictions_df.rename(columns={"score": "Probability", "token_str": "Words"}) | |
if (predictions_df[:5].Words == target_word).sum() == 0: | |
print("target word not in top 5") | |
top_5_df = predictions_df[:5] | |
target_word_df = predictions_df[(predictions_df.Words == target_word)] | |
print(target_word_df) | |
top_5_df = pd.concat([top_5_df, target_word_df]) | |
else: | |
top_5_df = predictions_df[:5] | |
top_5_df['Probability'] = top_5_df['Probability'].apply(lambda x: f"{x:.2%}") | |
return top_5_df | |
#### Streamlit Page | |
st.title("造句 Auto-marking Demo") | |
language = st.radio("Select your language", (ENGLISH_LANG, CHINESE_LANG)) | |
#st.info("You are practising on " + language) | |
if 'target_word' not in st.session_state: | |
st.session_state['target_word'] = get_word(language) | |
target_word = st.session_state['target_word'] | |
st.write("Target word: ", target_word) | |
if st.button("Get new word"): | |
st.session_state['target_word'] = get_word(language) | |
st.experimental_rerun() | |
st.subheader("Form your sentence and input below!") | |
sentence = st.text_input('Enter your sentence here', placeholder="Enter your sentence here!") | |
if st.button("Grade"): | |
top_k_prediction, score = assess_sentence(language, target_word, sentence) | |
with open('./result01.json', 'w') as outfile: | |
outfile.write(str(top_k_prediction)) | |
st.write(f"Probability: {score:.2%}") | |
st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.2%}") | |
predictions_df = get_top_5_results(top_k_prediction) | |
df_style = predictions_df.style.apply(highlight_given_word, axis=1) | |
if (score >= WORD_PROBABILITY_THRESHOLD): | |
# st.balloons() | |
st.success("Yay good job! 🕺 Practice again with other words", icon="✅") | |
st.table(df_style) | |
else: | |
st.warning("Hmmm.. maybe try again?") | |
st.table(df_style) | |