Spaces:
Sleeping
Sleeping
import collections | |
import contextlib | |
import logging | |
from typing import Any, Dict, Iterator, List | |
import torch | |
import transformers as tr | |
from lightning_fabric.utilities import move_data_to_device | |
from torch.utils.data import DataLoader, IterableDataset | |
from tqdm import tqdm | |
from relik.common.log import get_console_logger, get_logger | |
from relik.common.utils import get_callable_from_string | |
from relik.reader.data.relik_reader_sample import RelikReaderSample | |
from relik.reader.pytorch_modules.base import RelikReaderBase | |
from relik.reader.utils.special_symbols import get_special_symbols | |
from relik.retriever.pytorch_modules import PRECISION_MAP | |
console_logger = get_console_logger() | |
logger = get_logger(__name__, level=logging.INFO) | |
class RelikReaderForSpanExtraction(RelikReaderBase): | |
""" | |
A class for the RelikReader model for span extraction. | |
Args: | |
transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): | |
The transformer model to use. If `None`, the default model is used. | |
additional_special_symbols (:obj:`int`, `optional`, defaults to 0): | |
The number of additional special symbols to add to the tokenizer. | |
num_layers (:obj:`int`, `optional`): | |
The number of layers to use. If `None`, all layers are used. | |
activation (:obj:`str`, `optional`, defaults to "gelu"): | |
The activation function to use. | |
linears_hidden_size (:obj:`int`, `optional`, defaults to 512): | |
The hidden size of the linears. | |
use_last_k_layers (:obj:`int`, `optional`, defaults to 1): | |
The number of last layers to use. | |
training (:obj:`bool`, `optional`, defaults to False): | |
Whether the model is in training mode. | |
device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`): | |
The device to use. If `None`, the default device is used. | |
tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`): | |
The tokenizer to use. If `None`, the default tokenizer is used. | |
dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`): | |
The dataset to use. If `None`, the default dataset is used. | |
dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`): | |
The keyword arguments to pass to the dataset class. | |
default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): | |
The default reader class to use. If `None`, the default reader class is used. | |
**kwargs: | |
Keyword arguments. | |
""" | |
default_reader_class: str = ( | |
"relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderSpanModel" | |
) | |
default_data_class: str = "relik.reader.data.relik_reader_data.RelikDataset" | |
def __init__( | |
self, | |
transformer_model: str | tr.PreTrainedModel | None = None, | |
additional_special_symbols: int = 0, | |
num_layers: int | None = None, | |
activation: str = "gelu", | |
linears_hidden_size: int | None = 512, | |
use_last_k_layers: int = 1, | |
training: bool = False, | |
device: str | torch.device | None = None, | |
tokenizer: str | tr.PreTrainedTokenizer | None = None, | |
dataset: IterableDataset | str | None = None, | |
dataset_kwargs: Dict[str, Any] | None = None, | |
default_reader_class: tr.PreTrainedModel | str | None = None, | |
**kwargs, | |
): | |
super().__init__( | |
transformer_model=transformer_model, | |
additional_special_symbols=additional_special_symbols, | |
num_layers=num_layers, | |
activation=activation, | |
linears_hidden_size=linears_hidden_size, | |
use_last_k_layers=use_last_k_layers, | |
training=training, | |
device=device, | |
tokenizer=tokenizer, | |
dataset=dataset, | |
default_reader_class=default_reader_class, | |
**kwargs, | |
) | |
# and instantiate the dataset class | |
self.dataset = dataset | |
if self.dataset is None: | |
default_data_kwargs = dict( | |
dataset_path=None, | |
materialize_samples=False, | |
transformer_model=self.tokenizer, | |
special_symbols=get_special_symbols( | |
self.relik_reader_model.config.additional_special_symbols | |
), | |
for_inference=True, | |
) | |
# merge the default data kwargs with the ones passed to the model | |
default_data_kwargs.update(dataset_kwargs or {}) | |
self.dataset = get_callable_from_string(self.default_data_class)( | |
**default_data_kwargs | |
) | |
def _read( | |
self, | |
samples: List[RelikReaderSample] | None = None, | |
input_ids: torch.Tensor | None = None, | |
attention_mask: torch.Tensor | None = None, | |
token_type_ids: torch.Tensor | None = None, | |
prediction_mask: torch.Tensor | None = None, | |
special_symbols_mask: torch.Tensor | None = None, | |
max_length: int = 1000, | |
max_batch_size: int = 128, | |
token_batch_size: int = 2048, | |
precision: str = 32, | |
annotation_type: str = "char", | |
progress_bar: bool = False, | |
*args: object, | |
**kwargs: object, | |
) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: | |
""" | |
A wrapper around the forward method that returns the predicted labels for each sample. | |
Args: | |
samples (:obj:`List[RelikReaderSample]`, `optional`): | |
The samples to read. If provided, `text` and `candidates` are ignored. | |
input_ids (:obj:`torch.Tensor`, `optional`): | |
The input ids of the text. If `samples` is provided, this is ignored. | |
attention_mask (:obj:`torch.Tensor`, `optional`): | |
The attention mask of the text. If `samples` is provided, this is ignored. | |
token_type_ids (:obj:`torch.Tensor`, `optional`): | |
The token type ids of the text. If `samples` is provided, this is ignored. | |
prediction_mask (:obj:`torch.Tensor`, `optional`): | |
The prediction mask of the text. If `samples` is provided, this is ignored. | |
special_symbols_mask (:obj:`torch.Tensor`, `optional`): | |
The special symbols mask of the text. If `samples` is provided, this is ignored. | |
max_length (:obj:`int`, `optional`, defaults to 1000): | |
The maximum length of the text. | |
max_batch_size (:obj:`int`, `optional`, defaults to 128): | |
The maximum batch size. | |
token_batch_size (:obj:`int`, `optional`): | |
The token batch size. | |
progress_bar (:obj:`bool`, `optional`, defaults to False): | |
Whether to show a progress bar. | |
precision (:obj:`str`, `optional`, defaults to 32): | |
The precision to use for the model. | |
annotation_type (:obj:`str`, `optional`, defaults to "char"): | |
The annotation type to use. It can be either "char", "token" or "word". | |
*args: | |
Positional arguments. | |
**kwargs: | |
Keyword arguments. | |
Returns: | |
:obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`: | |
The predicted labels for each sample. | |
""" | |
precision = precision or self.precision | |
if samples is not None: | |
def _read_iterator(): | |
def samples_it(): | |
for i, sample in enumerate(samples): | |
assert sample._mixin_prediction_position is None | |
sample._mixin_prediction_position = i | |
yield sample | |
next_prediction_position = 0 | |
position2predicted_sample = {} | |
# instantiate dataset | |
if self.dataset is None: | |
raise ValueError( | |
"You need to pass a dataset to the model in order to predict" | |
) | |
self.dataset.samples = samples_it() | |
self.dataset.model_max_length = max_length | |
self.dataset.tokens_per_batch = token_batch_size | |
self.dataset.max_batch_size = max_batch_size | |
# instantiate dataloader | |
iterator = DataLoader( | |
self.dataset, batch_size=None, num_workers=0, shuffle=False | |
) | |
if progress_bar: | |
iterator = tqdm(iterator, desc="Predicting with RelikReader") | |
# fucking autocast only wants pure strings like 'cpu' or 'cuda' | |
# we need to convert the model device to that | |
device_type_for_autocast = str(self.device).split(":")[0] | |
# autocast doesn't work with CPU and stuff different from bfloat16 | |
autocast_mngr = ( | |
contextlib.nullcontext() | |
if device_type_for_autocast == "cpu" | |
else ( | |
torch.autocast( | |
device_type=device_type_for_autocast, | |
dtype=PRECISION_MAP[precision], | |
) | |
) | |
) | |
with autocast_mngr: | |
for batch in iterator: | |
batch = move_data_to_device(batch, self.device) | |
batch_out = self._batch_predict(**batch) | |
for sample in batch_out: | |
if ( | |
sample._mixin_prediction_position | |
>= next_prediction_position | |
): | |
position2predicted_sample[ | |
sample._mixin_prediction_position | |
] = sample | |
# yield | |
while next_prediction_position in position2predicted_sample: | |
yield position2predicted_sample[next_prediction_position] | |
del position2predicted_sample[next_prediction_position] | |
next_prediction_position += 1 | |
outputs = list(_read_iterator()) | |
for sample in outputs: | |
self.dataset.merge_patches_predictions(sample) | |
self.dataset.convert_tokens_to_char_annotations(sample) | |
else: | |
outputs = list( | |
self._batch_predict( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
prediction_mask, | |
special_symbols_mask, | |
*args, | |
**kwargs, | |
) | |
) | |
return outputs | |
def _batch_predict( | |
self, | |
input_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
token_type_ids: torch.Tensor | None = None, | |
prediction_mask: torch.Tensor | None = None, | |
special_symbols_mask: torch.Tensor | None = None, | |
sample: List[RelikReaderSample] | None = None, | |
top_k: int = 5, # the amount of top-k most probable entities to predict | |
*args, | |
**kwargs, | |
) -> Iterator[RelikReaderSample]: | |
""" | |
A wrapper around the forward method that returns the predicted labels for each sample. | |
It also adds the predicted labels to the samples. | |
Args: | |
input_ids (:obj:`torch.Tensor`): | |
The input ids of the text. | |
attention_mask (:obj:`torch.Tensor`): | |
The attention mask of the text. | |
token_type_ids (:obj:`torch.Tensor`, `optional`): | |
The token type ids of the text. | |
prediction_mask (:obj:`torch.Tensor`, `optional`): | |
The prediction mask of the text. | |
special_symbols_mask (:obj:`torch.Tensor`, `optional`): | |
The special symbols mask of the text. | |
sample (:obj:`List[RelikReaderSample]`, `optional`): | |
The samples to read. If provided, `text` and `candidates` are ignored. | |
top_k (:obj:`int`, `optional`, defaults to 5): | |
The amount of top-k most probable entities to predict. | |
*args: | |
Positional arguments. | |
**kwargs: | |
Keyword arguments. | |
Returns: | |
The predicted labels for each sample. | |
""" | |
forward_output = self.forward( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
prediction_mask=prediction_mask, | |
special_symbols_mask=special_symbols_mask, | |
) | |
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() | |
ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() | |
ed_predictions = forward_output["ed_predictions"].cpu().numpy() | |
ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() | |
batch_predictable_candidates = kwargs["predictable_candidates"] | |
patch_offset = kwargs["patch_offset"] | |
for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( | |
sample, | |
ned_start_predictions, | |
ned_end_predictions, | |
ed_predictions, | |
ed_probabilities, | |
batch_predictable_candidates, | |
patch_offset, | |
): | |
ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] | |
ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] | |
final_class2predicted_spans = collections.defaultdict(list) | |
spans2predicted_probabilities = dict() | |
for start_token_index, end_token_index in zip( | |
ne_start_indices, ne_end_indices | |
): | |
# predicted candidate | |
token_class = edp[start_token_index + 1] - 1 | |
predicted_candidate_title = pred_cands[token_class] | |
final_class2predicted_spans[predicted_candidate_title].append( | |
[start_token_index, end_token_index] | |
) | |
# candidates probabilities | |
classes_probabilities = edpr[start_token_index + 1] | |
classes_probabilities_best_indices = classes_probabilities.argsort()[ | |
::-1 | |
] | |
titles_2_probs = [] | |
top_k = ( | |
min( | |
top_k, | |
len(classes_probabilities_best_indices), | |
) | |
if top_k != -1 | |
else len(classes_probabilities_best_indices) | |
) | |
for i in range(top_k): | |
titles_2_probs.append( | |
( | |
pred_cands[classes_probabilities_best_indices[i] - 1], | |
classes_probabilities[ | |
classes_probabilities_best_indices[i] | |
].item(), | |
) | |
) | |
spans2predicted_probabilities[ | |
(start_token_index, end_token_index) | |
] = titles_2_probs | |
if "patches" not in ts._d: | |
ts._d["patches"] = dict() | |
ts._d["patches"][po] = dict() | |
sample_patch = ts._d["patches"][po] | |
sample_patch["predicted_window_labels"] = final_class2predicted_spans | |
sample_patch["span_title_probabilities"] = spans2predicted_probabilities | |
# additional info | |
sample_patch["predictable_candidates"] = pred_cands | |
yield ts | |