Spaces:
Runtime error
Runtime error
import json | |
import os | |
import shutil | |
import sys | |
import tempfile | |
import unittest | |
from unittest import TestCase | |
from unittest.mock import patch | |
import faiss | |
import numpy as np | |
from datasets import Dataset | |
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig | |
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available | |
from transformers.integrations import is_ray_available | |
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES | |
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever | |
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES | |
from transformers.testing_utils import require_ray | |
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip | |
if is_torch_available(): | |
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip | |
else: | |
RagPyTorchDistributedRetriever = None | |
if is_ray_available(): | |
import ray # noqa: E402 # isort:skip | |
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever # noqa: E402 # isort:skip | |
else: | |
ray = None | |
RagRayDistributedRetriever = None | |
RayRetriever = None | |
def require_distributed_retrieval(test_case): | |
""" | |
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with | |
:class:`~transformers.RagRetriever`. | |
These tests are skipped when respective libraries are not installed. | |
""" | |
if not (is_datasets_available() and is_faiss_available() and is_psutil_available()): | |
test_case = unittest.skip("test requires Datasets, Faiss, psutil")(test_case) | |
return test_case | |
class RagRetrieverTest(TestCase): | |
def setUp(self): | |
self.tmpdirname = tempfile.mkdtemp() | |
self.retrieval_vector_size = 8 | |
# DPR tok | |
vocab_tokens = [ | |
"[UNK]", | |
"[CLS]", | |
"[SEP]", | |
"[PAD]", | |
"[MASK]", | |
"want", | |
"##want", | |
"##ed", | |
"wa", | |
"un", | |
"runn", | |
"##ing", | |
",", | |
"low", | |
"lowest", | |
] | |
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") | |
os.makedirs(dpr_tokenizer_path, exist_ok=True) | |
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) | |
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: | |
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) | |
# BART tok | |
vocab = [ | |
"l", | |
"o", | |
"w", | |
"e", | |
"r", | |
"s", | |
"t", | |
"i", | |
"d", | |
"n", | |
"\u0120", | |
"\u0120l", | |
"\u0120n", | |
"\u0120lo", | |
"\u0120low", | |
"er", | |
"\u0120lowest", | |
"\u0120newer", | |
"\u0120wider", | |
"<unk>", | |
] | |
vocab_tokens = dict(zip(vocab, range(len(vocab)))) | |
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] | |
self.special_tokens_map = {"unk_token": "<unk>"} | |
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") | |
os.makedirs(bart_tokenizer_path, exist_ok=True) | |
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) | |
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) | |
with open(self.vocab_file, "w", encoding="utf-8") as fp: | |
fp.write(json.dumps(vocab_tokens) + "\n") | |
with open(self.merges_file, "w", encoding="utf-8") as fp: | |
fp.write("\n".join(merges)) | |
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: | |
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer")) | |
def get_bart_tokenizer(self) -> BartTokenizer: | |
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer")) | |
def tearDown(self): | |
shutil.rmtree(self.tmpdirname) | |
def get_dummy_dataset(self): | |
dataset = Dataset.from_dict( | |
{ | |
"id": ["0", "1"], | |
"text": ["foo", "bar"], | |
"title": ["Foo", "Bar"], | |
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)], | |
} | |
) | |
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) | |
return dataset | |
def get_dummy_pytorch_distributed_retriever( | |
self, init_retrieval: bool, port=12345 | |
) -> RagPyTorchDistributedRetriever: | |
dataset = self.get_dummy_dataset() | |
config = RagConfig( | |
retrieval_vector_size=self.retrieval_vector_size, | |
question_encoder=DPRConfig().to_dict(), | |
generator=BartConfig().to_dict(), | |
) | |
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset: | |
mock_load_dataset.return_value = dataset | |
retriever = RagPyTorchDistributedRetriever( | |
config, | |
question_encoder_tokenizer=self.get_dpr_tokenizer(), | |
generator_tokenizer=self.get_bart_tokenizer(), | |
) | |
if init_retrieval: | |
retriever.init_retrieval(port) | |
return retriever | |
def get_dummy_ray_distributed_retriever(self, init_retrieval: bool) -> RagRayDistributedRetriever: | |
# Have to run in local mode because sys.path modifications at top of | |
# file are not propogated to remote workers. | |
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder | |
ray.init(local_mode=True) | |
config = RagConfig( | |
retrieval_vector_size=self.retrieval_vector_size, | |
question_encoder=DPRConfig().to_dict(), | |
generator=BartConfig().to_dict(), | |
) | |
remote_cls = ray.remote(RayRetriever) | |
workers = [remote_cls.remote() for _ in range(1)] | |
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset: | |
mock_load_dataset.return_value = self.get_dummy_dataset() | |
retriever = RagRayDistributedRetriever( | |
config, | |
question_encoder_tokenizer=self.get_dpr_tokenizer(), | |
generator_tokenizer=self.get_bart_tokenizer(), | |
retrieval_workers=workers, | |
) | |
if init_retrieval: | |
retriever.init_retrieval() | |
return retriever | |
def get_dummy_custom_hf_index_pytorch_retriever(self, init_retrieval: bool, from_disk: bool, port=12345): | |
dataset = self.get_dummy_dataset() | |
config = RagConfig( | |
retrieval_vector_size=self.retrieval_vector_size, | |
question_encoder=DPRConfig().to_dict(), | |
generator=BartConfig().to_dict(), | |
index_name="custom", | |
) | |
if from_disk: | |
config.passages_path = os.path.join(self.tmpdirname, "dataset") | |
config.index_path = os.path.join(self.tmpdirname, "index.faiss") | |
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss")) | |
dataset.drop_index("embeddings") | |
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) | |
del dataset | |
retriever = RagPyTorchDistributedRetriever( | |
config, | |
question_encoder_tokenizer=self.get_dpr_tokenizer(), | |
generator_tokenizer=self.get_bart_tokenizer(), | |
) | |
else: | |
retriever = RagPyTorchDistributedRetriever( | |
config, | |
question_encoder_tokenizer=self.get_dpr_tokenizer(), | |
generator_tokenizer=self.get_bart_tokenizer(), | |
index=CustomHFIndex(config.retrieval_vector_size, dataset), | |
) | |
if init_retrieval: | |
retriever.init_retrieval(port) | |
return retriever | |
def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool): | |
# Have to run in local mode because sys.path modifications at top of | |
# file are not propogated to remote workers. | |
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder | |
ray.init(local_mode=True) | |
dataset = self.get_dummy_dataset() | |
config = RagConfig( | |
retrieval_vector_size=self.retrieval_vector_size, | |
question_encoder=DPRConfig().to_dict(), | |
generator=BartConfig().to_dict(), | |
index_name="custom", | |
) | |
remote_cls = ray.remote(RayRetriever) | |
workers = [remote_cls.remote() for _ in range(1)] | |
if from_disk: | |
config.passages_path = os.path.join(self.tmpdirname, "dataset") | |
config.index_path = os.path.join(self.tmpdirname, "index.faiss") | |
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss")) | |
dataset.drop_index("embeddings") | |
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) | |
del dataset | |
retriever = RagRayDistributedRetriever( | |
config, | |
question_encoder_tokenizer=self.get_dpr_tokenizer(), | |
generator_tokenizer=self.get_bart_tokenizer(), | |
retrieval_workers=workers, | |
index=CustomHFIndex.load_from_disk( | |
vector_size=config.retrieval_vector_size, | |
dataset_path=config.passages_path, | |
index_path=config.index_path, | |
), | |
) | |
else: | |
retriever = RagRayDistributedRetriever( | |
config, | |
question_encoder_tokenizer=self.get_dpr_tokenizer(), | |
generator_tokenizer=self.get_bart_tokenizer(), | |
retrieval_workers=workers, | |
index=CustomHFIndex(config.retrieval_vector_size, dataset), | |
) | |
if init_retrieval: | |
retriever.init_retrieval() | |
return retriever | |
def distributed_retriever_check(self, retriever: RagRetriever, hidden_states: np.array, n_docs: int) -> None: | |
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) | |
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) | |
self.assertEqual(len(doc_dicts), 2) | |
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"]) | |
self.assertEqual(len(doc_dicts[0]["id"]), n_docs) | |
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc | |
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc | |
self.assertListEqual(doc_ids.tolist(), [[1], [0]]) | |
def test_pytorch_distributed_retriever_retrieve(self): | |
n_docs = 1 | |
hidden_states = np.array( | |
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 | |
) | |
self.distributed_retriever_check( | |
self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs | |
) | |
def test_custom_hf_index_pytorch_retriever_retrieve(self): | |
n_docs = 1 | |
hidden_states = np.array( | |
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 | |
) | |
self.distributed_retriever_check( | |
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=False), | |
hidden_states, | |
n_docs, | |
) | |
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self): | |
n_docs = 1 | |
hidden_states = np.array( | |
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 | |
) | |
self.distributed_retriever_check( | |
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=True), | |
hidden_states, | |
n_docs, | |
) | |
def test_ray_distributed_retriever_retrieve(self): | |
n_docs = 1 | |
hidden_states = np.array( | |
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 | |
) | |
self.distributed_retriever_check( | |
self.get_dummy_ray_distributed_retriever(init_retrieval=True), hidden_states, n_docs | |
) | |
ray.shutdown() | |
def test_custom_hf_index_ray_retriever_retrieve(self): | |
n_docs = 1 | |
hidden_states = np.array( | |
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 | |
) | |
with self.assertRaises(ValueError): | |
self.distributed_retriever_check( | |
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=False), | |
hidden_states, | |
n_docs, | |
) | |
ray.shutdown() | |
def test_custom_ray_distributed_retriever_retrieve_from_disk(self): | |
n_docs = 1 | |
hidden_states = np.array( | |
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 | |
) | |
self.distributed_retriever_check( | |
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=True), hidden_states, n_docs | |
) | |
ray.shutdown() | |