cn91 commited on
Commit
0a6e4e2
1 Parent(s): a6be0b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -6
app.py CHANGED
@@ -11,12 +11,12 @@ if USE_GPU and torch.cuda.is_available():
11
  else:
12
  device = torch.device('cpu')
13
 
14
- #MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-186M-Chinese-SentencePiece"
15
- MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-97M-CWS-Chinese"
16
 
17
 
18
  WORD_PROBABILITY_THRESHOLD = 0.02
19
- TOP_K_WORDS = 10
20
 
21
  CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']
22
 
@@ -24,8 +24,15 @@ CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','
24
  def get_model_chinese():
25
  return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
26
 
 
 
 
 
 
27
  def assess_chinese(word, sentence):
28
  print("Assessing Chinese")
 
 
29
  if sentence.lower().find(word.lower()) == -1:
30
  print('Sentence does not contain the word!')
31
  return
@@ -35,15 +42,27 @@ def assess_chinese(word, sentence):
35
  top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
36
  target_word_prediction = mask_filler_chinese(text, targets = word)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  score = target_word_prediction[0]['score']
39
 
40
  # append the original word if its not found in the results
41
- top_k_prediction_filtered = [output for output in top_k_prediction if \
42
  output['token_str'] == word]
43
  if len(top_k_prediction_filtered) == 0:
44
- top_k_prediction.extend(target_word_prediction)
45
 
46
- return top_k_prediction, score
47
 
48
  def assess_sentence(word, sentence):
49
  return assess_chinese(word, sentence)
 
11
  else:
12
  device = torch.device('cpu')
13
 
14
+ MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-186M-Chinese-SentencePiece"
15
+ #MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-97M-CWS-Chinese"
16
 
17
 
18
  WORD_PROBABILITY_THRESHOLD = 0.02
19
+ TOP_K_WORDS = 200
20
 
21
  CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']
22
 
 
24
  def get_model_chinese():
25
  return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
26
 
27
+ @st.cache_resource
28
+ def get_allowed_tokens():
29
+ df = pd.read_csv('allowed_token_ids.csv')
30
+ return set(list(df['token']))
31
+
32
  def assess_chinese(word, sentence):
33
  print("Assessing Chinese")
34
+ allowed_token_ids = get_allowed_tokens()
35
+
36
  if sentence.lower().find(word.lower()) == -1:
37
  print('Sentence does not contain the word!')
38
  return
 
42
  top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
43
  target_word_prediction = mask_filler_chinese(text, targets = word)
44
 
45
+ norm_factor = 0
46
+ for output in top_k_prediction:
47
+ if output['token'] not in allowed_token_ids:
48
+ norm_factor += output['score']
49
+
50
+ top_k_prediction_new = []
51
+ for output in top_k_prediction:
52
+ if output['token'] in allowed_token_ids:
53
+ output['score'] = output['score']/(1-min(0.5,norm_factor))
54
+ top_k_prediction_new.append(output)
55
+
56
+ target_word_prediction[0]['score'] = target_word_prediction[0]['score'] / (1-min(0.5,norm_factor))
57
  score = target_word_prediction[0]['score']
58
 
59
  # append the original word if its not found in the results
60
+ top_k_prediction_filtered = [output for output in top_k_prediction_new if \
61
  output['token_str'] == word]
62
  if len(top_k_prediction_filtered) == 0:
63
+ top_k_prediction_new.extend(target_word_prediction)
64
 
65
+ return top_k_prediction_new, score
66
 
67
  def assess_sentence(word, sentence):
68
  return assess_chinese(word, sentence)