metric / augmentors.py
Elron's picture
Upload folder using huggingface_hub
357b16c verified
raw
history blame
7.54 kB
from random import Random
from typing import (
Any,
Dict,
List,
Optional,
Union,
)
from .operators import FieldOperator
from .random_utils import new_random_generator
from .type_utils import isoftype, parse_type_string, to_type_string
from .types import Text
class Augmentor(FieldOperator):
"""A stream operator that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template."""
pass
class TaskInputsAugmentor(Augmentor):
def set_fields(self, fields: List[str]):
fields = ["input_fields/" + field for field in fields]
self.field_to_field = {field: field for field in fields}
class TypeDependentAugmentor(TaskInputsAugmentor):
augmented_type: object
def process_instance_value(self, value: Any, instance: Dict[str, Any]):
if not isoftype(value, self.augmented_type):
return value
return super().process_instance_value(value=value, instance=instance)
@classmethod
def process_data_after_load(cls, data):
if "augmented_type" in data:
data["augmented_type"] = parse_type_string(data["augmented_type"])
return data
def process_data_before_dump(self, data):
if "augmented_type" in data:
data["augmented_type"] = to_type_string(data["augmented_type"])
return data
class TextAugmentor(TypeDependentAugmentor):
augmented_type = Text
class NullAugmentor(TaskInputsAugmentor):
"""Does not change the input string."""
def process_value(self, value: Any) -> Any:
return value
class AugmentWhitespace(TextAugmentor):
"""Augments the inputs by replacing existing whitespaces with other whitespaces.
Currently, each whitespace is replaced by a random choice of 1-3 whitespace characters (space, tab, newline).
"""
def process_value(self, value: str) -> str:
import re
words = re.split(r"(\s+)", value)
new_value = ""
random_generator = new_random_generator(sub_seed=value)
for word in words:
if word.isspace():
new_value += random_generator.choice(
["\n", "\t", " "]
) * random_generator.randint(1, 3)
else:
new_value += word
return new_value
class AugmentPrefixSuffix(TextAugmentor):
r"""Augments the input by prepending and appending randomly selected patterns (typically, whitespace).
Args:
prefixes (list or dict or None): the potential patterns (typically, whitespace) to select prefix from. The dictionary version allows the specification of relative weights for the different patterns. Set to None if not needed (i.e., only suffixes are needed).
suffixes (list or dict or None): the potential patterns (typically, whitespace) to select suffix from. The dictionary version allows the specification of relative weights for the different patterns. Set to None if not needed (i.e., only prefixes are needed).
prefix_len (positive int): the length of the prefix to be added.
suffix_len (positive int): The length of the suffix to be added.
remove_existing_whitespaces (bool): Clean any existing leading and trailing whitespaces. The selected pattern(s) are then prepended and/or appended to the potentially trimmed input.
Examples:
To prepend the input with a prefix made of 4 ``\n``-s or ``\t``-s, employ
``AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)``.
To append the input with a suffix made of 3 ``\n``-s or ``\t``-s, with ``\n`` being preferred over ``\t``,
at 2:1 ratio, employ
``AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)``
which will append ``\n``-s twice as often as ``\t``-s.
"""
prefixes: Optional[Union[List[str], Dict[str, int]]] = {
" ": 20,
"\\t": 10,
"\\n": 40,
"": 30,
}
prefix_len: Optional[int] = 3
suffixes: Optional[Union[List[str], Dict[str, int]]] = {
" ": 20,
"\\t": 10,
"\\n": 40,
"": 30,
}
suffix_len: Optional[int] = 3
remove_existing_whitespaces: Optional[bool] = False
def verify(self):
assert (
self.prefixes or self.suffixes
), "At least one of prefixes/suffixes should be not None."
for arg, arg_name in zip(
[self.prefixes, self.suffixes], ["prefixes", "suffixes"]
):
assert (
arg is None or isoftype(arg, List[str]) or isoftype(arg, Dict[str, int])
), f"Argument {arg_name} should be either None or a list of strings or a dictionary str->int. {arg} is none of the above."
assert (
self.prefix_len > 0
), f"prefix_len must be positive, got {self.prefix_len}"
assert (
self.suffix_len > 0
), f"suffix_len must be positive, got {self.suffix_len}"
super().verify()
def _calculate_distributions(self, prefs_or_suffs):
if prefs_or_suffs is None:
return None, None
patterns = (
prefs_or_suffs
if isinstance(prefs_or_suffs, list)
else [k for k, v in prefs_or_suffs.items()]
)
total_weight = (
len(patterns)
if isinstance(prefs_or_suffs, list)
else sum([v for k, v in prefs_or_suffs.items()])
)
weights = (
[1.0 / total_weight] * len(patterns)
if isinstance(prefs_or_suffs, list)
else [float(prefs_or_suffs[p]) / total_weight for p in patterns]
)
return patterns, weights
def prepare(self):
# Being an artifact, prepare is invoked before verify. Here we need verify before the actions
self.verify()
self._prefix_pattern_distribution = {"length": self.prefix_len}
self._suffix_pattern_distribution = {"length": self.suffix_len}
(
self._prefix_pattern_distribution["patterns"],
self._prefix_pattern_distribution["weights"],
) = self._calculate_distributions(self.prefixes)
(
self._suffix_pattern_distribution["patterns"],
self._suffix_pattern_distribution["weights"],
) = self._calculate_distributions(self.suffixes)
super().prepare()
def _get_random_pattern(
self, pattern_distribution, random_generator: Random
) -> str:
string_to_add = ""
if pattern_distribution["patterns"]:
string_to_add = "".join(
random_generator.choices(
pattern_distribution["patterns"],
pattern_distribution["weights"],
k=pattern_distribution["length"],
)
)
return string_to_add
def process_value(self, value: Any) -> Any:
assert value is not None, "input value should not be None"
new_value = str(value)
if self.remove_existing_whitespaces:
new_value = new_value.strip()
random_generator = new_random_generator(sub_seed=value)
prefix = self._get_random_pattern(
self._prefix_pattern_distribution, random_generator
)
suffix = self._get_random_pattern(
self._suffix_pattern_distribution, random_generator
)
return prefix + new_value + suffix