Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,12 +11,12 @@ if USE_GPU and torch.cuda.is_available():
|
|
11 |
else:
|
12 |
device = torch.device('cpu')
|
13 |
|
14 |
-
|
15 |
-
MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-97M-CWS-Chinese"
|
16 |
|
17 |
|
18 |
WORD_PROBABILITY_THRESHOLD = 0.02
|
19 |
-
TOP_K_WORDS =
|
20 |
|
21 |
CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']
|
22 |
|
@@ -24,8 +24,15 @@ CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','
|
|
24 |
def get_model_chinese():
|
25 |
return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
|
26 |
|
|
|
|
|
|
|
|
|
|
|
27 |
def assess_chinese(word, sentence):
|
28 |
print("Assessing Chinese")
|
|
|
|
|
29 |
if sentence.lower().find(word.lower()) == -1:
|
30 |
print('Sentence does not contain the word!')
|
31 |
return
|
@@ -35,15 +42,27 @@ def assess_chinese(word, sentence):
|
|
35 |
top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
|
36 |
target_word_prediction = mask_filler_chinese(text, targets = word)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
score = target_word_prediction[0]['score']
|
39 |
|
40 |
# append the original word if its not found in the results
|
41 |
-
top_k_prediction_filtered = [output for output in
|
42 |
output['token_str'] == word]
|
43 |
if len(top_k_prediction_filtered) == 0:
|
44 |
-
|
45 |
|
46 |
-
return
|
47 |
|
48 |
def assess_sentence(word, sentence):
|
49 |
return assess_chinese(word, sentence)
|
|
|
11 |
else:
|
12 |
device = torch.device('cpu')
|
13 |
|
14 |
+
MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-186M-Chinese-SentencePiece"
|
15 |
+
#MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-97M-CWS-Chinese"
|
16 |
|
17 |
|
18 |
WORD_PROBABILITY_THRESHOLD = 0.02
|
19 |
+
TOP_K_WORDS = 200
|
20 |
|
21 |
CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']
|
22 |
|
|
|
24 |
def get_model_chinese():
|
25 |
return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
|
26 |
|
27 |
+
@st.cache_resource
|
28 |
+
def get_allowed_tokens():
|
29 |
+
df = pd.read_csv('allowed_token_ids.csv')
|
30 |
+
return set(list(df['token']))
|
31 |
+
|
32 |
def assess_chinese(word, sentence):
|
33 |
print("Assessing Chinese")
|
34 |
+
allowed_token_ids = get_allowed_tokens()
|
35 |
+
|
36 |
if sentence.lower().find(word.lower()) == -1:
|
37 |
print('Sentence does not contain the word!')
|
38 |
return
|
|
|
42 |
top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
|
43 |
target_word_prediction = mask_filler_chinese(text, targets = word)
|
44 |
|
45 |
+
norm_factor = 0
|
46 |
+
for output in top_k_prediction:
|
47 |
+
if output['token'] not in allowed_token_ids:
|
48 |
+
norm_factor += output['score']
|
49 |
+
|
50 |
+
top_k_prediction_new = []
|
51 |
+
for output in top_k_prediction:
|
52 |
+
if output['token'] in allowed_token_ids:
|
53 |
+
output['score'] = output['score']/(1-min(0.5,norm_factor))
|
54 |
+
top_k_prediction_new.append(output)
|
55 |
+
|
56 |
+
target_word_prediction[0]['score'] = target_word_prediction[0]['score'] / (1-min(0.5,norm_factor))
|
57 |
score = target_word_prediction[0]['score']
|
58 |
|
59 |
# append the original word if its not found in the results
|
60 |
+
top_k_prediction_filtered = [output for output in top_k_prediction_new if \
|
61 |
output['token_str'] == word]
|
62 |
if len(top_k_prediction_filtered) == 0:
|
63 |
+
top_k_prediction_new.extend(target_word_prediction)
|
64 |
|
65 |
+
return top_k_prediction_new, score
|
66 |
|
67 |
def assess_sentence(word, sentence):
|
68 |
return assess_chinese(word, sentence)
|