riccorl commited on
Commit
8197b11
·
1 Parent(s): 087c2a2

Upload models

Browse files
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index/documents.json filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index/documents.json filter=lfs diff=lfs merge=lfs -text
37
+ models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered/documents.json filter=lfs diff=lfs merge=lfs -text
38
+ frequency_blink.txt filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Relik
3
- emoji: 📚
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: streamlit
 
1
  ---
2
  title: Relik
3
+ emoji: 🤖
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: streamlit
app.py CHANGED
@@ -181,7 +181,7 @@ def run_client():
181
 
182
  relik = Relik(
183
  question_encoder="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
184
- document_index="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index",
185
  reader="/home/user/app/models/relik-reader-aida-deberta-small",
186
  top_k=100,
187
  window_size=32,
 
181
 
182
  relik = Relik(
183
  question_encoder="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
184
+ document_index="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
185
  reader="/home/user/app/models/relik-reader-aida-deberta-small",
186
  top_k=100,
187
  window_size=32,
examples/explore_faiss.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # table to store results
2
+
3
+ | Index | nprobe | Recall | Time |
4
+ |----------------|--------|--------|-------|
5
+ | Flat | 1 | 98.7 | 38.64 |
6
+ | IVFx,Flat | 1 | 42.5 | 23.46 |
7
+ | IVFx,Flat | 14 | 88.5 | 133 |
8
+ | IVFx_HNSW,Flat | 1 | 88.5 | 133 |
examples/explore_faiss.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ import time
7
+ from typing import Union
8
+
9
+ import torch
10
+ import tqdm
11
+
12
+ from relik.retriever import GoldenRetriever
13
+ from relik.common.log import get_logger
14
+ from relik.retriever.common.model_inputs import ModelInputs
15
+ from relik.retriever.data.base.datasets import BaseDataset
16
+ from relik.retriever.indexers.base import BaseDocumentIndex
17
+ from relik.retriever.indexers.faiss import FaissDocumentIndex
18
+
19
+ logger = get_logger(level=logging.INFO)
20
+
21
+
22
+ def compute_retriever_stats(dataset) -> None:
23
+ correct, total = 0, 0
24
+ for sample in dataset:
25
+ window_candidates = sample["window_candidates"]
26
+ window_candidates = [c.replace("_", " ").lower() for c in window_candidates]
27
+
28
+ for ss, se, label in sample["window_labels"]:
29
+ if label == "--NME--":
30
+ continue
31
+ if label.replace("_", " ").lower() in window_candidates:
32
+ correct += 1
33
+ total += 1
34
+
35
+ recall = correct / total
36
+ print("Recall:", recall)
37
+
38
+
39
+ @torch.no_grad()
40
+ def add_candidates(
41
+ retriever_name_or_path: Union[str, os.PathLike],
42
+ document_index_name_or_path: Union[str, os.PathLike],
43
+ input_path: Union[str, os.PathLike],
44
+ batch_size: int = 128,
45
+ num_workers: int = 4,
46
+ index_type: str = "Flat",
47
+ nprobe: int = 1,
48
+ device: str = "cpu",
49
+ precision: str = "fp32",
50
+ topics: bool = False,
51
+ ):
52
+ document_index = BaseDocumentIndex.from_pretrained(
53
+ document_index_name_or_path,
54
+ # config_kwargs={
55
+ # "_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex",
56
+ # "index_type": index_type,
57
+ # "nprobe": nprobe,
58
+ # },
59
+ device=device,
60
+ precision=precision,
61
+ )
62
+
63
+ retriever = GoldenRetriever(
64
+ question_encoder=retriever_name_or_path,
65
+ document_index=document_index,
66
+ device=device,
67
+ precision=precision,
68
+ index_device=device,
69
+ index_precision=precision,
70
+ )
71
+ retriever.eval()
72
+
73
+ logger.info(f"Loading from {input_path}")
74
+ with open(input_path) as f:
75
+ samples = [json.loads(line) for line in f.readlines()]
76
+
77
+ topics = topics and "doc_topic" in samples[0]
78
+
79
+ # get tokenizer
80
+ tokenizer = retriever.question_tokenizer
81
+ collate_fn = lambda batch: ModelInputs(
82
+ tokenizer(
83
+ [b["text"] for b in batch],
84
+ text_pair=[b["doc_topic"] for b in batch] if topics else None,
85
+ padding=True,
86
+ return_tensors="pt",
87
+ truncation=True,
88
+ )
89
+ )
90
+ logger.info(f"Creating dataloader with batch size {batch_size}")
91
+ dataloader = torch.utils.data.DataLoader(
92
+ BaseDataset(name="passage", data=samples),
93
+ batch_size=batch_size,
94
+ shuffle=False,
95
+ num_workers=num_workers,
96
+ pin_memory=False,
97
+ collate_fn=collate_fn,
98
+ )
99
+
100
+ # we also dump the candidates to a file after a while
101
+ retrieved_accumulator = []
102
+ with torch.inference_mode():
103
+ start = time.time()
104
+ num_completed_docs = 0
105
+
106
+ for documents_batch in tqdm.tqdm(dataloader):
107
+ retrieve_kwargs = {
108
+ **documents_batch,
109
+ "k": 100,
110
+ "precision": precision,
111
+ }
112
+ batch_out = retriever.retrieve(**retrieve_kwargs)
113
+ retrieved_accumulator.extend(batch_out)
114
+
115
+ end = time.time()
116
+
117
+ output_data = []
118
+ # get the correct document from the original dataset
119
+ # the dataloader is not shuffled, so we can just count the number of
120
+ # documents we have seen so far
121
+ for sample, retrieved in zip(
122
+ samples[
123
+ num_completed_docs : num_completed_docs + len(retrieved_accumulator)
124
+ ],
125
+ retrieved_accumulator,
126
+ ):
127
+ candidate_titles = [c.label.split(" <def>", 1)[0] for c in retrieved]
128
+ sample["window_candidates"] = candidate_titles
129
+ sample["window_candidates_scores"] = [c.score for c in retrieved]
130
+ output_data.append(sample)
131
+
132
+ # for sample in output_data:
133
+ # f_out.write(json.dumps(sample) + "\n")
134
+
135
+ num_completed_docs += len(retrieved_accumulator)
136
+ retrieved_accumulator = []
137
+
138
+ compute_retriever_stats(output_data)
139
+ print(f"Retrieval took {end - start:.2f} seconds")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ # arg_parser = argparse.ArgumentParser()
144
+ # arg_parser.add_argument("--retriever_name_or_path", type=str, required=True)
145
+ # arg_parser.add_argument("--document_index_name_or_path", type=str, required=True)
146
+ # arg_parser.add_argument("--input_path", type=str, required=True)
147
+ # arg_parser.add_argument("--output_path", type=str, required=True)
148
+ # arg_parser.add_argument("--batch_size", type=int, default=128)
149
+ # arg_parser.add_argument("--device", type=str, default="cuda")
150
+ # arg_parser.add_argument("--index_device", type=str, default="cpu")
151
+ # arg_parser.add_argument("--precision", type=str, default="fp32")
152
+
153
+ # add_candidates(**vars(arg_parser.parse_args()))
154
+ add_candidates(
155
+ "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
156
+ "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
157
+ "/root/relik-spaces/data/reader/aida/testa_windowed.jsonl",
158
+ # index_type="HNSW32",
159
+ # index_type="IVF1024,PQ8",
160
+ # nprobe=1,
161
+ topics=True,
162
+ device="cuda",
163
+ )
frequency_blink.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63bdea194b5c27d8c35547a205c42b4bc2e8933a47f179bc63256cf12a3bd448
3
+ size 95579105
models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered/config.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ _target_: relik.retriever.indexers.inmemory.InMemoryDocumentIndex
2
+ documents:
3
+ _target_: relik.retriever.data.labels.Labels
4
+ embeddings:
5
+ _target_: torch.Tensor
6
+ name_or_dir: null
7
+ device: cpu
8
+ precision: null
models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered/documents.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:486ef055dcc484ddd9d445cfc2bac1e2a7c133d79492610de49b72630bd6ce8f
3
+ size 719452975
models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered/embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee144610bf744e96091f4f295d350806173703d0960a964444a1c13b248a5c0d
3
+ size 1537987243
relik/inference/annotator.py CHANGED
@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Optional, Union
4
 
5
  import hydra
6
  from omegaconf import OmegaConf
 
7
  from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
8
  from rich.pretty import pprint
9
 
@@ -395,10 +396,15 @@ class Relik:
395
  def main():
396
  from pprint import pprint
397
 
 
 
 
 
 
398
  relik = Relik(
399
- question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
400
- document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
401
- reader="riccorl/relik-reader-aida-deberta-small",
402
  device="cuda",
403
  precision=16,
404
  top_k=100,
 
4
 
5
  import hydra
6
  from omegaconf import OmegaConf
7
+ from relik.retriever.indexers.faiss import FaissDocumentIndex
8
  from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
9
  from rich.pretty import pprint
10
 
 
396
  def main():
397
  from pprint import pprint
398
 
399
+ document_index = FaissDocumentIndex.from_pretrained(
400
+ "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index",
401
+ config_kwargs={"_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex", "index_type": "IVFx,Flat"},
402
+ )
403
+
404
  relik = Relik(
405
+ question_encoder="/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
406
+ document_index=document_index,
407
+ reader="/root/relik-spaces/models/relik-reader-aida-deberta-small",
408
  device="cuda",
409
  precision=16,
410
  top_k=100,
relik/retriever/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from relik.retriever.pytorch_modules.model import GoldenRetriever
relik/retriever/indexers/base.py CHANGED
@@ -79,6 +79,17 @@ class BaseDocumentIndex:
79
  self.embeddings = embeddings
80
  self.name_or_dir = name_or_dir
81
 
 
 
 
 
 
 
 
 
 
 
 
82
  @property
83
  def config(self) -> Dict[str, Any]:
84
  """
@@ -261,6 +272,7 @@ class BaseDocumentIndex:
261
  config_file_name: Optional[str] = None,
262
  document_file_name: Optional[str] = None,
263
  embedding_file_name: Optional[str] = None,
 
264
  *args,
265
  **kwargs,
266
  ) -> "BaseDocumentIndex":
@@ -285,6 +297,9 @@ class BaseDocumentIndex:
285
  )
286
 
287
  config = OmegaConf.load(config_path)
 
 
 
288
  pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True)
289
 
290
  # load the documents
 
79
  self.embeddings = embeddings
80
  self.name_or_dir = name_or_dir
81
 
82
+ def __iter__(self):
83
+ # make this class iterable
84
+ for i in range(len(self)):
85
+ yield self[i]
86
+
87
+ def __len__(self):
88
+ return self.documents.get_label_size()
89
+
90
+ def __getitem__(self, index):
91
+ return self.get_passage_from_index(index)
92
+
93
  @property
94
  def config(self) -> Dict[str, Any]:
95
  """
 
272
  config_file_name: Optional[str] = None,
273
  document_file_name: Optional[str] = None,
274
  embedding_file_name: Optional[str] = None,
275
+ config_kwargs: Optional[Dict[str, Any]] = None,
276
  *args,
277
  **kwargs,
278
  ) -> "BaseDocumentIndex":
 
297
  )
298
 
299
  config = OmegaConf.load(config_path)
300
+ # override the config with the kwargs
301
+ if config_kwargs is not None:
302
+ config = OmegaConf.merge(config, OmegaConf.create(config_kwargs))
303
  pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True)
304
 
305
  # load the documents
relik/retriever/indexers/faiss.py CHANGED
@@ -6,8 +6,9 @@ from dataclasses import dataclass
6
  from typing import Callable, List, Optional, Union
7
 
8
  import numpy
 
9
  import torch
10
- from pytorch_modules import RetrievedSample
11
  from torch.utils.data import DataLoader
12
  from tqdm import tqdm
13
 
@@ -44,6 +45,7 @@ class FaissDocumentIndex(BaseDocumentIndex):
44
  embeddings: Optional[Union[torch.Tensor, numpy.ndarray]] = None,
45
  index=None,
46
  index_type: str = "Flat",
 
47
  metric: int = faiss.METRIC_INNER_PRODUCT,
48
  normalize: bool = False,
49
  device: str = "cpu",
@@ -60,6 +62,8 @@ class FaissDocumentIndex(BaseDocumentIndex):
60
  "The number of documents and embeddings must be the same."
61
  )
62
 
 
 
63
  # device to store the embeddings
64
  self.device = device
65
 
@@ -83,6 +87,7 @@ class FaissDocumentIndex(BaseDocumentIndex):
83
  self.embeddings = self._build_faiss_index(
84
  embeddings=embeddings,
85
  index_type=index_type,
 
86
  normalize=normalize,
87
  metric=metric,
88
  )
@@ -91,6 +96,7 @@ class FaissDocumentIndex(BaseDocumentIndex):
91
  self,
92
  embeddings: Optional[Union[torch.Tensor, numpy.ndarray]],
93
  index_type: str,
 
94
  normalize: bool,
95
  metric: int,
96
  ):
@@ -103,11 +109,15 @@ class FaissDocumentIndex(BaseDocumentIndex):
103
  if self.normalize:
104
  index_type = f"L2norm,{index_type}"
105
  faiss_vector_size = embeddings.shape[1]
106
- if self.device == "cpu":
107
- index_type = index_type.replace("x,", "x_HNSW32,")
108
- index_type = index_type.replace(
109
- "x", str(math.ceil(math.sqrt(faiss_vector_size)) * 4)
110
- )
 
 
 
 
111
  self.embeddings = faiss.index_factory(faiss_vector_size, index_type, metric)
112
 
113
  # convert to GPU
@@ -121,12 +131,24 @@ class FaissDocumentIndex(BaseDocumentIndex):
121
  embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings
122
  )
123
 
 
124
  # convert to float32 if embeddings is a torch.Tensor and is float16
125
  if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16:
126
  embeddings = embeddings.float()
127
 
 
 
 
 
128
  self.embeddings.add(embeddings)
129
 
 
 
 
 
 
 
 
130
  # save parameters for saving/loading
131
  self.index_type = index_type
132
  self.metric = metric
@@ -277,6 +299,7 @@ class FaissDocumentIndex(BaseDocumentIndex):
277
  @torch.no_grad()
278
  @torch.inference_mode()
279
  def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]:
 
280
  k = min(k, self.embeddings.ntotal)
281
 
282
  if self.normalize:
@@ -292,7 +315,7 @@ class FaissDocumentIndex(BaseDocumentIndex):
292
  batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist()
293
  # Retrieve the passages corresponding to the indices
294
  batch_passages = [
295
- [self.documents.get_label_from_index(i) for i in indices]
296
  for indices in batch_top_k
297
  ]
298
  # build the output object
 
6
  from typing import Callable, List, Optional, Union
7
 
8
  import numpy
9
+ import psutil
10
  import torch
11
+ from relik.retriever.pytorch_modules import RetrievedSample
12
  from torch.utils.data import DataLoader
13
  from tqdm import tqdm
14
 
 
45
  embeddings: Optional[Union[torch.Tensor, numpy.ndarray]] = None,
46
  index=None,
47
  index_type: str = "Flat",
48
+ nprobe: int = 1,
49
  metric: int = faiss.METRIC_INNER_PRODUCT,
50
  normalize: bool = False,
51
  device: str = "cpu",
 
62
  "The number of documents and embeddings must be the same."
63
  )
64
 
65
+ faiss.omp_set_num_threads(psutil.cpu_count(logical=False))
66
+
67
  # device to store the embeddings
68
  self.device = device
69
 
 
87
  self.embeddings = self._build_faiss_index(
88
  embeddings=embeddings,
89
  index_type=index_type,
90
+ nprobe=nprobe,
91
  normalize=normalize,
92
  metric=metric,
93
  )
 
96
  self,
97
  embeddings: Optional[Union[torch.Tensor, numpy.ndarray]],
98
  index_type: str,
99
+ nprobe: int,
100
  normalize: bool,
101
  metric: int,
102
  ):
 
109
  if self.normalize:
110
  index_type = f"L2norm,{index_type}"
111
  faiss_vector_size = embeddings.shape[1]
112
+ # if self.device == "cpu":
113
+ # index_type = index_type.replace("x,", "x_HNSW32,")
114
+ # nlist = math.ceil(math.sqrt(faiss_vector_size)) * 4
115
+ # # nlist = 8
116
+ # index_type = index_type.replace(
117
+ # "x", str(nlist)
118
+ # )
119
+ # print("Current nlist:", nlist)
120
+ print("Current index:", index_type)
121
  self.embeddings = faiss.index_factory(faiss_vector_size, index_type, metric)
122
 
123
  # convert to GPU
 
131
  embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings
132
  )
133
 
134
+ self.embeddings.hnsw.efConstruction = 20
135
  # convert to float32 if embeddings is a torch.Tensor and is float16
136
  if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16:
137
  embeddings = embeddings.float()
138
 
139
+ logger.info("Training the index.")
140
+ self.embeddings.train(embeddings)
141
+
142
+ logger.info("Adding the embeddings to the index.")
143
  self.embeddings.add(embeddings)
144
 
145
+ self.embeddings.nprobe = nprobe
146
+
147
+ # self.embeddings.hnsw.efSearch
148
+ self.embeddings.hnsw.efSearch = 256
149
+
150
+ # self.embeddings.k_factor = 10
151
+
152
  # save parameters for saving/loading
153
  self.index_type = index_type
154
  self.metric = metric
 
299
  @torch.no_grad()
300
  @torch.inference_mode()
301
  def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]:
302
+
303
  k = min(k, self.embeddings.ntotal)
304
 
305
  if self.normalize:
 
315
  batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist()
316
  # Retrieve the passages corresponding to the indices
317
  batch_passages = [
318
+ [self.documents.get_label_from_index(i) for i in indices if i != -1]
319
  for indices in batch_top_k
320
  ]
321
  # build the output object
relik/retriever/indexers/inmemory.py CHANGED
@@ -67,6 +67,18 @@ class InMemoryDocumentIndex(BaseDocumentIndex):
67
  f"Converting to {PRECISION_MAP[precision]}."
68
  )
69
  self.embeddings = self.embeddings.to(PRECISION_MAP[precision])
 
 
 
 
 
 
 
 
 
 
 
 
70
  # move the embeddings to the desired device
71
  if self.embeddings is not None and not self.embeddings.device == device:
72
  self.embeddings = self.embeddings.to(device)
 
67
  f"Converting to {PRECISION_MAP[precision]}."
68
  )
69
  self.embeddings = self.embeddings.to(PRECISION_MAP[precision])
70
+ else:
71
+ if (
72
+ device == "cpu"
73
+ and self.embeddings is not None
74
+ and self.embeddings.dtype != torch.float32
75
+ ):
76
+ logger.info(
77
+ "Index vectors are of type {}. Converting to float32.".format(
78
+ self.embeddings.dtype
79
+ )
80
+ )
81
+ self.embeddings = self.embeddings.to(PRECISION_MAP[32])
82
  # move the embeddings to the desired device
83
  if self.embeddings is not None and not self.embeddings.device == device:
84
  self.embeddings = self.embeddings.to(device)
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  #------- Core dependencies -------
2
  torch>=2.0
3
- transformers[sentencepiece]>=4.34,<4.35
4
  rich>=13.0.0,<14.0.0
5
  scikit-learn
6
  overrides
 
1
  #------- Core dependencies -------
2
  torch>=2.0
3
+ transformers[sentencepiece]>=4.33,<4.34
4
  rich>=13.0.0,<14.0.0
5
  scikit-learn
6
  overrides
scripts/blink_freq.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import json
3
+
4
+ from tqdm import tqdm
5
+
6
+ if __name__ == "__main__":
7
+ counter = Counter()
8
+
9
+ with open("/media/data/EL/blink/train.alby-format.jsonl") as f_in:
10
+ for line in tqdm(f_in):
11
+ sample = json.loads(line)
12
+ for ss, se, label in sample["doc_annotations"]:
13
+ if label == "--NME--":
14
+ continue
15
+ counter.update([label])
16
+
17
+ with open("frequency_blink.txt", "w") as f_out:
18
+ for k, v in counter.most_common():
19
+ f_out.write(f"{k}\t{v}\n")
scripts/filter_docs.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import json
3
+ import torch
4
+
5
+ from tqdm import tqdm
6
+ from relik.retriever.data.labels import Labels
7
+
8
+ from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
9
+
10
+ if __name__ == "__main__":
11
+ with open("frequency_blink.txt") as f_in:
12
+ frequencies = [l.strip().split("\t")[0] for l in f_in.readlines()]
13
+
14
+ frequencies = set(frequencies[:1_000_000])
15
+
16
+ with open(
17
+ "/root/golden-retriever-v2/data/dpr-like/el/definitions_only_data.txt"
18
+ ) as f_in:
19
+ for line in f_in:
20
+ title = line.strip().split(" <def>")[0].strip()
21
+ frequencies.add(title)
22
+
23
+ document_index = InMemoryDocumentIndex.from_pretrained(
24
+ "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index",
25
+ )
26
+
27
+ new_doc_index = {}
28
+ new_embeddings = []
29
+
30
+ for i in range(document_index.documents.get_label_size()):
31
+ doc = document_index.documents.get_label_from_index(i)
32
+ title = doc.split(" <def>")[0].strip()
33
+ if title in frequencies:
34
+ new_doc_index[doc] = len(new_doc_index)
35
+ new_embeddings.append(document_index.embeddings[i])
36
+
37
+ print(len(new_doc_index))
38
+ print(len(new_embeddings))
39
+
40
+ new_embeddings = torch.stack(new_embeddings, dim=0)
41
+ new_embeddings = new_embeddings.to(torch.float16)
42
+
43
+ print(new_embeddings.shape)
44
+
45
+ new_label_index = Labels()
46
+ new_label_index.add_labels(new_doc_index)
47
+ new_document_index = InMemoryDocumentIndex(
48
+ documents=new_label_index,
49
+ embeddings=new_embeddings,
50
+ )
51
+
52
+ new_document_index.save_pretrained(
53
+ "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered"
54
+ )