phamson02 commited on
Commit
cad5609
·
1 Parent(s): 1976bba
Files changed (2) hide show
  1. app.py +24 -13
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,21 +1,27 @@
1
  import csv
 
2
 
3
  import gradio as gr
4
  import pandas as pd
5
  from sentence_transformers import SentenceTransformer, util
6
  from underthesea import word_tokenize
 
 
7
 
8
  bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert")
9
- # cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
 
 
10
  corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl")
11
 
12
  with open("data/child_passages.tsv", "r") as f:
13
  tsv_reader = csv.reader(f, delimiter="\t")
14
  child_passage_ids = [row[0] for row in tsv_reader]
 
15
 
16
  with open("data/parent_passages.tsv", "r") as f:
17
  tsv_reader = csv.reader(f, delimiter="\t")
18
- parent_passages = {row[0]: row[1] for row in tsv_reader}
19
 
20
 
21
  def f7(seq):
@@ -24,7 +30,7 @@ def f7(seq):
24
  return [x for x in seq if not (x in seen or seen_add(x))]
25
 
26
 
27
- def search(query: str, top_k: int = 100, reranking: bool = False):
28
  query = word_tokenize(query, format="text")
29
 
30
  print("Top 5 Answer by the NSE:")
@@ -36,19 +42,22 @@ def search(query: str, top_k: int = 100, reranking: bool = False):
36
  hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
37
  hits = hits[0] # Get the hits for the first query
38
 
 
 
 
39
  ##### Re-Ranking #####
40
  # Now, score all retrieved passages with the cross_encoder
41
  if reranking:
42
- cross_inp = [[query, corpus[hit["corpus_id"]]] for hit in hits]
43
- cross_scores = cross_encoder.predict(cross_inp)
 
44
 
45
- # Sort results by the cross-encoder scores
46
- for idx in range(len(cross_scores)):
47
- hits[idx]["cross-score"] = cross_scores[idx]
48
-
49
- hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
50
 
51
- top_20_hits = hits[0:20]
52
  hit_child_passage_ids = [child_passage_ids[hit["corpus_id"]] for hit in top_20_hits]
53
  hit_parent_passage_ids = f7(
54
  [
@@ -60,7 +69,7 @@ def search(query: str, top_k: int = 100, reranking: bool = False):
60
  assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found"
61
 
62
  for hit in hit_parent_passage_ids[:5]:
63
- ans.append(parent_passages[hit])
64
 
65
  return ans[0], ans[1], ans[2], ans[3], ans[4]
66
 
@@ -76,6 +85,8 @@ exp = [
76
  desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers."
77
 
78
  inp = gr.Textbox(lines=1, placeholder=None, label="search you query here")
 
 
79
  out1 = gr.Textbox(type="text", label="Search result 1")
80
  out2 = gr.Textbox(type="text", label="Search result 2")
81
  out3 = gr.Textbox(type="text", label="Search result 3")
@@ -84,7 +95,7 @@ out5 = gr.Textbox(type="text", label="Search result 5")
84
 
85
  iface = gr.Interface(
86
  fn=search,
87
- inputs=inp,
88
  outputs=[out1, out2, out3, out4, out5],
89
  examples=exp,
90
  article=desc,
 
1
  import csv
2
+ from typing import Any
3
 
4
  import gradio as gr
5
  import pandas as pd
6
  from sentence_transformers import SentenceTransformer, util
7
  from underthesea import word_tokenize
8
+ from retriever_trainer import PretrainedColBERT
9
+
10
 
11
  bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert")
12
+ colbert = PretrainedColBERT(
13
+ pretrained_model_name="phamson02/colbert2.1_290000",
14
+ )
15
  corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl")
16
 
17
  with open("data/child_passages.tsv", "r") as f:
18
  tsv_reader = csv.reader(f, delimiter="\t")
19
  child_passage_ids = [row[0] for row in tsv_reader]
20
+ child_passages = [row[1] for row in tsv_reader]
21
 
22
  with open("data/parent_passages.tsv", "r") as f:
23
  tsv_reader = csv.reader(f, delimiter="\t")
24
+ parent_passages_map = {row[0]: row[1] for row in tsv_reader}
25
 
26
 
27
  def f7(seq):
 
30
  return [x for x in seq if not (x in seen or seen_add(x))]
31
 
32
 
33
+ def search(query: str, reranking: bool = False, top_k: int = 100):
34
  query = word_tokenize(query, format="text")
35
 
36
  print("Top 5 Answer by the NSE:")
 
42
  hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
43
  hits = hits[0] # Get the hits for the first query
44
 
45
+ top_k_child_passages = [child_passages[hit["corpus_id"]] for hit in hits]
46
+ top_k_child_passage_ids = [child_passage_ids[hit["corpus_id"]] for hit in hits]
47
+
48
  ##### Re-Ranking #####
49
  # Now, score all retrieved passages with the cross_encoder
50
  if reranking:
51
+ colbert_scores: list[dict[str, Any]] = colbert.rerank(
52
+ query=query, documents=top_k_child_passages, top_k=100
53
+ )
54
 
55
+ # Reorder child passage ids based on the reranking
56
+ top_k_child_passage_ids = [
57
+ top_k_child_passage_ids[score["corpus_id"]] for score in colbert_scores
58
+ ]
 
59
 
60
+ top_20_hits = top_k_child_passage_ids[0:20]
61
  hit_child_passage_ids = [child_passage_ids[hit["corpus_id"]] for hit in top_20_hits]
62
  hit_parent_passage_ids = f7(
63
  [
 
69
  assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found"
70
 
71
  for hit in hit_parent_passage_ids[:5]:
72
+ ans.append(parent_passages_map[hit])
73
 
74
  return ans[0], ans[1], ans[2], ans[3], ans[4]
75
 
 
85
  desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers."
86
 
87
  inp = gr.Textbox(lines=1, placeholder=None, label="search you query here")
88
+ reranking_checkbox = gr.Checkbox(label="Enable reranking")
89
+
90
  out1 = gr.Textbox(type="text", label="Search result 1")
91
  out2 = gr.Textbox(type="text", label="Search result 2")
92
  out3 = gr.Textbox(type="text", label="Search result 3")
 
95
 
96
  iface = gr.Interface(
97
  fn=search,
98
+ inputs=[inp, reranking_checkbox],
99
  outputs=[out1, out2, out3, out4, out5],
100
  examples=exp,
101
  article=desc,
requirements.txt CHANGED
@@ -2,4 +2,5 @@ sentence-transformers
2
  torch
3
  pandas
4
  gradio
5
- underthesea
 
 
2
  torch
3
  pandas
4
  gradio
5
+ underthesea
6
+ retriever-trainer[colbert] @ git+https://[email protected]/phamson02/retriever_trainer.git@rerank