File size: 1,982 Bytes
6bf4ad7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from datasets import load_dataset
from transformers import DPRContextEncoderTokenizer, DPRContextEncoder
from general_utils import embed_passages, embed_passages_haystack
import faiss
import argparse
import os
from haystack.nodes import DensePassageRetriever
from haystack.document_stores import InMemoryDocumentStore


os.environ["OMP_NUM_THREADS"] = "8"


def create_faiss_index(args):
    minchars = 200
    dims = 128

    dpr = DensePassageRetriever(
        document_store=InMemoryDocumentStore(),
        query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
        passage_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
        max_seq_len_query=64,
        max_seq_len_passage=256,
        batch_size=512,
    )

    dataset = load_dataset(
        "IIC/spanish_biomedical_crawled_corpus", split="train"
    )

    dataset = dataset.filter(lambda example: len(example["text"]) > minchars)

    def embed_passages_retrieval(examples):
        return embed_passages_haystack(dpr, examples)

    dataset = dataset.map(embed_passages_retrieval, batched=True, batch_size=8192)

    dataset.add_faiss_index(
        column="embeddings",
        string_factory="OPQ64_128,IVF4898,PQ64x4fsr",
        train_size=len(dataset),
    )
    dataset.save_faiss_index("embeddings", args.index_file_name)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file")

    parser.add_argument(
        "--ctx_encoder_name",
        default="IIC/dpr-spanish-passage_encoder-squades-base",
        help="Encoding model to use for passage encoding",
    )

    parser.add_argument(
        "--index_file_name",
        default="dpr_index_bio_splitted.faiss",
        help="Faiss index file with passage embeddings",
    )
    parser.add_argument(
        "--device", default="cuda:0", help="The device to index data on."
    )

    main_args, _ = parser.parse_known_args()
    create_faiss_index(main_args)