Spaces:
Sleeping
Sleeping
import logging | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Generator, | |
Iterator, | |
List, | |
NamedTuple, | |
Optional, | |
Tuple, | |
Union, | |
) | |
import numpy as np | |
import torch | |
from reader.data.relik_reader_data_utils import ( | |
add_noise_to_value, | |
batchify, | |
batchify_matrices, | |
batchify_tensor, | |
chunks, | |
flatten, | |
) | |
from reader.data.relik_reader_sample import RelikReaderSample, load_relik_reader_samples | |
from torch.utils.data import IterableDataset | |
from transformers import AutoTokenizer | |
from relik.reader.utils.special_symbols import NME_SYMBOL | |
logger = logging.getLogger(__name__) | |
class TokenizationOutput(NamedTuple): | |
input_ids: torch.Tensor | |
attention_mask: torch.Tensor | |
token_type_ids: torch.Tensor | |
prediction_mask: torch.Tensor | |
special_symbols_mask: torch.Tensor | |
special_symbols_mask_entities: torch.Tensor | |
class RelikREDataset(IterableDataset): | |
def __init__( | |
self, | |
dataset_path: str, | |
materialize_samples: bool, | |
transformer_model: str, | |
special_symbols: List[str], | |
shuffle_candidates: Optional[Union[bool, float]], | |
flip_candidates: Optional[Union[bool, float]], | |
relations_definitions: Union[str, Dict[str, str]], | |
for_inference: bool, | |
entities_definitions: Optional[Union[str, Dict[str, str]]] = None, | |
special_symbols_entities: Optional[List[str]] = None, | |
noise_param: float = 0.1, | |
sorting_fields: Optional[str] = None, | |
tokens_per_batch: int = 2048, | |
batch_size: int = None, | |
max_batch_size: int = 128, | |
section_size: int = 50_000, | |
prebatch: bool = True, | |
max_candidates: int = 0, | |
add_gold_candidates: bool = True, | |
use_nme: bool = True, | |
min_length: int = 5, | |
max_length: int = 2048, | |
model_max_length: int = 1000, | |
skip_empty_training_samples: bool = True, | |
drop_last: bool = False, | |
samples: Optional[Iterator[RelikReaderSample]] = None, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.dataset_path = dataset_path | |
self.materialize_samples = materialize_samples | |
self.samples: Optional[List[RelikReaderSample]] = None | |
if self.materialize_samples: | |
self.samples = list() | |
self.tokenizer = self._build_tokenizer(transformer_model, special_symbols) | |
self.special_symbols = special_symbols | |
self.special_symbols_entities = special_symbols_entities | |
self.shuffle_candidates = shuffle_candidates | |
self.flip_candidates = flip_candidates | |
self.for_inference = for_inference | |
self.noise_param = noise_param | |
self.batching_fields = ["input_ids"] | |
self.sorting_fields = ( | |
sorting_fields if sorting_fields is not None else self.batching_fields | |
) | |
# open relations definitions file if needed | |
if type(relations_definitions) == str: | |
relations_definitions = { | |
line.split("\t")[0]: line.split("\t")[1] | |
for line in open(relations_definitions) | |
} | |
self.max_candidates = max_candidates | |
self.relations_definitions = relations_definitions | |
self.entities_definitions = entities_definitions | |
self.add_gold_candidates = add_gold_candidates | |
self.use_nme = use_nme | |
self.min_length = min_length | |
self.max_length = max_length | |
self.model_max_length = ( | |
model_max_length | |
if model_max_length < self.tokenizer.model_max_length | |
else self.tokenizer.model_max_length | |
) | |
self.transformer_model = transformer_model | |
self.skip_empty_training_samples = skip_empty_training_samples | |
self.drop_last = drop_last | |
self.samples = samples | |
self.tokens_per_batch = tokens_per_batch | |
self.batch_size = batch_size | |
self.max_batch_size = max_batch_size | |
self.section_size = section_size | |
self.prebatch = prebatch | |
def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]): | |
return AutoTokenizer.from_pretrained( | |
transformer_model, | |
additional_special_tokens=[ss for ss in special_symbols], | |
add_prefix_space=True, | |
) | |
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: | |
fields_batchers = { | |
"input_ids": lambda x: batchify( | |
x, padding_value=self.tokenizer.pad_token_id | |
), | |
"attention_mask": lambda x: batchify(x, padding_value=0), | |
"token_type_ids": lambda x: batchify(x, padding_value=0), | |
"prediction_mask": lambda x: batchify(x, padding_value=1), | |
"global_attention": lambda x: batchify(x, padding_value=0), | |
"token2word": None, | |
"sample": None, | |
"special_symbols_mask": lambda x: batchify(x, padding_value=False), | |
"special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), | |
"start_labels": lambda x: batchify(x, padding_value=-100), | |
"end_labels": lambda x: batchify_matrices(x, padding_value=-100), | |
"disambiguation_labels": lambda x: batchify(x, padding_value=-100), | |
"relation_labels": lambda x: batchify_tensor(x, padding_value=-100), | |
"predictable_candidates": None, | |
} | |
if "roberta" in self.transformer_model: | |
del fields_batchers["token_type_ids"] | |
return fields_batchers | |
def _build_input_ids( | |
self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]] | |
) -> List[int]: | |
return ( | |
[self.tokenizer.cls_token_id] | |
+ sentence_input_ids | |
+ [self.tokenizer.sep_token_id] | |
+ flatten(candidates_input_ids) | |
+ [self.tokenizer.sep_token_id] | |
) | |
def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor: | |
special_symbols_mask = input_ids >= ( | |
len(self.tokenizer) | |
- len(self.special_symbols + self.special_symbols_entities) | |
) | |
special_symbols_mask[0] = True | |
return special_symbols_mask | |
def _build_tokenizer_essentials( | |
self, input_ids, original_sequence | |
) -> TokenizationOutput: | |
input_ids = torch.tensor(input_ids, dtype=torch.long) | |
attention_mask = torch.ones_like(input_ids) | |
total_sequence_len = len(input_ids) | |
predictable_sentence_len = len(original_sequence) | |
# token type ids | |
token_type_ids = torch.cat( | |
[ | |
input_ids.new_zeros( | |
predictable_sentence_len + 2 | |
), # original sentence bpes + CLS and SEP | |
input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2), | |
] | |
) | |
# prediction mask -> boolean on tokens that are predictable | |
prediction_mask = torch.tensor( | |
[1] | |
+ ([0] * predictable_sentence_len) | |
+ ([1] * (total_sequence_len - predictable_sentence_len - 1)) | |
) | |
assert len(prediction_mask) == len(input_ids) | |
# special symbols mask | |
special_symbols_mask = input_ids >= ( | |
len(self.tokenizer) | |
- len(self.special_symbols) # + self.special_symbols_entities) | |
) | |
if self.entities_definitions is not None: | |
# select only the first N true values where N is len(entities_definitions) | |
special_symbols_mask_entities = special_symbols_mask.clone() | |
special_symbols_mask_entities[ | |
special_symbols_mask_entities.cumsum(0) > len(self.entities_definitions) | |
] = False | |
special_symbols_mask = special_symbols_mask ^ special_symbols_mask_entities | |
else: | |
special_symbols_mask_entities = special_symbols_mask.clone() | |
return TokenizationOutput( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
prediction_mask, | |
special_symbols_mask, | |
special_symbols_mask_entities, | |
) | |
def _build_labels( | |
self, | |
sample, | |
tokenization_output: TokenizationOutput, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
start_labels = [0] * len(tokenization_output.input_ids) | |
end_labels = [] | |
sample.entities.sort(key=lambda x: (x[0], x[1])) | |
prev_start_bpe = -1 | |
num_repeat_start = 0 | |
if self.entities_definitions: | |
sample.entities = [(ce[0], ce[1], ce[2]) for ce in sample.entities] | |
sample.entity_candidates = list(self.entities_definitions.keys()) | |
disambiguation_labels = torch.zeros( | |
len(sample.entities), | |
len(sample.entity_candidates) + len(sample.candidates), | |
) | |
else: | |
sample.entities = [(ce[0], ce[1], "") for ce in sample.entities] | |
disambiguation_labels = torch.zeros( | |
len(sample.entities), len(sample.candidates) | |
) | |
ignored_labels_indices = tokenization_output.prediction_mask == 1 | |
for idx, c_ent in enumerate(sample.entities): | |
start_bpe = sample.word2token[c_ent[0]][0] + 1 | |
end_bpe = sample.word2token[c_ent[1] - 1][-1] + 1 | |
class_index = idx | |
start_labels[start_bpe] = class_index + 1 # +1 for the NONE class | |
if start_bpe != prev_start_bpe: | |
end_labels.append([0] * len(tokenization_output.input_ids)) | |
# end_labels[-1][:start_bpe] = [-100] * start_bpe | |
end_labels[-1][end_bpe] = class_index + 1 | |
else: | |
end_labels[-1][end_bpe] = class_index + 1 | |
num_repeat_start += 1 | |
if self.entities_definitions: | |
entity_type_idx = sample.entity_candidates.index(c_ent[2]) | |
disambiguation_labels[idx, entity_type_idx] = 1 | |
prev_start_bpe = start_bpe | |
start_labels = torch.tensor(start_labels, dtype=torch.long) | |
start_labels[ignored_labels_indices] = -100 | |
end_labels = torch.tensor(end_labels, dtype=torch.long) | |
end_labels[ignored_labels_indices.repeat(len(end_labels), 1)] = -100 | |
relation_labels = torch.zeros( | |
len(sample.entities), len(sample.entities), len(sample.candidates) | |
) | |
# sample.relations = [] | |
for re in sample.triplets: | |
if re["relation"]["name"] not in sample.candidates: | |
re_class_index = len(sample.candidates) - 1 | |
else: | |
re_class_index = sample.candidates.index( | |
re["relation"]["name"] | |
) # should remove this +1 | |
if self.entities_definitions: | |
subject_class_index = sample.entities.index( | |
( | |
re["subject"]["start"], | |
re["subject"]["end"], | |
re["subject"]["type"], | |
) | |
) | |
object_class_index = sample.entities.index( | |
(re["object"]["start"], re["object"]["end"], re["object"]["type"]) | |
) | |
else: | |
subject_class_index = sample.entities.index( | |
(re["subject"]["start"], re["subject"]["end"], "") | |
) | |
object_class_index = sample.entities.index( | |
(re["object"]["start"], re["object"]["end"], "") | |
) | |
relation_labels[subject_class_index, object_class_index, re_class_index] = 1 | |
if self.entities_definitions: | |
disambiguation_labels[ | |
subject_class_index, re_class_index + len(sample.entity_candidates) | |
] = 1 | |
disambiguation_labels[ | |
object_class_index, re_class_index + len(sample.entity_candidates) | |
] = 1 | |
# sample.relations.append([re['subject']['start'], re['subject']['end'], re['subject']['type'], re['relation']['name'], re['object']['start'], re['object']['end'], re['object']['type']]) | |
else: | |
disambiguation_labels[subject_class_index, re_class_index] = 1 | |
disambiguation_labels[object_class_index, re_class_index] = 1 | |
# sample.relations.append([re['subject']['start'], re['subject']['end'], "", re['relation']['name'], re['object']['start'], re['object']['end'], ""]) | |
return start_labels, end_labels, disambiguation_labels, relation_labels | |
def __iter__(self): | |
dataset_iterator = self.dataset_iterator_func() | |
current_dataset_elements = [] | |
i = None | |
for i, dataset_elem in enumerate(dataset_iterator, start=1): | |
if ( | |
self.section_size is not None | |
and len(current_dataset_elements) == self.section_size | |
): | |
for batch in self.materialize_batches(current_dataset_elements): | |
yield batch | |
current_dataset_elements = [] | |
current_dataset_elements.append(dataset_elem) | |
if i % 50_000 == 0: | |
logger.info(f"Processed: {i} number of elements") | |
if len(current_dataset_elements) != 0: | |
for batch in self.materialize_batches(current_dataset_elements): | |
yield batch | |
if i is not None: | |
logger.info(f"Dataset finished: {i} number of elements processed") | |
else: | |
logger.warning("Dataset empty") | |
def dataset_iterator_func(self): | |
data_samples = ( | |
load_relik_reader_samples(self.dataset_path) | |
if self.samples is None | |
else self.samples | |
) | |
for sample in data_samples: | |
# input sentence tokenization | |
input_tokenized = self.tokenizer( | |
sample.tokens, | |
return_offsets_mapping=True, | |
add_special_tokens=False, | |
is_split_into_words=True, | |
) | |
input_subwords = input_tokenized["input_ids"] | |
offsets = input_tokenized["offset_mapping"] | |
token2word = [] | |
word2token = {} | |
count = 0 | |
for i, offset in enumerate(offsets): | |
if offset[0] == 0: | |
token2word.append(i - count) | |
word2token[i - count] = [i] | |
else: | |
token2word.append(token2word[-1]) | |
word2token[token2word[-1]].append(i) | |
count += 1 | |
sample.token2word = token2word | |
sample.word2token = word2token | |
# input_subwords = sample.tokens[1:-1] # removing special tokens | |
candidates_symbols = self.special_symbols | |
if self.max_candidates > 0: | |
# truncate candidates | |
sample.candidates = sample.candidates[: self.max_candidates] | |
# add NME as a possible candidate | |
if self.use_nme: | |
sample.candidates.insert(0, NME_SYMBOL) | |
# training time sample mods | |
if not self.for_inference: | |
# check whether the sample has labels if not skip | |
if ( | |
sample.triplets is None or len(sample.triplets) == 0 | |
) and self.skip_empty_training_samples: | |
logger.warning( | |
"Sample {} has no labels, skipping".format(sample.sample_id) | |
) | |
continue | |
# add gold candidates if missing | |
if self.add_gold_candidates: | |
candidates_set = set(sample.candidates) | |
candidates_to_add = [] | |
for candidate_title in sample.triplets: | |
if candidate_title["relation"]["name"] not in candidates_set: | |
candidates_to_add.append( | |
candidate_title["relation"]["name"] | |
) | |
if len(candidates_to_add) > 0: | |
# replacing last candidates with the gold ones | |
# this is done in order to preserve the ordering | |
added_gold_candidates = 0 | |
gold_candidates_titles_set = set( | |
set(ct["relation"]["name"] for ct in sample.triplets) | |
) | |
for i in reversed(range(len(sample.candidates))): | |
if ( | |
sample.candidates[i] not in gold_candidates_titles_set | |
and sample.candidates[i] != NME_SYMBOL | |
): | |
sample.candidates[i] = candidates_to_add[ | |
added_gold_candidates | |
] | |
added_gold_candidates += 1 | |
if len(candidates_to_add) == added_gold_candidates: | |
break | |
candidates_still_to_add = ( | |
len(candidates_to_add) - added_gold_candidates | |
) | |
while ( | |
len(sample.candidates) <= len(candidates_symbols) | |
and candidates_still_to_add != 0 | |
): | |
sample.candidates.append( | |
candidates_to_add[added_gold_candidates] | |
) | |
added_gold_candidates += 1 | |
candidates_still_to_add -= 1 | |
# shuffle candidates | |
if ( | |
isinstance(self.shuffle_candidates, bool) | |
and self.shuffle_candidates | |
) or ( | |
isinstance(self.shuffle_candidates, float) | |
and np.random.uniform() < self.shuffle_candidates | |
): | |
np.random.shuffle(sample.candidates) | |
if NME_SYMBOL in sample.candidates: | |
sample.candidates.remove(NME_SYMBOL) | |
sample.candidates.insert(0, NME_SYMBOL) | |
# flip candidates | |
if ( | |
isinstance(self.flip_candidates, bool) and self.flip_candidates | |
) or ( | |
isinstance(self.flip_candidates, float) | |
and np.random.uniform() < self.flip_candidates | |
): | |
for i in range(len(sample.candidates) - 1): | |
if np.random.uniform() < 0.5: | |
sample.candidates[i], sample.candidates[i + 1] = ( | |
sample.candidates[i + 1], | |
sample.candidates[i], | |
) | |
if NME_SYMBOL in sample.candidates: | |
sample.candidates.remove(NME_SYMBOL) | |
sample.candidates.insert(0, NME_SYMBOL) | |
# candidates encoding | |
candidates_symbols = candidates_symbols[: len(sample.candidates)] | |
relations_defs = [ | |
"{} {}".format(cs, self.relations_definitions[ct]) | |
if ct != NME_SYMBOL | |
else NME_SYMBOL | |
for cs, ct in zip(candidates_symbols, sample.candidates) | |
] | |
if self.entities_definitions is not None: | |
candidates_entities_symbols = list(self.special_symbols_entities) | |
candidates_entities_symbols = candidates_entities_symbols[ | |
: len(self.entities_definitions) | |
] | |
entity_defs = [ | |
"{} {}".format(cs, self.entities_definitions[ct]) | |
for cs, ct in zip( | |
candidates_entities_symbols, self.entities_definitions.keys() | |
) | |
] | |
relations_defs = ( | |
entity_defs + [self.tokenizer.sep_token] + relations_defs | |
) | |
candidates_encoding_result = self.tokenizer.batch_encode_plus( | |
relations_defs, | |
add_special_tokens=False, | |
).input_ids | |
# drop candidates if the number of input tokens is too long for the model | |
if ( | |
sum(map(len, candidates_encoding_result)) | |
+ len(input_subwords) | |
+ 20 # + 20 special tokens | |
> self.model_max_length | |
): | |
if self.for_inference: | |
acceptable_tokens_from_candidates = ( | |
self.model_max_length - 20 - len(input_subwords) | |
) | |
while ( | |
cum_len + len(candidates_encoding_result[i]) | |
< acceptable_tokens_from_candidates | |
): | |
cum_len += len(candidates_encoding_result[i]) | |
i += 1 | |
candidates_encoding_result = candidates_encoding_result[:i] | |
if self.entities_definitions is not None: | |
candidates_symbols = candidates_symbols[ | |
: i - len(self.entities_definitions) | |
] | |
sample.candidates = sample.candidates[ | |
: i - len(self.entities_definitions) | |
] | |
else: | |
candidates_symbols = candidates_symbols[:i] | |
sample.candidates = sample.candidates[:i] | |
else: | |
gold_candidates_set = set( | |
[wl["relation"]["name"] for wl in sample.triplets] | |
) | |
gold_candidates_indices = [ | |
i | |
for i, wc in enumerate(sample.candidates) | |
if wc in gold_candidates_set | |
] | |
if self.entities_definitions is not None: | |
gold_candidates_indices = [ | |
i + len(self.entities_definitions) | |
for i in gold_candidates_indices | |
] | |
# add entities indices | |
gold_candidates_indices = gold_candidates_indices + list( | |
range(len(self.entities_definitions)) | |
) | |
necessary_taken_tokens = sum( | |
map( | |
len, | |
[ | |
candidates_encoding_result[i] | |
for i in gold_candidates_indices | |
], | |
) | |
) | |
acceptable_tokens_from_candidates = ( | |
self.model_max_length | |
- 20 | |
- len(input_subwords) | |
- necessary_taken_tokens | |
) | |
assert acceptable_tokens_from_candidates > 0 | |
i = 0 | |
cum_len = 0 | |
while ( | |
cum_len + len(candidates_encoding_result[i]) | |
< acceptable_tokens_from_candidates | |
): | |
if i not in gold_candidates_indices: | |
cum_len += len(candidates_encoding_result[i]) | |
i += 1 | |
new_indices = sorted( | |
list(set(list(range(i)) + gold_candidates_indices)) | |
) | |
np.random.shuffle(new_indices) | |
candidates_encoding_result = [ | |
candidates_encoding_result[i] for i in new_indices | |
] | |
if self.entities_definitions is not None: | |
sample.candidates = [ | |
sample.candidates[i - len(self.entities_definitions)] | |
for i in new_indices | |
] | |
candidates_symbols = candidates_symbols[ | |
: i - len(self.entities_definitions) | |
] | |
else: | |
candidates_symbols = [ | |
candidates_symbols[i] for i in new_indices | |
] | |
sample.window_candidates = [ | |
sample.window_candidates[i] for i in new_indices | |
] | |
if len(sample.candidates) == 0: | |
logger.warning( | |
"Sample {} has no candidates after truncation due to max length".format( | |
sample.sample_id | |
) | |
) | |
continue | |
# final input_ids build | |
input_ids = self._build_input_ids( | |
sentence_input_ids=input_subwords, | |
candidates_input_ids=candidates_encoding_result, | |
) | |
# complete input building (e.g. attention / prediction mask) | |
tokenization_output = self._build_tokenizer_essentials( | |
input_ids, input_subwords | |
) | |
# labels creation | |
start_labels, end_labels, disambiguation_labels, relation_labels = ( | |
None, | |
None, | |
None, | |
None, | |
) | |
if sample.entities is not None and len(sample.entities) > 0: | |
( | |
start_labels, | |
end_labels, | |
disambiguation_labels, | |
relation_labels, | |
) = self._build_labels( | |
sample, | |
tokenization_output, | |
) | |
yield { | |
"input_ids": tokenization_output.input_ids, | |
"attention_mask": tokenization_output.attention_mask, | |
"token_type_ids": tokenization_output.token_type_ids, | |
"prediction_mask": tokenization_output.prediction_mask, | |
"special_symbols_mask": tokenization_output.special_symbols_mask, | |
"special_symbols_mask_entities": tokenization_output.special_symbols_mask_entities, | |
"sample": sample, | |
"start_labels": start_labels, | |
"end_labels": end_labels, | |
"disambiguation_labels": disambiguation_labels, | |
"relation_labels": relation_labels, | |
"predictable_candidates": candidates_symbols, | |
} | |
def preshuffle_elements(self, dataset_elements: List): | |
# This shuffling is done so that when using the sorting function, | |
# if it is deterministic given a collection and its order, we will | |
# make the whole operation not deterministic anymore. | |
# Basically, the aim is not to build every time the same batches. | |
if not self.for_inference: | |
dataset_elements = np.random.permutation(dataset_elements) | |
sorting_fn = ( | |
lambda elem: add_noise_to_value( | |
sum(len(elem[k]) for k in self.sorting_fields), | |
noise_param=self.noise_param, | |
) | |
if not self.for_inference | |
else sum(len(elem[k]) for k in self.sorting_fields) | |
) | |
dataset_elements = sorted(dataset_elements, key=sorting_fn) | |
if self.for_inference: | |
return dataset_elements | |
ds = list(chunks(dataset_elements, 64)) # todo: modified | |
np.random.shuffle(ds) | |
return flatten(ds) | |
def materialize_batches( | |
self, dataset_elements: List[Dict[str, Any]] | |
) -> Generator[Dict[str, Any], None, None]: | |
if self.prebatch: | |
dataset_elements = self.preshuffle_elements(dataset_elements) | |
current_batch = [] | |
# function that creates a batch from the 'current_batch' list | |
def output_batch() -> Dict[str, Any]: | |
assert ( | |
len( | |
set([len(elem["predictable_candidates"]) for elem in current_batch]) | |
) | |
== 1 | |
), " ".join( | |
map( | |
str, [len(elem["predictable_candidates"]) for elem in current_batch] | |
) | |
) | |
batch_dict = dict() | |
de_values_by_field = { | |
fn: [de[fn] for de in current_batch if fn in de] | |
for fn in self.fields_batcher | |
} | |
# in case you provide fields batchers but in the batch | |
# there are no elements for that field | |
de_values_by_field = { | |
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 | |
} | |
assert len(set([len(v) for v in de_values_by_field.values()])) | |
# todo: maybe we should report the user about possible | |
# fields filtering due to "None" instances | |
de_values_by_field = { | |
fn: fvs | |
for fn, fvs in de_values_by_field.items() | |
if all([fv is not None for fv in fvs]) | |
} | |
for field_name, field_values in de_values_by_field.items(): | |
field_batch = ( | |
self.fields_batcher[field_name](field_values) | |
if self.fields_batcher[field_name] is not None | |
else field_values | |
) | |
batch_dict[field_name] = field_batch | |
return batch_dict | |
max_len_discards, min_len_discards = 0, 0 | |
should_token_batch = self.batch_size is None | |
curr_pred_elements = -1 | |
for de in dataset_elements: | |
if ( | |
should_token_batch | |
and self.max_batch_size != -1 | |
and len(current_batch) == self.max_batch_size | |
) or (not should_token_batch and len(current_batch) == self.batch_size): | |
yield output_batch() | |
current_batch = [] | |
curr_pred_elements = -1 | |
# todo support max length (and min length) as dicts | |
too_long_fields = [ | |
k | |
for k in de | |
if self.max_length != -1 | |
and torch.is_tensor(de[k]) | |
and len(de[k]) > self.max_length | |
] | |
if len(too_long_fields) > 0: | |
max_len_discards += 1 | |
continue | |
too_short_fields = [ | |
k | |
for k in de | |
if self.min_length != -1 | |
and torch.is_tensor(de[k]) | |
and len(de[k]) < self.min_length | |
] | |
if len(too_short_fields) > 0: | |
min_len_discards += 1 | |
continue | |
if should_token_batch: | |
de_len = sum(len(de[k]) for k in self.batching_fields) | |
future_max_len = max( | |
de_len, | |
max( | |
[ | |
sum(len(bde[k]) for k in self.batching_fields) | |
for bde in current_batch | |
], | |
default=0, | |
), | |
) | |
future_tokens_per_batch = future_max_len * (len(current_batch) + 1) | |
num_predictable_candidates = len(de["predictable_candidates"]) | |
if len(current_batch) > 0 and ( | |
future_tokens_per_batch >= self.tokens_per_batch | |
or ( | |
num_predictable_candidates != curr_pred_elements | |
and curr_pred_elements != -1 | |
) | |
): | |
yield output_batch() | |
current_batch = [] | |
current_batch.append(de) | |
curr_pred_elements = len(de["predictable_candidates"]) | |
if len(current_batch) != 0 and not self.drop_last: | |
yield output_batch() | |
if max_len_discards > 0: | |
if self.for_inference: | |
logger.warning( | |
f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were " | |
f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" | |
f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the " | |
f"sample length exceeds the maximum length supported by the current model." | |
) | |
else: | |
logger.warning( | |
f"During iteration, {max_len_discards} elements were " | |
f"discarded since longer than max length {self.max_length}" | |
) | |
if min_len_discards > 0: | |
if self.for_inference: | |
logger.warning( | |
f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were " | |
f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" | |
f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the " | |
f"sample length is shorter than the minimum length supported by the current model." | |
) | |
else: | |
logger.warning( | |
f"During iteration, {min_len_discards} elements were " | |
f"discarded since shorter than min length {self.min_length}" | |
) | |
def main(): | |
special_symbols = [NME_SYMBOL] + [f"R-{i}" for i in range(50)] | |
relik_dataset = RelikREDataset( | |
"/home/huguetcabot/alby-re/alby/data/nyt-alby+/valid.jsonl", | |
materialize_samples=False, | |
transformer_model="microsoft/deberta-v3-base", | |
special_symbols=special_symbols, | |
shuffle_candidates=False, | |
flip_candidates=False, | |
for_inference=True, | |
) | |
for batch in relik_dataset: | |
print(batch) | |
exit(0) | |
if __name__ == "__main__": | |
main() | |