Spaces:
Sleeping
Sleeping
import os | |
from pathlib import Path | |
from typing import Any, Callable, Dict, Optional, Union | |
import hydra | |
from omegaconf import OmegaConf | |
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel | |
from rich.pretty import pprint | |
from relik.common.log import get_console_logger, get_logger | |
from relik.common.upload import upload | |
from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string | |
from relik.inference.data.objects import EntitySpan, RelikOutput | |
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer | |
from relik.inference.data.window.manager import WindowManager | |
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction | |
from relik.reader.relik_reader import RelikReader | |
from relik.retriever.data.utils import batch_generator | |
from relik.retriever.indexers.base import BaseDocumentIndex | |
from relik.retriever.pytorch_modules.model import GoldenRetriever | |
logger = get_logger(__name__) | |
console_logger = get_console_logger() | |
class Relik: | |
""" | |
Relik main class. It is a wrapper around a retriever and a reader. | |
Args: | |
retriever (`Optional[GoldenRetriever]`, `optional`): | |
The retriever to use. If `None`, a retriever will be instantiated from the | |
provided `question_encoder`, `passage_encoder` and `document_index`. | |
Defaults to `None`. | |
question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): | |
The question encoder to use. If `retriever` is `None`, a retriever will be | |
instantiated from this parameter. Defaults to `None`. | |
passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): | |
The passage encoder to use. If `retriever` is `None`, a retriever will be | |
instantiated from this parameter. Defaults to `None`. | |
document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`): | |
The document index to use. If `retriever` is `None`, a retriever will be | |
instantiated from this parameter. Defaults to `None`. | |
reader (`Optional[Union[str, RelikReader]]`, `optional`): | |
The reader to use. If `None`, a reader will be instantiated from the | |
provided `reader`. Defaults to `None`. | |
retriever_device (`str`, `optional`, defaults to `cpu`): | |
The device to use for the retriever. | |
""" | |
def __init__( | |
self, | |
retriever: GoldenRetriever | None = None, | |
question_encoder: str | GoldenRetrieverModel | None = None, | |
passage_encoder: str | GoldenRetrieverModel | None = None, | |
document_index: str | BaseDocumentIndex | None = None, | |
reader: str | RelikReader | None = None, | |
device: str = "cpu", | |
retriever_device: str | None = None, | |
document_index_device: str | None = None, | |
reader_device: str | None = None, | |
precision: int = 32, | |
retriever_precision: int | None = None, | |
document_index_precision: int | None = None, | |
reader_precision: int | None = None, | |
reader_kwargs: dict | None = None, | |
retriever_kwargs: dict | None = None, | |
candidates_preprocessing_fn: str | Callable | None = None, | |
top_k: int | None = None, | |
window_size: int | None = None, | |
window_stride: int | None = None, | |
**kwargs, | |
) -> None: | |
# retriever | |
retriever_device = retriever_device or device | |
document_index_device = document_index_device or device | |
retriever_precision = retriever_precision or precision | |
document_index_precision = document_index_precision or precision | |
if retriever is None and question_encoder is None: | |
raise ValueError( | |
"Either `retriever` or `question_encoder` must be provided" | |
) | |
if retriever is None: | |
self.retriever_kwargs = dict( | |
question_encoder=question_encoder, | |
passage_encoder=passage_encoder, | |
document_index=document_index, | |
device=retriever_device, | |
precision=retriever_precision, | |
index_device=document_index_device, | |
index_precision=document_index_precision, | |
) | |
# overwrite default_retriever_kwargs with retriever_kwargs | |
self.retriever_kwargs.update(retriever_kwargs or {}) | |
retriever = GoldenRetriever(**self.retriever_kwargs) | |
retriever.training = False | |
retriever.eval() | |
self.retriever = retriever | |
# reader | |
self.reader_device = reader_device or device | |
self.reader_precision = reader_precision or precision | |
self.reader_kwargs = reader_kwargs | |
if isinstance(reader, str): | |
reader_kwargs = reader_kwargs or {} | |
reader = RelikReaderForSpanExtraction(reader, **reader_kwargs) | |
self.reader = reader | |
# windowization stuff | |
self.tokenizer = SpacyTokenizer(language="en") | |
self.window_manager: WindowManager | None = None | |
# candidates preprocessing | |
# TODO: maybe move this logic somewhere else | |
candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x) | |
if isinstance(candidates_preprocessing_fn, str): | |
candidates_preprocessing_fn = get_callable_from_string( | |
candidates_preprocessing_fn | |
) | |
self.candidates_preprocessing_fn = candidates_preprocessing_fn | |
# inference params | |
self.top_k = top_k | |
self.window_size = window_size | |
self.window_stride = window_stride | |
def __call__( | |
self, | |
text: Union[str, list], | |
top_k: Optional[int] = None, | |
window_size: Optional[int] = None, | |
window_stride: Optional[int] = None, | |
retriever_batch_size: Optional[int] = 32, | |
reader_batch_size: Optional[int] = 32, | |
return_also_windows: bool = False, | |
**kwargs, | |
) -> Union[RelikOutput, list[RelikOutput]]: | |
""" | |
Annotate a text with entities. | |
Args: | |
text (`str` or `list`): | |
The text to annotate. If a list is provided, each element of the list | |
will be annotated separately. | |
top_k (`int`, `optional`, defaults to `None`): | |
The number of candidates to retrieve for each window. | |
window_size (`int`, `optional`, defaults to `None`): | |
The size of the window. If `None`, the whole text will be annotated. | |
window_stride (`int`, `optional`, defaults to `None`): | |
The stride of the window. If `None`, there will be no overlap between windows. | |
retriever_batch_size (`int`, `optional`, defaults to `None`): | |
The batch size to use for the retriever. The whole input is the batch for the retriever. | |
reader_batch_size (`int`, `optional`, defaults to `None`): | |
The batch size to use for the reader. The whole input is the batch for the reader. | |
return_also_windows (`bool`, `optional`, defaults to `False`): | |
Whether to return the windows in the output. | |
**kwargs: | |
Additional keyword arguments to pass to the retriever and the reader. | |
Returns: | |
`RelikOutput` or `list[RelikOutput]`: | |
The annotated text. If a list was provided as input, a list of | |
`RelikOutput` objects will be returned. | |
""" | |
if top_k is None: | |
top_k = self.top_k or 100 | |
if window_size is None: | |
window_size = self.window_size | |
if window_stride is None: | |
window_stride = self.window_stride | |
if isinstance(text, str): | |
text = [text] | |
if window_size is not None: | |
if self.window_manager is None: | |
self.window_manager = WindowManager(self.tokenizer) | |
if window_size == "sentence": | |
# todo: implement sentence windowizer | |
raise NotImplementedError("Sentence windowizer not implemented yet") | |
# if window_size < window_stride: | |
# raise ValueError( | |
# f"Window size ({window_size}) must be greater than window stride ({window_stride})" | |
# ) | |
# window generator | |
windows = [ | |
window | |
for doc_id, t in enumerate(text) | |
for window in self.window_manager.create_windows( | |
t, | |
window_size=window_size, | |
stride=window_stride, | |
doc_id=doc_id, | |
) | |
] | |
# retrieve candidates first | |
windows_candidates = [] | |
# TODO: Move batching inside retriever | |
for batch in batch_generator(windows, batch_size=retriever_batch_size): | |
retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k) | |
windows_candidates.extend( | |
[[p.label for p in predictions] for predictions in retriever_out] | |
) | |
# add passage to the windows | |
for window, candidates in zip(windows, windows_candidates): | |
window.window_candidates = [ | |
self.candidates_preprocessing_fn(c) for c in candidates | |
] | |
windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size) | |
windows = self.window_manager.merge_windows(windows) | |
# transform predictions into RelikOutput objects | |
output = [] | |
for w in windows: | |
sample_output = RelikOutput( | |
text=text[w.doc_id], | |
labels=sorted( | |
[ | |
EntitySpan( | |
start=ss, end=se, label=sl, text=text[w.doc_id][ss:se] | |
) | |
for ss, se, sl in w.predicted_window_labels_chars | |
], | |
key=lambda x: x.start, | |
), | |
) | |
output.append(sample_output) | |
if return_also_windows: | |
for i, sample_output in enumerate(output): | |
sample_output.windows = [w for w in windows if w.doc_id == i] | |
# if only one text was provided, return a single RelikOutput object | |
if len(output) == 1: | |
return output[0] | |
return output | |
def from_pretrained( | |
cls, | |
model_name_or_dir: Union[str, os.PathLike], | |
config_kwargs: Optional[Dict] = None, | |
config_file_name: str = CONFIG_NAME, | |
*args, | |
**kwargs, | |
) -> "Relik": | |
cache_dir = kwargs.pop("cache_dir", None) | |
force_download = kwargs.pop("force_download", False) | |
model_dir = from_cache( | |
model_name_or_dir, | |
filenames=[config_file_name], | |
cache_dir=cache_dir, | |
force_download=force_download, | |
) | |
config_path = model_dir / config_file_name | |
if not config_path.exists(): | |
raise FileNotFoundError( | |
f"Model configuration file not found at {config_path}." | |
) | |
# overwrite config with config_kwargs | |
config = OmegaConf.load(config_path) | |
if config_kwargs is not None: | |
# TODO: check merging behavior | |
config = OmegaConf.merge(config, OmegaConf.create(config_kwargs)) | |
# do we want to print the config? I like it | |
pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) | |
# load relik from config | |
relik = hydra.utils.instantiate(config, *args, **kwargs) | |
return relik | |
def save_pretrained( | |
self, | |
output_dir: Union[str, os.PathLike], | |
config: Optional[Dict[str, Any]] = None, | |
config_file_name: Optional[str] = None, | |
save_weights: bool = False, | |
push_to_hub: bool = False, | |
model_id: Optional[str] = None, | |
organization: Optional[str] = None, | |
repo_name: Optional[str] = None, | |
**kwargs, | |
): | |
""" | |
Save the configuration of Relik to the specified directory as a YAML file. | |
Args: | |
output_dir (`str`): | |
The directory to save the configuration file to. | |
config (`Optional[Dict[str, Any]]`, `optional`): | |
The configuration to save. If `None`, the current configuration will be | |
saved. Defaults to `None`. | |
config_file_name (`Optional[str]`, `optional`): | |
The name of the configuration file. Defaults to `config.yaml`. | |
save_weights (`bool`, `optional`): | |
Whether to save the weights of the model. Defaults to `False`. | |
push_to_hub (`bool`, `optional`): | |
Whether to push the saved model to the hub. Defaults to `False`. | |
model_id (`Optional[str]`, `optional`): | |
The id of the model to push to the hub. If `None`, the name of the | |
directory will be used. Defaults to `None`. | |
organization (`Optional[str]`, `optional`): | |
The organization to push the model to. Defaults to `None`. | |
repo_name (`Optional[str]`, `optional`): | |
The name of the repository to push the model to. Defaults to `None`. | |
**kwargs: | |
Additional keyword arguments to pass to `OmegaConf.save`. | |
""" | |
if config is None: | |
# create a default config | |
config = { | |
"_target_": f"{self.__class__.__module__}.{self.__class__.__name__}" | |
} | |
if self.retriever is not None: | |
if self.retriever.question_encoder is not None: | |
config[ | |
"question_encoder" | |
] = self.retriever.question_encoder.name_or_path | |
if self.retriever.passage_encoder is not None: | |
config[ | |
"passage_encoder" | |
] = self.retriever.passage_encoder.name_or_path | |
if self.retriever.document_index is not None: | |
config["document_index"] = self.retriever.document_index.name_or_dir | |
if self.reader is not None: | |
config["reader"] = self.reader.model_path | |
config["retriever_kwargs"] = self.retriever_kwargs | |
config["reader_kwargs"] = self.reader_kwargs | |
# expand the fn as to be able to save it and load it later | |
config[ | |
"candidates_preprocessing_fn" | |
] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}" | |
# these are model-specific and should be saved | |
config["top_k"] = self.top_k | |
config["window_size"] = self.window_size | |
config["window_stride"] = self.window_stride | |
config_file_name = config_file_name or CONFIG_NAME | |
# create the output directory | |
output_dir = Path(output_dir) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
logger.info(f"Saving relik config to {output_dir / config_file_name}") | |
# pretty print the config | |
pprint(config, console=console_logger, expand_all=True) | |
OmegaConf.save(config, output_dir / config_file_name) | |
if save_weights: | |
model_id = model_id or output_dir.name | |
retriever_model_id = model_id + "-retriever" | |
# save weights | |
logger.info(f"Saving retriever to {output_dir / retriever_model_id}") | |
self.retriever.save_pretrained( | |
output_dir / retriever_model_id, | |
question_encoder_name=retriever_model_id + "-question-encoder", | |
passage_encoder_name=retriever_model_id + "-passage-encoder", | |
document_index_name=retriever_model_id + "-index", | |
push_to_hub=push_to_hub, | |
organization=organization, | |
repo_name=repo_name, | |
**kwargs, | |
) | |
reader_model_id = model_id + "-reader" | |
logger.info(f"Saving reader to {output_dir / reader_model_id}") | |
self.reader.save_pretrained( | |
output_dir / reader_model_id, | |
push_to_hub=push_to_hub, | |
organization=organization, | |
repo_name=repo_name, | |
**kwargs, | |
) | |
if push_to_hub: | |
# push to hub | |
logger.info(f"Pushing to hub") | |
model_id = model_id or output_dir.name | |
upload(output_dir, model_id, organization=organization, repo_name=repo_name) | |
def main(): | |
from pprint import pprint | |
relik = Relik( | |
question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", | |
document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", | |
reader="riccorl/relik-reader-aida-deberta-small", | |
device="cuda", | |
precision=16, | |
top_k=100, | |
window_size=32, | |
window_stride=16, | |
candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing", | |
) | |
input_text = """ | |
Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore. | |
The 92-year-old billionaire did not disclose the trust to the government in July 2015. | |
Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty. | |
Ecclestone had been due to go on trial next month. | |
""" | |
preds = relik(input_text) | |
pprint(preds) | |
if __name__ == "__main__": | |
main() | |