HarryLee commited on
Commit
b726bae
·
1 Parent(s): aaf7855

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -9,6 +9,7 @@ import gzip
9
  import os
10
  import torch
11
  import pickle
 
12
 
13
  ############
14
  ## Main page
@@ -64,6 +65,13 @@ with open(embedding_cache_path, "rb") as fIn:
64
  passages = cache_data['sentences']
65
  corpus_embeddings = cache_data['embeddings']
66
 
 
 
 
 
 
 
 
67
  # This function will search all wikipedia articles for passages that
68
  # answer the query
69
  def search(query):
@@ -94,8 +102,17 @@ def search(query):
94
  st.write("\n-------------------------\n")
95
  st.subheader("Top-N Cross-Encoder Re-ranker hits")
96
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
97
- for hit in hits[0:maxtags_sidebar]:
98
- st.write("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
 
 
 
 
 
 
 
 
 
99
 
100
  st.write("## Results:")
101
  if st.button('Generated Expansion'):
 
9
  import os
10
  import torch
11
  import pickle
12
+ import yake
13
 
14
  ############
15
  ## Main page
 
65
  passages = cache_data['sentences']
66
  corpus_embeddings = cache_data['embeddings']
67
 
68
+ kw_extractor = yake.KeywordExtractor()
69
+ language = "en"
70
+ max_ngram_size = 3
71
+ deduplication_threshold = 0.9
72
+ numOfKeywords = 20
73
+ custom_kw_extractor=yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, top=numOfKeywords, features=None)
74
+
75
  # This function will search all wikipedia articles for passages that
76
  # answer the query
77
  def search(query):
 
102
  st.write("\n-------------------------\n")
103
  st.subheader("Top-N Cross-Encoder Re-ranker hits")
104
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
105
+ #for hit in hits[0:maxtags_sidebar]:
106
+ # st.write("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
107
+ hit_res = []
108
+ for hit in hits[0:1000]:
109
+ q = passages[hit['corpus_id']].replace("\n", " ")
110
+ if q not in hit_res:
111
+ hit_res.append(q)
112
+ for res in hit_res[0:maxtags_sidebar]:
113
+ keywords = custom_kw_extractor.extract_keywords(res)
114
+ for kw in keywords:
115
+ print(kw)
116
 
117
  st.write("## Results:")
118
  if st.button('Generated Expansion'):