File size: 4,463 Bytes
2609434
 
 
 
 
 
 
 
 
 
 
 
 
9a919c2
 
e4cf87a
2609434
f6cb372
 
 
 
 
 
 
 
 
 
15553b2
f6cb372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15553b2
 
f6cb372
 
15553b2
 
f6cb372
 
15553b2
 
f6cb372
 
3a52889
f6cb372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15553b2
f6cb372
 
 
 
15553b2
f6cb372
 
 
 
 
 
15553b2
f6cb372
 
 
 
 
 
 
 
 
dd36097
 
f6cb372
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from transformers import pipeline, AutoTokenizer
import pandas as pd
import numpy as np
import torch
import streamlit as st

USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = torch.device("cuda:0")
else: 
    device = torch.device('cpu')

#MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-186M-Chinese-SentencePiece"
MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-97M-CWS-Chinese"


WORD_PROBABILITY_THRESHOLD = 0.02
TOP_K_WORDS = 10

CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']

@st.cache_resource
def get_model_chinese():
    return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)

def assess_chinese(word, sentence):
    print("Assessing Chinese")
    if sentence.lower().find(word.lower()) == -1:
        print('Sentence does not contain the word!')
        return

    text = sentence.replace(word.lower(), "<mask>")

    top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
    target_word_prediction = mask_filler_chinese(text, targets = word)

    score = target_word_prediction[0]['score']

    # append the original word if its not found in the results
    top_k_prediction_filtered = [output for output in top_k_prediction if \
                                 output['token_str'] == word]
    if len(top_k_prediction_filtered) == 0:
        top_k_prediction.extend(target_word_prediction)

    return top_k_prediction, score

def assess_sentence(word, sentence):
    return assess_chinese(word, sentence)
    
def get_chinese_word():
    possible_words = CHINESE_WORDLIST
    word = np.random.choice(possible_words)
    return word

def get_word():
    return get_chinese_word()

mask_filler_chinese = get_model_chinese()
#wordlist_chinese = get_wordlist_chinese()

def highlight_given_word(row):
    color = '#ACE5EE' if row.Words == target_word else 'white'
    return [f'background-color:{color}'] * len(row)

def get_top_5_results(top_k_prediction):
    predictions_df = pd.DataFrame(top_k_prediction)
    predictions_df = predictions_df.drop(columns=["token", "sequence"])
    predictions_df = predictions_df.rename(columns={"score": "Probability", "token_str": "Words"})

    if (predictions_df[:5].Words == target_word).sum() == 0:
        print("target word not in top 5")
        top_5_df = predictions_df[:5]
        target_word_df = predictions_df[(predictions_df.Words == target_word)]
        print(target_word_df)
        top_5_df = pd.concat([top_5_df, target_word_df])

    else:
        top_5_df = predictions_df[:5]
    top_5_df['Probability'] = top_5_df['Probability'].apply(lambda x: f"{x:.2%}")

    return top_5_df

#### Streamlit Page
st.title("造句 Auto-marking Demo")

if 'target_word' not in st.session_state:
    st.session_state['target_word'] = get_word()
target_word = st.session_state['target_word']

st.write("Target word: ", target_word)
if st.button("Get new word"):
    st.session_state['target_word'] = get_word()
    st.experimental_rerun()

st.subheader("Form your sentence and input below!")
sentence = st.text_input('Enter your sentence here', placeholder="Enter your sentence here!")

if st.button("Grade"):
    top_k_prediction, score = assess_sentence(target_word, sentence)
    with open('./result01.json', 'w') as outfile:
        outfile.write(str(top_k_prediction))

    st.write(f"Probability: {score:.2%}")
    st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.2%}")
    predictions_df = get_top_5_results(top_k_prediction)
    df_style = predictions_df.style.apply(highlight_given_word, axis=1)

    if (score >= WORD_PROBABILITY_THRESHOLD):
#        st.balloons()
        st.success("Yay good job! 🕺 Practice again with other words", icon="✅")
        st.table(df_style)
    else:
        st.warning("Hmmm.. maybe try again?")
        st.table(df_style)