Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,6 @@ import gzip
|
|
9 |
import os
|
10 |
import torch
|
11 |
import pickle
|
12 |
-
import yake
|
13 |
|
14 |
############
|
15 |
## Main page
|
@@ -36,7 +35,7 @@ user_query = st.text_input("Enter a query for the generated text: e.g., gift, ho
|
|
36 |
# Add selectbox in streamlit
|
37 |
option1 = st.sidebar.selectbox(
|
38 |
'Which transformers model would you like to be selected?',
|
39 |
-
('multi-qa-MiniLM-L6-cos-v1','
|
40 |
|
41 |
option2 = st.sidebar.selectbox(
|
42 |
'Which corss-encoder model would you like to be selected?',
|
@@ -65,20 +64,52 @@ with open(embedding_cache_path, "rb") as fIn:
|
|
65 |
passages = cache_data['sentences']
|
66 |
corpus_embeddings = cache_data['embeddings']
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# This function will search all wikipedia articles for passages that
|
76 |
# answer the query
|
77 |
def search(query):
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
##### Sematic Search #####
|
80 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
81 |
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
|
|
82 |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
|
83 |
hits = hits[0] # Get the hits for the first query
|
84 |
|
@@ -91,28 +122,33 @@ def search(query):
|
|
91 |
for idx in range(len(cross_scores)):
|
92 |
hits[idx]['cross-score'] = cross_scores[idx]
|
93 |
|
94 |
-
# Output of top-
|
95 |
-
#
|
96 |
-
#
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
104 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
105 |
-
|
106 |
-
# st.write("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
|
107 |
-
hit_res = []
|
108 |
for hit in hits[0:1000]:
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
116 |
|
117 |
st.write("## Results:")
|
118 |
if st.button('Generated Expansion'):
|
|
|
9 |
import os
|
10 |
import torch
|
11 |
import pickle
|
|
|
12 |
|
13 |
############
|
14 |
## Main page
|
|
|
35 |
# Add selectbox in streamlit
|
36 |
option1 = st.sidebar.selectbox(
|
37 |
'Which transformers model would you like to be selected?',
|
38 |
+
('multi-qa-MiniLM-L6-cos-v1','null','null'))
|
39 |
|
40 |
option2 = st.sidebar.selectbox(
|
41 |
'Which corss-encoder model would you like to be selected?',
|
|
|
64 |
passages = cache_data['sentences']
|
65 |
corpus_embeddings = cache_data['embeddings']
|
66 |
|
67 |
+
from rank_bm25 import BM25Okapi
|
68 |
+
from sklearn.feature_extraction import _stop_words
|
69 |
+
import string
|
70 |
+
from tqdm.autonotebook import tqdm
|
71 |
+
import numpy as np
|
72 |
+
|
73 |
+
|
74 |
+
# We lower case our text and remove stop-words from indexing
|
75 |
+
def bm25_tokenizer(text):
|
76 |
+
tokenized_doc = []
|
77 |
+
for token in text.lower().split():
|
78 |
+
token = token.strip(string.punctuation)
|
79 |
+
|
80 |
+
if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
|
81 |
+
tokenized_doc.append(token)
|
82 |
+
return tokenized_doc
|
83 |
|
84 |
# This function will search all wikipedia articles for passages that
|
85 |
# answer the query
|
86 |
def search(query):
|
87 |
+
print("Input query:", query)
|
88 |
+
total_qe = []
|
89 |
+
|
90 |
+
##### BM25 search (lexical search) #####
|
91 |
+
bm25_scores = bm25.get_scores(bm25_tokenizer(query))
|
92 |
+
top_n = np.argpartition(bm25_scores, -5)[-5:]
|
93 |
+
bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
|
94 |
+
bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
|
95 |
+
|
96 |
+
#print("Top-10 lexical search (BM25) hits")
|
97 |
+
qe_string = []
|
98 |
+
for hit in bm25_hits[0:1000]:
|
99 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
100 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
101 |
+
|
102 |
+
sub_string = []
|
103 |
+
for item in qe_string:
|
104 |
+
for sub_item in item.split(","):
|
105 |
+
sub_string.append(sub_item)
|
106 |
+
#print(sub_string)
|
107 |
+
total_qe.append(sub_string)
|
108 |
+
|
109 |
##### Sematic Search #####
|
110 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
111 |
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
112 |
+
query_embedding = query_embedding.cuda()
|
113 |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
|
114 |
hits = hits[0] # Get the hits for the first query
|
115 |
|
|
|
122 |
for idx in range(len(cross_scores)):
|
123 |
hits[idx]['cross-score'] = cross_scores[idx]
|
124 |
|
125 |
+
# Output of top-10 hits from bi-encoder
|
126 |
+
#print("\n-------------------------\n")
|
127 |
+
#print("Top-N Bi-Encoder Retrieval hits")
|
128 |
+
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
129 |
+
qe_string = []
|
130 |
+
for hit in hits[0:1000]:
|
131 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
132 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
133 |
+
#print(qe_string)
|
134 |
+
total_qe.append(qe_string)
|
135 |
+
|
136 |
+
# Output of top-10 hits from re-ranker
|
137 |
+
#print("\n-------------------------\n")
|
138 |
+
#print("Top-N Cross-Encoder Re-ranker hits")
|
139 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
140 |
+
qe_string = []
|
|
|
|
|
141 |
for hit in hits[0:1000]:
|
142 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
143 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
144 |
+
#print(qe_string)
|
145 |
+
total_qe.append(qe_string)
|
146 |
+
|
147 |
+
# Total Results
|
148 |
+
total_qe.append(qe_string)
|
149 |
+
print("E-Commerce Query Expansion Results: \n")
|
150 |
+
print(total_qe)
|
151 |
+
|
152 |
|
153 |
st.write("## Results:")
|
154 |
if st.button('Generated Expansion'):
|