Spaces:
Running
Running
"""Tweaked AllenNLP dataset reader.""" | |
import logging | |
import re | |
from random import random | |
from typing import Dict, List | |
from allennlp.common.file_utils import cached_path | |
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | |
from allennlp.data.fields import TextField, SequenceLabelField, MetadataField, Field | |
from allennlp.data.instance import Instance | |
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer | |
from allennlp.data.tokenizers import Token | |
from overrides import overrides | |
from utils.helpers import SEQ_DELIMETERS, START_TOKEN | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
class Seq2LabelsDatasetReader(DatasetReader): | |
""" | |
Reads instances from a pretokenised file where each line is in the following format: | |
WORD###TAG [TAB] WORD###TAG [TAB] ..... \n | |
and converts it into a ``Dataset`` suitable for sequence tagging. You can also specify | |
alternative delimiters in the constructor. | |
Parameters | |
---------- | |
delimiters: ``dict`` | |
The dcitionary with all delimeters. | |
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) | |
We use this to define the input representation for the text. See :class:`TokenIndexer`. | |
Note that the `output` tags will always correspond to single token IDs based on how they | |
are pre-tokenised in the data file. | |
max_len: if set than will truncate long sentences | |
""" | |
# fix broken sentences mostly in Lang8 | |
BROKEN_SENTENCES_REGEXP = re.compile(r'\.[a-zA-RT-Z]') | |
def __init__(self, | |
token_indexers: Dict[str, TokenIndexer] = None, | |
delimeters: dict = SEQ_DELIMETERS, | |
skip_correct: bool = False, | |
skip_complex: int = 0, | |
lazy: bool = False, | |
max_len: int = None, | |
test_mode: bool = False, | |
tag_strategy: str = "keep_one", | |
tn_prob: float = 0, | |
tp_prob: float = 0, | |
broken_dot_strategy: str = "keep") -> None: | |
super().__init__(lazy) | |
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} | |
self._delimeters = delimeters | |
self._max_len = max_len | |
self._skip_correct = skip_correct | |
self._skip_complex = skip_complex | |
self._tag_strategy = tag_strategy | |
self._broken_dot_strategy = broken_dot_strategy | |
self._test_mode = test_mode | |
self._tn_prob = tn_prob | |
self._tp_prob = tp_prob | |
def _read(self, file_path): | |
# if `file_path` is a URL, redirect to the cache | |
file_path = cached_path(file_path) | |
with open(file_path, "r") as data_file: | |
logger.info("Reading instances from lines in file at: %s", file_path) | |
for line in data_file: | |
line = line.strip("\n") | |
# skip blank and broken lines | |
if not line or (not self._test_mode and self._broken_dot_strategy == 'skip' | |
and self.BROKEN_SENTENCES_REGEXP.search(line) is not None): | |
continue | |
tokens_and_tags = [pair.rsplit(self._delimeters['labels'], 1) | |
for pair in line.split(self._delimeters['tokens'])] | |
try: | |
tokens = [Token(token) for token, tag in tokens_and_tags] | |
tags = [tag for token, tag in tokens_and_tags] | |
except ValueError: | |
tokens = [Token(token[0]) for token in tokens_and_tags] | |
tags = None | |
if tokens and tokens[0] != Token(START_TOKEN): | |
tokens = [Token(START_TOKEN)] + tokens | |
words = [x.text for x in tokens] | |
if self._max_len is not None: | |
tokens = tokens[:self._max_len] | |
tags = None if tags is None else tags[:self._max_len] | |
instance = self.text_to_instance(tokens, tags, words) | |
if instance: | |
yield instance | |
def extract_tags(self, tags: List[str]): | |
op_del = self._delimeters['operations'] | |
labels = [x.split(op_del) for x in tags] | |
comlex_flag_dict = {} | |
# get flags | |
for i in range(5): | |
idx = i + 1 | |
comlex_flag_dict[idx] = sum([len(x) > idx for x in labels]) | |
if self._tag_strategy == "keep_one": | |
# get only first candidates for r_tags in right and the last for left | |
labels = [x[0] for x in labels] | |
elif self._tag_strategy == "merge_all": | |
# consider phrases as a words | |
pass | |
else: | |
raise Exception("Incorrect tag strategy") | |
detect_tags = ["CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels] | |
return labels, detect_tags, comlex_flag_dict | |
def text_to_instance(self, tokens: List[Token], tags: List[str] = None, | |
words: List[str] = None) -> Instance: # type: ignore | |
""" | |
We take `pre-tokenized` input here, because we don't have a tokenizer in this class. | |
""" | |
# pylint: disable=arguments-differ | |
fields: Dict[str, Field] = {} | |
sequence = TextField(tokens, self._token_indexers) | |
fields["tokens"] = sequence | |
fields["metadata"] = MetadataField({"words": words}) | |
if tags is not None: | |
labels, detect_tags, complex_flag_dict = self.extract_tags(tags) | |
if self._skip_complex and complex_flag_dict[self._skip_complex] > 0: | |
return None | |
rnd = random() | |
# skip TN | |
if self._skip_correct and all(x == "CORRECT" for x in detect_tags): | |
if rnd > self._tn_prob: | |
return None | |
# skip TP | |
else: | |
if rnd > self._tp_prob: | |
return None | |
fields["labels"] = SequenceLabelField(labels, sequence, | |
label_namespace="labels") | |
fields["d_tags"] = SequenceLabelField(detect_tags, sequence, | |
label_namespace="d_tags") | |
return Instance(fields) | |