|
from random import Random |
|
from typing import ( |
|
Any, |
|
Dict, |
|
List, |
|
Optional, |
|
Union, |
|
) |
|
|
|
from .operators import FieldOperator |
|
from .random_utils import new_random_generator |
|
from .type_utils import isoftype, parse_type_string, to_type_string |
|
from .types import Text |
|
|
|
|
|
class Augmentor(FieldOperator): |
|
"""A stream operator that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template.""" |
|
|
|
pass |
|
|
|
|
|
class TaskInputsAugmentor(Augmentor): |
|
def set_fields(self, fields: List[str]): |
|
fields = ["input_fields/" + field for field in fields] |
|
self.field_to_field = {field: field for field in fields} |
|
|
|
|
|
class TypeDependentAugmentor(TaskInputsAugmentor): |
|
augmented_type: object |
|
|
|
def process_instance_value(self, value: Any, instance: Dict[str, Any]): |
|
if not isoftype(value, self.augmented_type): |
|
return value |
|
return super().process_instance_value(value=value, instance=instance) |
|
|
|
@classmethod |
|
def process_data_after_load(cls, data): |
|
if "augmented_type" in data: |
|
data["augmented_type"] = parse_type_string(data["augmented_type"]) |
|
return data |
|
|
|
def process_data_before_dump(self, data): |
|
if "augmented_type" in data: |
|
data["augmented_type"] = to_type_string(data["augmented_type"]) |
|
return data |
|
|
|
|
|
class TextAugmentor(TypeDependentAugmentor): |
|
augmented_type = Text |
|
|
|
|
|
class NullAugmentor(TaskInputsAugmentor): |
|
"""Does not change the input string.""" |
|
|
|
def process_value(self, value: Any) -> Any: |
|
return value |
|
|
|
|
|
class AugmentWhitespace(TextAugmentor): |
|
"""Augments the inputs by replacing existing whitespaces with other whitespaces. |
|
|
|
Currently, each whitespace is replaced by a random choice of 1-3 whitespace characters (space, tab, newline). |
|
""" |
|
|
|
def process_value(self, value: str) -> str: |
|
import re |
|
|
|
words = re.split(r"(\s+)", value) |
|
new_value = "" |
|
|
|
random_generator = new_random_generator(sub_seed=value) |
|
for word in words: |
|
if word.isspace(): |
|
new_value += random_generator.choice( |
|
["\n", "\t", " "] |
|
) * random_generator.randint(1, 3) |
|
else: |
|
new_value += word |
|
return new_value |
|
|
|
|
|
class AugmentPrefixSuffix(TextAugmentor): |
|
r"""Augments the input by prepending and appending randomly selected patterns (typically, whitespace). |
|
|
|
Args: |
|
prefixes (list or dict or None): the potential patterns (typically, whitespace) to select prefix from. The dictionary version allows the specification of relative weights for the different patterns. Set to None if not needed (i.e., only suffixes are needed). |
|
|
|
suffixes (list or dict or None): the potential patterns (typically, whitespace) to select suffix from. The dictionary version allows the specification of relative weights for the different patterns. Set to None if not needed (i.e., only prefixes are needed). |
|
|
|
prefix_len (positive int): the length of the prefix to be added. |
|
|
|
suffix_len (positive int): The length of the suffix to be added. |
|
|
|
remove_existing_whitespaces (bool): Clean any existing leading and trailing whitespaces. The selected pattern(s) are then prepended and/or appended to the potentially trimmed input. |
|
|
|
Examples: |
|
To prepend the input with a prefix made of 4 ``\n``-s or ``\t``-s, employ |
|
``AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)``. |
|
|
|
To append the input with a suffix made of 3 ``\n``-s or ``\t``-s, with ``\n`` being preferred over ``\t``, |
|
at 2:1 ratio, employ |
|
``AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)`` |
|
which will append ``\n``-s twice as often as ``\t``-s. |
|
|
|
""" |
|
|
|
prefixes: Optional[Union[List[str], Dict[str, int]]] = { |
|
" ": 20, |
|
"\\t": 10, |
|
"\\n": 40, |
|
"": 30, |
|
} |
|
prefix_len: Optional[int] = 3 |
|
suffixes: Optional[Union[List[str], Dict[str, int]]] = { |
|
" ": 20, |
|
"\\t": 10, |
|
"\\n": 40, |
|
"": 30, |
|
} |
|
suffix_len: Optional[int] = 3 |
|
remove_existing_whitespaces: Optional[bool] = False |
|
|
|
def verify(self): |
|
assert ( |
|
self.prefixes or self.suffixes |
|
), "At least one of prefixes/suffixes should be not None." |
|
for arg, arg_name in zip( |
|
[self.prefixes, self.suffixes], ["prefixes", "suffixes"] |
|
): |
|
assert ( |
|
arg is None or isoftype(arg, List[str]) or isoftype(arg, Dict[str, int]) |
|
), f"Argument {arg_name} should be either None or a list of strings or a dictionary str->int. {arg} is none of the above." |
|
assert ( |
|
self.prefix_len > 0 |
|
), f"prefix_len must be positive, got {self.prefix_len}" |
|
assert ( |
|
self.suffix_len > 0 |
|
), f"suffix_len must be positive, got {self.suffix_len}" |
|
super().verify() |
|
|
|
def _calculate_distributions(self, prefs_or_suffs): |
|
if prefs_or_suffs is None: |
|
return None, None |
|
patterns = ( |
|
prefs_or_suffs |
|
if isinstance(prefs_or_suffs, list) |
|
else [k for k, v in prefs_or_suffs.items()] |
|
) |
|
total_weight = ( |
|
len(patterns) |
|
if isinstance(prefs_or_suffs, list) |
|
else sum([v for k, v in prefs_or_suffs.items()]) |
|
) |
|
weights = ( |
|
[1.0 / total_weight] * len(patterns) |
|
if isinstance(prefs_or_suffs, list) |
|
else [float(prefs_or_suffs[p]) / total_weight for p in patterns] |
|
) |
|
return patterns, weights |
|
|
|
def prepare(self): |
|
|
|
self.verify() |
|
self._prefix_pattern_distribution = {"length": self.prefix_len} |
|
self._suffix_pattern_distribution = {"length": self.suffix_len} |
|
|
|
( |
|
self._prefix_pattern_distribution["patterns"], |
|
self._prefix_pattern_distribution["weights"], |
|
) = self._calculate_distributions(self.prefixes) |
|
( |
|
self._suffix_pattern_distribution["patterns"], |
|
self._suffix_pattern_distribution["weights"], |
|
) = self._calculate_distributions(self.suffixes) |
|
super().prepare() |
|
|
|
def _get_random_pattern( |
|
self, pattern_distribution, random_generator: Random |
|
) -> str: |
|
string_to_add = "" |
|
if pattern_distribution["patterns"]: |
|
string_to_add = "".join( |
|
random_generator.choices( |
|
pattern_distribution["patterns"], |
|
pattern_distribution["weights"], |
|
k=pattern_distribution["length"], |
|
) |
|
) |
|
return string_to_add |
|
|
|
def process_value(self, value: Any) -> Any: |
|
assert value is not None, "input value should not be None" |
|
new_value = str(value) |
|
if self.remove_existing_whitespaces: |
|
new_value = new_value.strip() |
|
random_generator = new_random_generator(sub_seed=value) |
|
prefix = self._get_random_pattern( |
|
self._prefix_pattern_distribution, random_generator |
|
) |
|
suffix = self._get_random_pattern( |
|
self._suffix_pattern_distribution, random_generator |
|
) |
|
return prefix + new_value + suffix |
|
|