Spaces:
Sleeping
Sleeping
import csv | |
import gradio as gr | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer, util | |
bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert") | |
# cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl") | |
with open("data/child_passages.tsv", "r") as f: | |
tsv_reader = csv.reader(f, delimiter="\t") | |
child_passage_ids = [row[0] for row in tsv_reader] | |
with open("data/parent_passages.tsv", "r") as f: | |
tsv_reader = csv.reader(f, delimiter="\t") | |
parent_passages = {row[0]: row[1] for row in tsv_reader} | |
def f7(seq): | |
seen = set() | |
seen_add = seen.add | |
return [x for x in seq if not (x in seen or seen_add(x))] | |
def search(query: str, top_k: int = 100, reranking: bool = False): | |
print("Top 5 Answer by the NSE:") | |
print() | |
ans: list[str] = [] | |
##### Sematic Search ##### | |
# Encode the query using the bi-encoder and find potentially relevant passages | |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True) | |
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) | |
hits = hits[0] # Get the hits for the first query | |
##### Re-Ranking ##### | |
# Now, score all retrieved passages with the cross_encoder | |
if reranking: | |
cross_inp = [[query, corpus[hit["corpus_id"]]] for hit in hits] | |
cross_scores = cross_encoder.predict(cross_inp) | |
# Sort results by the cross-encoder scores | |
for idx in range(len(cross_scores)): | |
hits[idx]["cross-score"] = cross_scores[idx] | |
hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True) | |
top_20_hits = hits[0:20] | |
hit_child_passage_ids = [child_passage_ids[hit["corpus_id"]] for hit in top_20_hits] | |
hit_parent_passage_ids = f7( | |
[ | |
"_".join(hit_child_passage_id.split("_")[:-1]) | |
for hit_child_passage_id in hit_child_passage_ids | |
] | |
) | |
assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found" | |
for hit in hit_parent_passage_ids[:5]: | |
ans.append(parent_passages[hit]) | |
return ans[0], ans[1], ans[2], ans[3], ans[4] | |
exp = [ | |
"Who is steve jobs?", | |
"What is coldplay?", | |
"What is a turing test?", | |
"What is the most interesting thing about our universe?", | |
"What are the most beautiful places on earth?", | |
] | |
desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers." | |
inp = gr.Textbox(lines=1, placeholder=None, label="search you query here") | |
out1 = gr.Textbox(type="text", label="Search result 1") | |
out2 = gr.Textbox(type="text", label="Search result 2") | |
out3 = gr.Textbox(type="text", label="Search result 3") | |
out4 = gr.Textbox(type="text", label="Search result 4") | |
out5 = gr.Textbox(type="text", label="Search result 5") | |
iface = gr.Interface( | |
fn=search, | |
inputs=inp, | |
outputs=[out1, out2, out3, out4, out5], | |
examples=exp, | |
article=desc, | |
title="Neural Search Engine", | |
) | |
iface.launch() | |