Spaces:
Sleeping
Sleeping
phamson02
commited on
Commit
·
cad5609
1
Parent(s):
1976bba
update
Browse files- app.py +24 -13
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,21 +1,27 @@
|
|
1 |
import csv
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import pandas as pd
|
5 |
from sentence_transformers import SentenceTransformer, util
|
6 |
from underthesea import word_tokenize
|
|
|
|
|
7 |
|
8 |
bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert")
|
9 |
-
|
|
|
|
|
10 |
corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl")
|
11 |
|
12 |
with open("data/child_passages.tsv", "r") as f:
|
13 |
tsv_reader = csv.reader(f, delimiter="\t")
|
14 |
child_passage_ids = [row[0] for row in tsv_reader]
|
|
|
15 |
|
16 |
with open("data/parent_passages.tsv", "r") as f:
|
17 |
tsv_reader = csv.reader(f, delimiter="\t")
|
18 |
-
|
19 |
|
20 |
|
21 |
def f7(seq):
|
@@ -24,7 +30,7 @@ def f7(seq):
|
|
24 |
return [x for x in seq if not (x in seen or seen_add(x))]
|
25 |
|
26 |
|
27 |
-
def search(query: str,
|
28 |
query = word_tokenize(query, format="text")
|
29 |
|
30 |
print("Top 5 Answer by the NSE:")
|
@@ -36,19 +42,22 @@ def search(query: str, top_k: int = 100, reranking: bool = False):
|
|
36 |
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
37 |
hits = hits[0] # Get the hits for the first query
|
38 |
|
|
|
|
|
|
|
39 |
##### Re-Ranking #####
|
40 |
# Now, score all retrieved passages with the cross_encoder
|
41 |
if reranking:
|
42 |
-
|
43 |
-
|
|
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
|
50 |
|
51 |
-
top_20_hits =
|
52 |
hit_child_passage_ids = [child_passage_ids[hit["corpus_id"]] for hit in top_20_hits]
|
53 |
hit_parent_passage_ids = f7(
|
54 |
[
|
@@ -60,7 +69,7 @@ def search(query: str, top_k: int = 100, reranking: bool = False):
|
|
60 |
assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found"
|
61 |
|
62 |
for hit in hit_parent_passage_ids[:5]:
|
63 |
-
ans.append(
|
64 |
|
65 |
return ans[0], ans[1], ans[2], ans[3], ans[4]
|
66 |
|
@@ -76,6 +85,8 @@ exp = [
|
|
76 |
desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers."
|
77 |
|
78 |
inp = gr.Textbox(lines=1, placeholder=None, label="search you query here")
|
|
|
|
|
79 |
out1 = gr.Textbox(type="text", label="Search result 1")
|
80 |
out2 = gr.Textbox(type="text", label="Search result 2")
|
81 |
out3 = gr.Textbox(type="text", label="Search result 3")
|
@@ -84,7 +95,7 @@ out5 = gr.Textbox(type="text", label="Search result 5")
|
|
84 |
|
85 |
iface = gr.Interface(
|
86 |
fn=search,
|
87 |
-
inputs=inp,
|
88 |
outputs=[out1, out2, out3, out4, out5],
|
89 |
examples=exp,
|
90 |
article=desc,
|
|
|
1 |
import csv
|
2 |
+
from typing import Any
|
3 |
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
6 |
from sentence_transformers import SentenceTransformer, util
|
7 |
from underthesea import word_tokenize
|
8 |
+
from retriever_trainer import PretrainedColBERT
|
9 |
+
|
10 |
|
11 |
bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert")
|
12 |
+
colbert = PretrainedColBERT(
|
13 |
+
pretrained_model_name="phamson02/colbert2.1_290000",
|
14 |
+
)
|
15 |
corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl")
|
16 |
|
17 |
with open("data/child_passages.tsv", "r") as f:
|
18 |
tsv_reader = csv.reader(f, delimiter="\t")
|
19 |
child_passage_ids = [row[0] for row in tsv_reader]
|
20 |
+
child_passages = [row[1] for row in tsv_reader]
|
21 |
|
22 |
with open("data/parent_passages.tsv", "r") as f:
|
23 |
tsv_reader = csv.reader(f, delimiter="\t")
|
24 |
+
parent_passages_map = {row[0]: row[1] for row in tsv_reader}
|
25 |
|
26 |
|
27 |
def f7(seq):
|
|
|
30 |
return [x for x in seq if not (x in seen or seen_add(x))]
|
31 |
|
32 |
|
33 |
+
def search(query: str, reranking: bool = False, top_k: int = 100):
|
34 |
query = word_tokenize(query, format="text")
|
35 |
|
36 |
print("Top 5 Answer by the NSE:")
|
|
|
42 |
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
43 |
hits = hits[0] # Get the hits for the first query
|
44 |
|
45 |
+
top_k_child_passages = [child_passages[hit["corpus_id"]] for hit in hits]
|
46 |
+
top_k_child_passage_ids = [child_passage_ids[hit["corpus_id"]] for hit in hits]
|
47 |
+
|
48 |
##### Re-Ranking #####
|
49 |
# Now, score all retrieved passages with the cross_encoder
|
50 |
if reranking:
|
51 |
+
colbert_scores: list[dict[str, Any]] = colbert.rerank(
|
52 |
+
query=query, documents=top_k_child_passages, top_k=100
|
53 |
+
)
|
54 |
|
55 |
+
# Reorder child passage ids based on the reranking
|
56 |
+
top_k_child_passage_ids = [
|
57 |
+
top_k_child_passage_ids[score["corpus_id"]] for score in colbert_scores
|
58 |
+
]
|
|
|
59 |
|
60 |
+
top_20_hits = top_k_child_passage_ids[0:20]
|
61 |
hit_child_passage_ids = [child_passage_ids[hit["corpus_id"]] for hit in top_20_hits]
|
62 |
hit_parent_passage_ids = f7(
|
63 |
[
|
|
|
69 |
assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found"
|
70 |
|
71 |
for hit in hit_parent_passage_ids[:5]:
|
72 |
+
ans.append(parent_passages_map[hit])
|
73 |
|
74 |
return ans[0], ans[1], ans[2], ans[3], ans[4]
|
75 |
|
|
|
85 |
desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers."
|
86 |
|
87 |
inp = gr.Textbox(lines=1, placeholder=None, label="search you query here")
|
88 |
+
reranking_checkbox = gr.Checkbox(label="Enable reranking")
|
89 |
+
|
90 |
out1 = gr.Textbox(type="text", label="Search result 1")
|
91 |
out2 = gr.Textbox(type="text", label="Search result 2")
|
92 |
out3 = gr.Textbox(type="text", label="Search result 3")
|
|
|
95 |
|
96 |
iface = gr.Interface(
|
97 |
fn=search,
|
98 |
+
inputs=[inp, reranking_checkbox],
|
99 |
outputs=[out1, out2, out3, out4, out5],
|
100 |
examples=exp,
|
101 |
article=desc,
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ sentence-transformers
|
|
2 |
torch
|
3 |
pandas
|
4 |
gradio
|
5 |
-
underthesea
|
|
|
|
2 |
torch
|
3 |
pandas
|
4 |
gradio
|
5 |
+
underthesea
|
6 |
+
retriever-trainer[colbert] @ git+https://[email protected]/phamson02/retriever_trainer.git@rerank
|