HarryLee commited on
Commit
c306d7d
·
1 Parent(s): 0d2a873

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -29
app.py CHANGED
@@ -9,7 +9,6 @@ import gzip
9
  import os
10
  import torch
11
  import pickle
12
- import yake
13
 
14
  ############
15
  ## Main page
@@ -36,7 +35,7 @@ user_query = st.text_input("Enter a query for the generated text: e.g., gift, ho
36
  # Add selectbox in streamlit
37
  option1 = st.sidebar.selectbox(
38
  'Which transformers model would you like to be selected?',
39
- ('multi-qa-MiniLM-L6-cos-v1','louis030195/multi-qa-MiniLM-L6-cos-v1-de-ecommerce','null'))
40
 
41
  option2 = st.sidebar.selectbox(
42
  'Which corss-encoder model would you like to be selected?',
@@ -65,20 +64,52 @@ with open(embedding_cache_path, "rb") as fIn:
65
  passages = cache_data['sentences']
66
  corpus_embeddings = cache_data['embeddings']
67
 
68
- kw_extractor = yake.KeywordExtractor()
69
- language = "en"
70
- max_ngram_size = 3
71
- deduplication_threshold = 0.9
72
- numOfKeywords = 20
73
- custom_kw_extractor=yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, top=numOfKeywords, features=None)
 
 
 
 
 
 
 
 
 
 
74
 
75
  # This function will search all wikipedia articles for passages that
76
  # answer the query
77
  def search(query):
78
- st.write("Input question:", query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  ##### Sematic Search #####
80
  # Encode the query using the bi-encoder and find potentially relevant passages
81
  query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
 
82
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
83
  hits = hits[0] # Get the hits for the first query
84
 
@@ -91,28 +122,33 @@ def search(query):
91
  for idx in range(len(cross_scores)):
92
  hits[idx]['cross-score'] = cross_scores[idx]
93
 
94
- # Output of top-N hits from bi-encoder
95
- #st.write("\n-------------------------\n")
96
- #st.subheader("Top-N Bi-Encoder Retrieval hits")
97
- #hits = sorted(hits, key=lambda x: x['score'], reverse=True)
98
- #for hit in hits[0:maxtags_sidebar]:
99
- # st.write("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
100
-
101
- # Output of top-N hits from re-ranker
102
- st.write("\n-------------------------\n")
103
- st.subheader("Top-N Cross-Encoder Re-ranker hits")
 
 
 
 
104
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
105
- #for hit in hits[0:maxtags_sidebar]:
106
- # st.write("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
107
- hit_res = []
108
  for hit in hits[0:1000]:
109
- q = passages[hit['corpus_id']].replace("\n", " ")
110
- if q not in hit_res:
111
- hit_res.append(q)
112
- for res in hit_res[0:maxtags_sidebar]:
113
- keywords = custom_kw_extractor.extract_keywords(res)
114
- for kw in keywords:
115
- st.write(kw)
 
 
 
116
 
117
  st.write("## Results:")
118
  if st.button('Generated Expansion'):
 
9
  import os
10
  import torch
11
  import pickle
 
12
 
13
  ############
14
  ## Main page
 
35
  # Add selectbox in streamlit
36
  option1 = st.sidebar.selectbox(
37
  'Which transformers model would you like to be selected?',
38
+ ('multi-qa-MiniLM-L6-cos-v1','null','null'))
39
 
40
  option2 = st.sidebar.selectbox(
41
  'Which corss-encoder model would you like to be selected?',
 
64
  passages = cache_data['sentences']
65
  corpus_embeddings = cache_data['embeddings']
66
 
67
+ from rank_bm25 import BM25Okapi
68
+ from sklearn.feature_extraction import _stop_words
69
+ import string
70
+ from tqdm.autonotebook import tqdm
71
+ import numpy as np
72
+
73
+
74
+ # We lower case our text and remove stop-words from indexing
75
+ def bm25_tokenizer(text):
76
+ tokenized_doc = []
77
+ for token in text.lower().split():
78
+ token = token.strip(string.punctuation)
79
+
80
+ if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
81
+ tokenized_doc.append(token)
82
+ return tokenized_doc
83
 
84
  # This function will search all wikipedia articles for passages that
85
  # answer the query
86
  def search(query):
87
+ print("Input query:", query)
88
+ total_qe = []
89
+
90
+ ##### BM25 search (lexical search) #####
91
+ bm25_scores = bm25.get_scores(bm25_tokenizer(query))
92
+ top_n = np.argpartition(bm25_scores, -5)[-5:]
93
+ bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
94
+ bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
95
+
96
+ #print("Top-10 lexical search (BM25) hits")
97
+ qe_string = []
98
+ for hit in bm25_hits[0:1000]:
99
+ if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
100
+ qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
101
+
102
+ sub_string = []
103
+ for item in qe_string:
104
+ for sub_item in item.split(","):
105
+ sub_string.append(sub_item)
106
+ #print(sub_string)
107
+ total_qe.append(sub_string)
108
+
109
  ##### Sematic Search #####
110
  # Encode the query using the bi-encoder and find potentially relevant passages
111
  query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
112
+ query_embedding = query_embedding.cuda()
113
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
114
  hits = hits[0] # Get the hits for the first query
115
 
 
122
  for idx in range(len(cross_scores)):
123
  hits[idx]['cross-score'] = cross_scores[idx]
124
 
125
+ # Output of top-10 hits from bi-encoder
126
+ #print("\n-------------------------\n")
127
+ #print("Top-N Bi-Encoder Retrieval hits")
128
+ hits = sorted(hits, key=lambda x: x['score'], reverse=True)
129
+ qe_string = []
130
+ for hit in hits[0:1000]:
131
+ if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
132
+ qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
133
+ #print(qe_string)
134
+ total_qe.append(qe_string)
135
+
136
+ # Output of top-10 hits from re-ranker
137
+ #print("\n-------------------------\n")
138
+ #print("Top-N Cross-Encoder Re-ranker hits")
139
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
140
+ qe_string = []
 
 
141
  for hit in hits[0:1000]:
142
+ if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
143
+ qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
144
+ #print(qe_string)
145
+ total_qe.append(qe_string)
146
+
147
+ # Total Results
148
+ total_qe.append(qe_string)
149
+ print("E-Commerce Query Expansion Results: \n")
150
+ print(total_qe)
151
+
152
 
153
  st.write("## Results:")
154
  if st.button('Generated Expansion'):