Spaces:
Sleeping
Sleeping
phamson02
commited on
Commit
·
1ee7f3c
1
Parent(s):
29631f9
init
Browse files- .gitattributes +2 -0
- app.py +90 -0
- data/child_passages.tsv +3 -0
- data/parent_passages.tsv +3 -0
- data/passage_embeds.pkl +3 -0
- requirements.txt +4 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/child_passages.tsv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/parent_passages.tsv filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import pandas as pd
|
5 |
+
from sentence_transformers import SentenceTransformer, util
|
6 |
+
|
7 |
+
bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert")
|
8 |
+
# cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
9 |
+
corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl")
|
10 |
+
|
11 |
+
with open("data/child_passages.tsv", "r") as f:
|
12 |
+
tsv_reader = csv.reader(f, delimiter="\t")
|
13 |
+
child_passage_ids = [row[0] for row in tsv_reader]
|
14 |
+
|
15 |
+
with open("data/parent_passages.tsv", "r") as f:
|
16 |
+
tsv_reader = csv.reader(f, delimiter="\t")
|
17 |
+
parent_passages = {row[0]: row[1] for row in tsv_reader}
|
18 |
+
|
19 |
+
|
20 |
+
def f7(seq):
|
21 |
+
seen = set()
|
22 |
+
seen_add = seen.add
|
23 |
+
return [x for x in seq if not (x in seen or seen_add(x))]
|
24 |
+
|
25 |
+
|
26 |
+
def search(query: str, top_k: int = 100, reranking: bool = False):
|
27 |
+
print("Top 5 Answer by the NSE:")
|
28 |
+
print()
|
29 |
+
ans: list[str] = []
|
30 |
+
##### Sematic Search #####
|
31 |
+
# Encode the query using the bi-encoder and find potentially relevant passages
|
32 |
+
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
33 |
+
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
34 |
+
hits = hits[0] # Get the hits for the first query
|
35 |
+
|
36 |
+
##### Re-Ranking #####
|
37 |
+
# Now, score all retrieved passages with the cross_encoder
|
38 |
+
if reranking:
|
39 |
+
cross_inp = [[query, corpus[hit["corpus_id"]]] for hit in hits]
|
40 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
41 |
+
|
42 |
+
# Sort results by the cross-encoder scores
|
43 |
+
for idx in range(len(cross_scores)):
|
44 |
+
hits[idx]["cross-score"] = cross_scores[idx]
|
45 |
+
|
46 |
+
hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
|
47 |
+
|
48 |
+
top_20_hits = hits[0:20]
|
49 |
+
hit_child_passage_ids = [child_passage_ids[hit["corpus_id"]] for hit in top_20_hits]
|
50 |
+
hit_parent_passage_ids = f7(
|
51 |
+
[
|
52 |
+
"_".join(hit_child_passage_id.split("_")[:-1])
|
53 |
+
for hit_child_passage_id in hit_child_passage_ids
|
54 |
+
]
|
55 |
+
)
|
56 |
+
|
57 |
+
assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found"
|
58 |
+
|
59 |
+
for hit in hit_parent_passage_ids[:5]:
|
60 |
+
ans.append(parent_passages[hit])
|
61 |
+
|
62 |
+
return ans[0], ans[1], ans[2], ans[3], ans[4]
|
63 |
+
|
64 |
+
|
65 |
+
exp = [
|
66 |
+
"Who is steve jobs?",
|
67 |
+
"What is coldplay?",
|
68 |
+
"What is a turing test?",
|
69 |
+
"What is the most interesting thing about our universe?",
|
70 |
+
"What are the most beautiful places on earth?",
|
71 |
+
]
|
72 |
+
|
73 |
+
desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers."
|
74 |
+
|
75 |
+
inp = gr.Textbox(lines=1, placeholder=None, label="search you query here")
|
76 |
+
out1 = gr.Textbox(type="text", label="Search result 1")
|
77 |
+
out2 = gr.Textbox(type="text", label="Search result 2")
|
78 |
+
out3 = gr.Textbox(type="text", label="Search result 3")
|
79 |
+
out4 = gr.Textbox(type="text", label="Search result 4")
|
80 |
+
out5 = gr.Textbox(type="text", label="Search result 5")
|
81 |
+
|
82 |
+
iface = gr.Interface(
|
83 |
+
fn=search,
|
84 |
+
inputs=inp,
|
85 |
+
outputs=[out1, out2, out3, out4, out5],
|
86 |
+
examples=exp,
|
87 |
+
article=desc,
|
88 |
+
title="Neural Search Engine",
|
89 |
+
)
|
90 |
+
iface.launch()
|
data/child_passages.tsv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ab5beca4e38074457dd397a7b56f30679db2a92716bd31fa11a4c303db3dea3
|
3 |
+
size 428522185
|
data/parent_passages.tsv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b8e35cd1a742779cbfb2ff5b7809ff92663a90ee92bd72c2d8b1d66f375e6535
|
3 |
+
size 352871807
|
data/passage_embeds.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3f4f423e0c7ef2021f17b8ca0dd78b1a32c5c75d8eaca06238ac12e8661ea0e
|
3 |
+
size 1060860322
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sentence-transformers
|
2 |
+
torch
|
3 |
+
pandas
|
4 |
+
gradio
|