vi_law_search / app.py
phamson02
update
7ce8383
raw
history blame
3.65 kB
import csv
from typing import Any
import gradio as gr
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from underthesea import word_tokenize
from retriever_trainer import PretrainedColBERT
bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert")
colbert = PretrainedColBERT(
pretrained_model_name="phamson02/colbert2.1_290000",
)
corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl")
with open("data/child_passages.tsv", "r") as f:
tsv_reader = csv.reader(f, delimiter="\t")
child_passage_ids, child_passages = zip(*[(row[0], row[1]) for row in tsv_reader])
with open("data/parent_passages.tsv", "r") as f:
tsv_reader = csv.reader(f, delimiter="\t")
parent_passages_map = {row[0]: row[1] for row in tsv_reader}
def f7(seq):
seen = set()
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]
def search(query: str, reranking: bool = False, top_k: int = 100):
query = word_tokenize(query, format="text")
print("Top 5 Answer by the NSE:")
print()
ans: list[str] = []
##### Sematic Search #####
# Encode the query using the bi-encoder and find potentially relevant passages
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = hits[0] # Get the hits for the first query
top_k_child_passages = [child_passages[hit["corpus_id"]] for hit in hits][:20]
top_k_child_passage_ids = [hit["corpus_id"] for hit in hits][:20]
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
if reranking:
colbert_scores: list[dict[str, Any]] = colbert.rerank(
query=query, documents=top_k_child_passages, top_k=20
)
# Reorder child passage ids based on the reranking
top_k_child_passage_ids = [
top_k_child_passage_ids[score["corpus_id"]] for score in colbert_scores
]
top_20_hits = top_k_child_passage_ids[0:20]
hit_child_passage_ids = [child_passage_ids[id] for id in top_20_hits]
hit_parent_passage_ids = f7(
[
"_".join(hit_child_passage_id.split("_")[:-1])
for hit_child_passage_id in hit_child_passage_ids
]
)
assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found"
for hit in hit_parent_passage_ids[:5]:
ans.append(parent_passages_map[hit])
return ans[0], ans[1], ans[2], ans[3], ans[4]
exp = [
["Who is steve jobs?", False],
["What is coldplay?", False],
["What is a turing test?", False],
["What is the most interesting thing about our universe?", False],
["What are the most beautiful places on earth?", False],
]
desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers."
inp = gr.Textbox(lines=1, placeholder=None, label="search you query here")
reranking_checkbox = gr.Checkbox(label="Enable reranking")
out1 = gr.Textbox(type="text", label="Search result 1")
out2 = gr.Textbox(type="text", label="Search result 2")
out3 = gr.Textbox(type="text", label="Search result 3")
out4 = gr.Textbox(type="text", label="Search result 4")
out5 = gr.Textbox(type="text", label="Search result 5")
iface = gr.Interface(
fn=search,
inputs=[inp, reranking_checkbox],
outputs=[out1, out2, out3, out4, out5],
examples=exp,
article=desc,
title="Neural Search Engine",
)
iface.launch()