Upload folder using huggingface_hub
Browse files- README.md +5 -5
- augmentors.py +195 -0
- dataset.py +3 -0
- deprecation_utils.py +37 -0
- dialog_operators.py +151 -4
- formats.py +9 -7
- image_operators.py +58 -9
- inference.py +195 -18
- llm_as_judge.py +12 -4
- loaders.py +10 -2
- metric.py +3 -0
- metrics.py +15 -7
- operators.py +1 -225
- schema.py +26 -13
- serializers.py +142 -0
- settings_utils.py +1 -0
- standard.py +21 -6
- struct_data_operators.py +50 -52
- templates.py +140 -151
- type_utils.py +152 -59
- types.py +36 -0
- utils.py +21 -1
- version.py +1 -1
README.md
CHANGED
@@ -40,11 +40,11 @@ https://github.com/IBM/unitxt/assets/23455264/baef9131-39d4-4164-90b2-05da52919f
|
|
40 |
|
41 |
### 🦄 Currently on Unitxt Catalog
|
42 |
|
43 |
-
![NLP Tasks](https://img.shields.io/badge/NLP_tasks-
|
44 |
-
![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-
|
45 |
-
![Templates](https://img.shields.io/badge/Templates-
|
46 |
-
![Formats](https://img.shields.io/badge/Formats-
|
47 |
-
![Metrics](https://img.shields.io/badge/Metrics-
|
48 |
|
49 |
### 🦄 Run Unitxt Exploration Dashboard
|
50 |
|
|
|
40 |
|
41 |
### 🦄 Currently on Unitxt Catalog
|
42 |
|
43 |
+
![NLP Tasks](https://img.shields.io/badge/NLP_tasks-48-blue)
|
44 |
+
![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-537-blue)
|
45 |
+
![Templates](https://img.shields.io/badge/Templates-265-blue)
|
46 |
+
![Formats](https://img.shields.io/badge/Formats-23-blue)
|
47 |
+
![Metrics](https://img.shields.io/badge/Metrics-136-blue)
|
48 |
|
49 |
### 🦄 Run Unitxt Exploration Dashboard
|
50 |
|
augmentors.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from random import Random
|
2 |
+
from typing import (
|
3 |
+
Any,
|
4 |
+
Dict,
|
5 |
+
List,
|
6 |
+
Optional,
|
7 |
+
Union,
|
8 |
+
)
|
9 |
+
|
10 |
+
from .operators import FieldOperator
|
11 |
+
from .random_utils import new_random_generator
|
12 |
+
from .type_utils import isoftype
|
13 |
+
|
14 |
+
|
15 |
+
class Augmentor(FieldOperator):
|
16 |
+
"""A stream operator that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template."""
|
17 |
+
|
18 |
+
operator: FieldOperator
|
19 |
+
|
20 |
+
def process_value(self, value: Any) -> Any:
|
21 |
+
return self.operator.process_value(value)
|
22 |
+
|
23 |
+
|
24 |
+
class TaskInputsAugmentor(Augmentor):
|
25 |
+
def set_fields(self, fields: List[str]):
|
26 |
+
fields = ["input_fields/" + field for field in fields]
|
27 |
+
self.field_to_field = {field: field for field in fields}
|
28 |
+
|
29 |
+
|
30 |
+
class FinalStateInputsAugmentor(Augmentor):
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
class ModelInputAugmentor(FinalStateInputsAugmentor):
|
35 |
+
field = "source"
|
36 |
+
|
37 |
+
|
38 |
+
class ImagesAugmentor(FinalStateInputsAugmentor):
|
39 |
+
field = "media/images"
|
40 |
+
process_every_value = True
|
41 |
+
|
42 |
+
|
43 |
+
class Identity(FieldOperator):
|
44 |
+
def process_value(self, value: Any) -> Any:
|
45 |
+
return value
|
46 |
+
|
47 |
+
|
48 |
+
class NullAugmentor(Augmentor):
|
49 |
+
"""Does not change the input string."""
|
50 |
+
|
51 |
+
operator = Identity()
|
52 |
+
|
53 |
+
|
54 |
+
class AugmentWhitespace(FieldOperator):
|
55 |
+
"""Augments the inputs by replacing existing whitespaces with other whitespaces.
|
56 |
+
|
57 |
+
Currently, each whitespace is replaced by a random choice of 1-3 whitespace characters (space, tab, newline).
|
58 |
+
"""
|
59 |
+
|
60 |
+
def process_value(self, value: str) -> str:
|
61 |
+
import re
|
62 |
+
|
63 |
+
words = re.split(r"(\s+)", value)
|
64 |
+
new_value = ""
|
65 |
+
|
66 |
+
random_generator = new_random_generator(sub_seed=value)
|
67 |
+
for word in words:
|
68 |
+
if word.isspace():
|
69 |
+
new_value += random_generator.choice(
|
70 |
+
["\n", "\t", " "]
|
71 |
+
) * random_generator.randint(1, 3)
|
72 |
+
else:
|
73 |
+
new_value += word
|
74 |
+
return new_value
|
75 |
+
|
76 |
+
|
77 |
+
class AugmentPrefixSuffix(FieldOperator):
|
78 |
+
r"""Augments the input by prepending and appending randomly selected (typically, whitespace) patterns.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
prefixes, suffixes (list or dict) : the potential (typically, whitespace) patterns to select from.
|
82 |
+
The dictionary version allows the specification relative weights for the different patterns.
|
83 |
+
prefix_len, suffix_len (positive int) : The added prefix or suffix will be of a certain length.
|
84 |
+
remove_existing_whitespaces : Clean any existing leading and trailing whitespaces.
|
85 |
+
The strings made of repetitions of the selected pattern(s) are then prepended and/or appended to the potentially
|
86 |
+
trimmed input.
|
87 |
+
If only either just prefixes or just suffixes are needed, set the other to None.
|
88 |
+
|
89 |
+
Examples:
|
90 |
+
To prepend the input with a prefix made of 4 '\n'-s or '\t'-s, employ
|
91 |
+
AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)
|
92 |
+
To append the input with a suffix made of 3 '\n'-s or '\t'-s, with triple '\n' suffixes
|
93 |
+
being preferred over triple '\t', at 2:1 ratio, employ
|
94 |
+
AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)
|
95 |
+
which will append '\n'-s twice as often as '\t'-s.
|
96 |
+
|
97 |
+
"""
|
98 |
+
|
99 |
+
prefixes: Optional[Union[List[str], Dict[str, int]]] = {
|
100 |
+
" ": 20,
|
101 |
+
"\\t": 10,
|
102 |
+
"\\n": 40,
|
103 |
+
"": 30,
|
104 |
+
}
|
105 |
+
prefix_len: Optional[int] = 3
|
106 |
+
suffixes: Optional[Union[List[str], Dict[str, int]]] = {
|
107 |
+
" ": 20,
|
108 |
+
"\\t": 10,
|
109 |
+
"\\n": 40,
|
110 |
+
"": 30,
|
111 |
+
}
|
112 |
+
suffix_len: Optional[int] = 3
|
113 |
+
remove_existing_whitespaces: Optional[bool] = False
|
114 |
+
|
115 |
+
def verify(self):
|
116 |
+
assert (
|
117 |
+
self.prefixes or self.suffixes
|
118 |
+
), "At least one of prefixes/suffixes should be not None."
|
119 |
+
for arg, arg_name in zip(
|
120 |
+
[self.prefixes, self.suffixes], ["prefixes", "suffixes"]
|
121 |
+
):
|
122 |
+
assert (
|
123 |
+
arg is None or isoftype(arg, List[str]) or isoftype(arg, Dict[str, int])
|
124 |
+
), f"Argument {arg_name} should be either None or a list of strings or a dictionary str->int. {arg} is none of the above."
|
125 |
+
assert (
|
126 |
+
self.prefix_len > 0
|
127 |
+
), f"prefix_len must be positive, got {self.prefix_len}"
|
128 |
+
assert (
|
129 |
+
self.suffix_len > 0
|
130 |
+
), f"suffix_len must be positive, got {self.suffix_len}"
|
131 |
+
super().verify()
|
132 |
+
|
133 |
+
def _calculate_distributions(self, prefs_or_suffs):
|
134 |
+
if prefs_or_suffs is None:
|
135 |
+
return None, None
|
136 |
+
patterns = (
|
137 |
+
prefs_or_suffs
|
138 |
+
if isinstance(prefs_or_suffs, list)
|
139 |
+
else [k for k, v in prefs_or_suffs.items()]
|
140 |
+
)
|
141 |
+
total_weight = (
|
142 |
+
len(patterns)
|
143 |
+
if isinstance(prefs_or_suffs, list)
|
144 |
+
else sum([v for k, v in prefs_or_suffs.items()])
|
145 |
+
)
|
146 |
+
weights = (
|
147 |
+
[1.0 / total_weight] * len(patterns)
|
148 |
+
if isinstance(prefs_or_suffs, list)
|
149 |
+
else [float(prefs_or_suffs[p]) / total_weight for p in patterns]
|
150 |
+
)
|
151 |
+
return patterns, weights
|
152 |
+
|
153 |
+
def prepare(self):
|
154 |
+
# Being an artifact, prepare is invoked before verify. Here we need verify before the actions
|
155 |
+
self.verify()
|
156 |
+
self._prefix_pattern_distribution = {"length": self.prefix_len}
|
157 |
+
self._suffix_pattern_distribution = {"length": self.suffix_len}
|
158 |
+
|
159 |
+
(
|
160 |
+
self._prefix_pattern_distribution["patterns"],
|
161 |
+
self._prefix_pattern_distribution["weights"],
|
162 |
+
) = self._calculate_distributions(self.prefixes)
|
163 |
+
(
|
164 |
+
self._suffix_pattern_distribution["patterns"],
|
165 |
+
self._suffix_pattern_distribution["weights"],
|
166 |
+
) = self._calculate_distributions(self.suffixes)
|
167 |
+
super().prepare()
|
168 |
+
|
169 |
+
def _get_random_pattern(
|
170 |
+
self, pattern_distribution, random_generator: Random
|
171 |
+
) -> str:
|
172 |
+
string_to_add = ""
|
173 |
+
if pattern_distribution["patterns"]:
|
174 |
+
string_to_add = "".join(
|
175 |
+
random_generator.choices(
|
176 |
+
pattern_distribution["patterns"],
|
177 |
+
pattern_distribution["weights"],
|
178 |
+
k=pattern_distribution["length"],
|
179 |
+
)
|
180 |
+
)
|
181 |
+
return string_to_add
|
182 |
+
|
183 |
+
def process_value(self, value: Any) -> Any:
|
184 |
+
assert value is not None, "input value should not be None"
|
185 |
+
new_value = str(value)
|
186 |
+
if self.remove_existing_whitespaces:
|
187 |
+
new_value = new_value.strip()
|
188 |
+
random_generator = new_random_generator(sub_seed=value)
|
189 |
+
prefix = self._get_random_pattern(
|
190 |
+
self._prefix_pattern_distribution, random_generator
|
191 |
+
)
|
192 |
+
suffix = self._get_random_pattern(
|
193 |
+
self._suffix_pattern_distribution, random_generator
|
194 |
+
)
|
195 |
+
return prefix + new_value + suffix
|
dataset.py
CHANGED
@@ -4,6 +4,7 @@ import datasets
|
|
4 |
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
|
|
7 |
from .benchmark import __file__ as _
|
8 |
from .blocks import __file__ as _
|
9 |
from .card import __file__ as _
|
@@ -43,6 +44,7 @@ from .random_utils import __file__ as _
|
|
43 |
from .recipe import __file__ as _
|
44 |
from .register import __file__ as _
|
45 |
from .schema import __file__ as _
|
|
|
46 |
from .settings_utils import __file__ as _
|
47 |
from .settings_utils import get_constants
|
48 |
from .span_lableing_operators import __file__ as _
|
@@ -58,6 +60,7 @@ from .task import __file__ as _
|
|
58 |
from .templates import __file__ as _
|
59 |
from .text_utils import __file__ as _
|
60 |
from .type_utils import __file__ as _
|
|
|
61 |
from .utils import __file__ as _
|
62 |
from .utils import is_package_installed
|
63 |
from .validate import __file__ as _
|
|
|
4 |
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
7 |
+
from .augmentors import __file__ as _
|
8 |
from .benchmark import __file__ as _
|
9 |
from .blocks import __file__ as _
|
10 |
from .card import __file__ as _
|
|
|
44 |
from .recipe import __file__ as _
|
45 |
from .register import __file__ as _
|
46 |
from .schema import __file__ as _
|
47 |
+
from .serializers import __file__ as _
|
48 |
from .settings_utils import __file__ as _
|
49 |
from .settings_utils import get_constants
|
50 |
from .span_lableing_operators import __file__ as _
|
|
|
60 |
from .templates import __file__ as _
|
61 |
from .text_utils import __file__ as _
|
62 |
from .type_utils import __file__ as _
|
63 |
+
from .types import __file__ as _
|
64 |
from .utils import __file__ as _
|
65 |
from .utils import is_package_installed
|
66 |
from .validate import __file__ as _
|
deprecation_utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import functools
|
2 |
import warnings
|
3 |
|
|
|
4 |
from .settings_utils import get_constants, get_settings
|
5 |
|
6 |
constants = get_constants()
|
@@ -98,3 +99,39 @@ def deprecation(version, alternative=None, msg=None):
|
|
98 |
return depraction_wrapper(func, version, alt_text)
|
99 |
|
100 |
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import functools
|
2 |
import warnings
|
3 |
|
4 |
+
from .error_utils import UnitxtWarning
|
5 |
from .settings_utils import get_constants, get_settings
|
6 |
|
7 |
constants = get_constants()
|
|
|
99 |
return depraction_wrapper(func, version, alt_text)
|
100 |
|
101 |
return decorator
|
102 |
+
|
103 |
+
|
104 |
+
def init_warning(msg=""):
|
105 |
+
# Decorator that raises warning when class is initialized
|
106 |
+
def decorator(initiated_class):
|
107 |
+
UnitxtWarning(msg)
|
108 |
+
return initiated_class
|
109 |
+
|
110 |
+
return decorator
|
111 |
+
|
112 |
+
|
113 |
+
def warn_on_call(warning_type=UserWarning, msg=""):
|
114 |
+
def decorator(obj):
|
115 |
+
if isinstance(obj, type):
|
116 |
+
original_init = obj.__init__
|
117 |
+
|
118 |
+
@functools.wraps(original_init)
|
119 |
+
def new_init(self, *args, **kwargs):
|
120 |
+
warnings.warn(msg, warning_type, stacklevel=2)
|
121 |
+
original_init(self, *args, **kwargs)
|
122 |
+
|
123 |
+
obj.__init__ = new_init
|
124 |
+
return obj
|
125 |
+
|
126 |
+
if callable(obj):
|
127 |
+
|
128 |
+
@functools.wraps(obj)
|
129 |
+
def wrapper(*args, **kwargs):
|
130 |
+
warnings.warn(msg, warning_type, stacklevel=2)
|
131 |
+
return obj(*args, **kwargs)
|
132 |
+
|
133 |
+
return wrapper
|
134 |
+
|
135 |
+
raise TypeError("This decorator can only be applied to classes or functions.")
|
136 |
+
|
137 |
+
return decorator
|
dialog_operators.py
CHANGED
@@ -11,7 +11,6 @@ dialog = [
|
|
11 |
{"user": "kkk", "system": ""},
|
12 |
]
|
13 |
"""
|
14 |
-
|
15 |
from typing import Any, Dict, List, Optional
|
16 |
|
17 |
from .formats import SystemFormat
|
@@ -34,10 +33,11 @@ class SerializeDialog(InstanceFieldOperator):
|
|
34 |
context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
|
35 |
"""
|
36 |
|
37 |
-
format:
|
38 |
last_response_to_field: Optional[str] = None
|
39 |
context_field: Optional[str] = None
|
40 |
context_separator: str = " "
|
|
|
41 |
|
42 |
def standardize_format(self, demo_format):
|
43 |
turn_format = demo_format.replace("{source}", "{user}")
|
@@ -54,10 +54,11 @@ class SerializeDialog(InstanceFieldOperator):
|
|
54 |
return turn_format[: turn_format.index("{user}") + len("{user}")]
|
55 |
|
56 |
def get_turn_format(self, turn_format, step, length):
|
57 |
-
if step == 0:
|
58 |
turn_format = self.slice_first_turn(turn_format)
|
59 |
if step == length - 1:
|
60 |
-
|
|
|
61 |
if self.last_response_to_field is not None:
|
62 |
turn_format = self.slice_last_response(turn_format)
|
63 |
return turn_format
|
@@ -87,3 +88,149 @@ class SerializeDialog(InstanceFieldOperator):
|
|
87 |
if self.last_response_to_field is not None:
|
88 |
instance[self.last_response_to_field] = turn["system"]
|
89 |
return dialog
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
{"user": "kkk", "system": ""},
|
12 |
]
|
13 |
"""
|
|
|
14 |
from typing import Any, Dict, List, Optional
|
15 |
|
16 |
from .formats import SystemFormat
|
|
|
33 |
context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
|
34 |
"""
|
35 |
|
36 |
+
format: SystemFormat = None
|
37 |
last_response_to_field: Optional[str] = None
|
38 |
context_field: Optional[str] = None
|
39 |
context_separator: str = " "
|
40 |
+
slice_first_and_last_turns_format: bool = True
|
41 |
|
42 |
def standardize_format(self, demo_format):
|
43 |
turn_format = demo_format.replace("{source}", "{user}")
|
|
|
54 |
return turn_format[: turn_format.index("{user}") + len("{user}")]
|
55 |
|
56 |
def get_turn_format(self, turn_format, step, length):
|
57 |
+
if step == 0 and self.slice_first_and_last_turns_format:
|
58 |
turn_format = self.slice_first_turn(turn_format)
|
59 |
if step == length - 1:
|
60 |
+
if self.slice_first_and_last_turns_format:
|
61 |
+
turn_format = self.slice_last_turn(turn_format)
|
62 |
if self.last_response_to_field is not None:
|
63 |
turn_format = self.slice_last_response(turn_format)
|
64 |
return turn_format
|
|
|
88 |
if self.last_response_to_field is not None:
|
89 |
instance[self.last_response_to_field] = turn["system"]
|
90 |
return dialog
|
91 |
+
|
92 |
+
|
93 |
+
class SerializeOpenAiFormatDialog(SerializeDialog):
|
94 |
+
"""Serializes dialog data for feeding into a model.
|
95 |
+
|
96 |
+
This class takes structured dialog data in the OpenAi format, and converts it into a text format
|
97 |
+
according to a specified template. It allows for the inclusion or exclusion
|
98 |
+
of system responses and can operate on a per-turn basis or aggregate the entire
|
99 |
+
dialog.
|
100 |
+
|
101 |
+
Attributes:
|
102 |
+
field (str): The field in the input data that contains the dialog.
|
103 |
+
to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
|
104 |
+
last_user_turn_to_field (Optional[str]): Field to store the last user turn.
|
105 |
+
last_system_turn_to_field (Optional[str]): Field to store the last system turn.
|
106 |
+
context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
|
107 |
+
"""
|
108 |
+
|
109 |
+
is_last_turn_user_only: bool = True
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
def validate_openai_dialog_format(dialog: List[Dict[str, str]]) -> None:
|
113 |
+
"""Validates that the given dialog follows the correct OpenAI format.
|
114 |
+
|
115 |
+
The function checks that:
|
116 |
+
1. The dialog is a list of dictionaries.
|
117 |
+
2. Each dictionary contains the keys 'role' and 'content'.
|
118 |
+
3. The 'role' value is either 'user' or 'assistant'.
|
119 |
+
4. Both 'role' and 'content' values are strings.
|
120 |
+
5. The first 'role' is 'user'
|
121 |
+
|
122 |
+
If the dialog does not conform to the expected format, a descriptive
|
123 |
+
ValueError is raised indicating the issue.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
dialog (List[Dict[str, str]]): The dialog to validate.
|
127 |
+
|
128 |
+
Raises:
|
129 |
+
ValueError: If the dialog does not meet the format requirements.
|
130 |
+
"""
|
131 |
+
if not isinstance(dialog, list):
|
132 |
+
raise ValueError("Dialog must be a list of dictionaries.")
|
133 |
+
|
134 |
+
for i, entry in enumerate(dialog):
|
135 |
+
if not isinstance(entry, dict):
|
136 |
+
raise ValueError(
|
137 |
+
f"Entry {i} is not a dictionary: {entry}. Each entry in the dialog must be a dictionary."
|
138 |
+
)
|
139 |
+
|
140 |
+
if "role" not in entry:
|
141 |
+
raise ValueError(
|
142 |
+
f"Entry {i} is missing the 'role' key: {entry}. Each dictionary must have a 'role' key."
|
143 |
+
)
|
144 |
+
|
145 |
+
if "content" not in entry:
|
146 |
+
raise ValueError(
|
147 |
+
f"Entry {i} is missing the 'content' key: {entry}. Each dictionary must have a 'content' key."
|
148 |
+
)
|
149 |
+
|
150 |
+
if not isinstance(entry["role"], str):
|
151 |
+
raise ValueError(
|
152 |
+
f"Entry {i} has a non-string 'role': {entry['role']}. The 'role' value must be a string."
|
153 |
+
)
|
154 |
+
|
155 |
+
if not isinstance(entry["content"], str):
|
156 |
+
raise ValueError(
|
157 |
+
f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
|
158 |
+
)
|
159 |
+
|
160 |
+
if entry["role"] not in {"user", "assistant"}:
|
161 |
+
raise ValueError(
|
162 |
+
f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
|
163 |
+
)
|
164 |
+
|
165 |
+
first_entry = dialog[0]
|
166 |
+
if first_entry["role"] != "user":
|
167 |
+
raise ValueError(
|
168 |
+
f"First entry role is expected to be 'user' It is {first_entry['role']}."
|
169 |
+
)
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def merge_dialog_entries(dialog: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
173 |
+
"""Merges consecutive dialog entries with the same role.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
dialog (List[Dict[str, str]]): The input dialog list where each dictionary has a 'role' and 'content'.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
List[Dict[str, str]]: A new list where consecutive entries with the same role are merged.
|
180 |
+
"""
|
181 |
+
if len(dialog) == 0:
|
182 |
+
return []
|
183 |
+
|
184 |
+
merged_dialog = [dialog[0]]
|
185 |
+
|
186 |
+
for entry in dialog[1:]:
|
187 |
+
if entry["role"] == merged_dialog[-1]["role"]:
|
188 |
+
merged_dialog[-1]["content"] += " " + entry["content"]
|
189 |
+
else:
|
190 |
+
merged_dialog.append(entry)
|
191 |
+
|
192 |
+
return merged_dialog
|
193 |
+
|
194 |
+
def transform_dialog_to_standard_format(
|
195 |
+
self, dialog: List[Dict[str, str]]
|
196 |
+
) -> List[Dict[str, str]]:
|
197 |
+
"""Transforms a dialog from OpenAI format to a simplified format.
|
198 |
+
|
199 |
+
Each dictionary
|
200 |
+
contains 'user' and 'system' keys with their respective contents. Consecutive entries
|
201 |
+
with the same role are merged. Entries with invalid roles raise an error.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
dialog (List[Dict[str, str]]): The input dialog in OpenAI format.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
List[Dict[str, str]]: The transformed dialog.
|
208 |
+
|
209 |
+
Raises:
|
210 |
+
ValueError: If an invalid role is detected.
|
211 |
+
"""
|
212 |
+
SerializeOpenAiFormatDialog.validate_openai_dialog_format(dialog)
|
213 |
+
merged_dialog = SerializeOpenAiFormatDialog.merge_dialog_entries(dialog)
|
214 |
+
# self.validate_dialog_have_complete_pairs(merged_dialog)
|
215 |
+
|
216 |
+
result = []
|
217 |
+
for i in range(0, len(merged_dialog) - 1, 2):
|
218 |
+
user_entry = merged_dialog[i]
|
219 |
+
system_entry = merged_dialog[i + 1]
|
220 |
+
|
221 |
+
result.append(
|
222 |
+
{"user": user_entry["content"], "system": system_entry["content"]}
|
223 |
+
)
|
224 |
+
if len(merged_dialog) % 2 != 0:
|
225 |
+
user_entry = merged_dialog[-1]
|
226 |
+
result.append({"user": user_entry["content"], "system": ""})
|
227 |
+
|
228 |
+
return result
|
229 |
+
|
230 |
+
def process_instance_value(
|
231 |
+
self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any]
|
232 |
+
):
|
233 |
+
standard_format_dialog = self.transform_dialog_to_standard_format(
|
234 |
+
structured_dialog
|
235 |
+
)
|
236 |
+
return super().process_instance_value(standard_format_dialog, instance)
|
formats.py
CHANGED
@@ -164,7 +164,7 @@ class SystemFormat(BaseFormat):
|
|
164 |
demos is not None and isoftype(demos, List[Dict[str, Any]])
|
165 |
), f"A list of dict-s is expected in field '{self.demos_field}'. Received instance: {instance}"
|
166 |
demo_instances = demos
|
167 |
-
instance.pop(self.demos_field)
|
168 |
|
169 |
demos_string = ""
|
170 |
for demo_instance in demo_instances:
|
@@ -226,14 +226,16 @@ class HFSystemFormat(BaseFormat):
|
|
226 |
"""
|
227 |
|
228 |
model_name: str
|
|
|
229 |
|
230 |
-
def
|
231 |
-
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
232 |
-
) -> Dict[str, Any]:
|
233 |
from transformers import AutoTokenizer
|
234 |
|
235 |
-
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
236 |
|
|
|
|
|
|
|
237 |
assert (
|
238 |
"source" in instance
|
239 |
), f"field 'source' is expected to be in the input instance. Received instance: {instance}"
|
@@ -267,7 +269,7 @@ class HFSystemFormat(BaseFormat):
|
|
267 |
demos is not None and isoftype(demos, List[Dict[str, Any]])
|
268 |
), f"A list of dict-s is expected in field '{self.demos_field}'. Received instance: {instance}"
|
269 |
demo_instances = demos
|
270 |
-
instance.pop(self.demos_field)
|
271 |
|
272 |
for demo_instance in demo_instances:
|
273 |
messages.extend(
|
@@ -280,7 +282,7 @@ class HFSystemFormat(BaseFormat):
|
|
280 |
]
|
281 |
)
|
282 |
messages.extend([{"role": "user", "content": source}])
|
283 |
-
tokenized_chat = tokenizer.apply_chat_template(
|
284 |
messages, tokenize=False, add_generation_prompt=True
|
285 |
)
|
286 |
|
|
|
164 |
demos is not None and isoftype(demos, List[Dict[str, Any]])
|
165 |
), f"A list of dict-s is expected in field '{self.demos_field}'. Received instance: {instance}"
|
166 |
demo_instances = demos
|
167 |
+
# instance.pop(self.demos_field)
|
168 |
|
169 |
demos_string = ""
|
170 |
for demo_instance in demo_instances:
|
|
|
226 |
"""
|
227 |
|
228 |
model_name: str
|
229 |
+
_requirements_list = ["transformers"]
|
230 |
|
231 |
+
def prepare(self):
|
|
|
|
|
232 |
from transformers import AutoTokenizer
|
233 |
|
234 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
235 |
|
236 |
+
def process(
|
237 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
238 |
+
) -> Dict[str, Any]:
|
239 |
assert (
|
240 |
"source" in instance
|
241 |
), f"field 'source' is expected to be in the input instance. Received instance: {instance}"
|
|
|
269 |
demos is not None and isoftype(demos, List[Dict[str, Any]])
|
270 |
), f"A list of dict-s is expected in field '{self.demos_field}'. Received instance: {instance}"
|
271 |
demo_instances = demos
|
272 |
+
# instance.pop(self.demos_field)
|
273 |
|
274 |
for demo_instance in demo_instances:
|
275 |
messages.extend(
|
|
|
282 |
]
|
283 |
)
|
284 |
messages.extend([{"role": "user", "content": source}])
|
285 |
+
tokenized_chat = self.tokenizer.apply_chat_template(
|
286 |
messages, tokenize=False, add_generation_prompt=True
|
287 |
)
|
288 |
|
image_operators.py
CHANGED
@@ -1,8 +1,25 @@
|
|
|
|
|
|
1 |
import re
|
|
|
2 |
from typing import Any, Dict
|
3 |
|
|
|
|
|
4 |
from .dict_utils import dict_get
|
5 |
-
from .operators import InstanceFieldOperator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
def extract_images(text, instance):
|
@@ -15,12 +32,44 @@ def extract_images(text, instance):
|
|
15 |
return images
|
16 |
|
17 |
|
18 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def process_instance_value(self, value: Any, instance: Dict[str, Any]):
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
import re
|
4 |
+
from abc import abstractmethod
|
5 |
from typing import Any, Dict
|
6 |
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
from .dict_utils import dict_get
|
10 |
+
from .operators import FieldOperator, InstanceFieldOperator, PackageRequirementsMixin
|
11 |
+
|
12 |
+
|
13 |
+
class PillowMixin(PackageRequirementsMixin):
|
14 |
+
_requirements_list = {"PIL": "pip install pillow"}
|
15 |
+
|
16 |
+
def prepare(self):
|
17 |
+
super().prepare()
|
18 |
+
import PIL
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
self.pil = PIL
|
22 |
+
self.image = Image
|
23 |
|
24 |
|
25 |
def extract_images(text, instance):
|
|
|
32 |
return images
|
33 |
|
34 |
|
35 |
+
class DecodeImage(FieldOperator, PillowMixin):
|
36 |
+
def decode_base64_to_image(self, base64_string):
|
37 |
+
image_data = base64.b64decode(base64_string)
|
38 |
+
return self.image.open(io.BytesIO(image_data))
|
39 |
+
|
40 |
+
def process_value(self, value: Any) -> Any:
|
41 |
+
return {"image": self.decode_base64_to_image(value)}
|
42 |
+
|
43 |
+
|
44 |
+
class ToImage(InstanceFieldOperator):
|
45 |
def process_instance_value(self, value: Any, instance: Dict[str, Any]):
|
46 |
+
return {"image": value}
|
47 |
+
|
48 |
+
|
49 |
+
class ImageFieldOperator(FieldOperator, PillowMixin):
|
50 |
+
@abstractmethod
|
51 |
+
def process_image(self, image):
|
52 |
+
pass
|
53 |
+
|
54 |
+
def process_value(self, value: Any) -> Any:
|
55 |
+
if not isinstance(value, self.image.Image):
|
56 |
+
raise ValueError(f"ImageFieldOperator requires image, got {type(value)}.")
|
57 |
+
return self.process_image(value)
|
58 |
+
|
59 |
+
|
60 |
+
class GrayScale(ImageFieldOperator):
|
61 |
+
def process_image(self, image):
|
62 |
+
# Convert the image to grayscale
|
63 |
+
grayscale_image = image.convert("L")
|
64 |
+
|
65 |
+
# Convert the grayscale image to a NumPy array
|
66 |
+
grayscale_array = np.array(grayscale_image)
|
67 |
+
|
68 |
+
# Add a dummy channel dimension to make it (height, width, 1)
|
69 |
+
grayscale_array = np.expand_dims(grayscale_array, axis=-1)
|
70 |
+
|
71 |
+
# Repeat the channel to have (height, width, 3) if needed for compatibility
|
72 |
+
grayscale_array = np.repeat(grayscale_array, 3, axis=-1)
|
73 |
+
|
74 |
+
# Convert back to a PIL image with 3 channels
|
75 |
+
return self.image.fromarray(grayscale_array)
|
inference.py
CHANGED
@@ -5,12 +5,15 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
|
5 |
|
6 |
from tqdm import tqdm
|
7 |
|
8 |
-
from .artifact import Artifact
|
9 |
from .dataclass import InternalField, NonPositionalField
|
10 |
from .deprecation_utils import deprecation
|
11 |
from .image_operators import extract_images
|
12 |
from .logging_utils import get_logger
|
13 |
from .operator import PackageRequirementsMixin
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
class InferenceEngine(abc.ABC, Artifact):
|
@@ -21,9 +24,20 @@ class InferenceEngine(abc.ABC, Artifact):
|
|
21 |
"""Perform inference on the input dataset."""
|
22 |
pass
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def infer(self, dataset) -> str:
|
25 |
"""Verifies instances of a dataset and performs inference."""
|
26 |
[self.verify_instance(instance) for instance in dataset]
|
|
|
|
|
27 |
return self._infer(dataset)
|
28 |
|
29 |
@deprecation(version="2.0.0")
|
@@ -122,7 +136,7 @@ class HFPipelineBasedInferenceEngine(
|
|
122 |
model=self.model_name, trust_remote_code=True, **model_args
|
123 |
)
|
124 |
|
125 |
-
def
|
126 |
if not self.lazy_load:
|
127 |
self._prepare_pipeline()
|
128 |
|
@@ -144,13 +158,17 @@ class HFPipelineBasedInferenceEngine(
|
|
144 |
class MockInferenceEngine(InferenceEngine):
|
145 |
model_name: str
|
146 |
|
147 |
-
def
|
148 |
return
|
149 |
|
150 |
def _infer(self, dataset):
|
151 |
return ["[[10]]" for instance in dataset]
|
152 |
|
153 |
|
|
|
|
|
|
|
|
|
154 |
class IbmGenAiInferenceEngineParamsMixin(Artifact):
|
155 |
beam_width: Optional[int] = None
|
156 |
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
@@ -190,6 +208,57 @@ class IbmGenAiInferenceEngineParams(Artifact):
|
|
190 |
typical_p: Optional[float] = None
|
191 |
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
class IbmGenAiInferenceEngine(
|
194 |
InferenceEngine, IbmGenAiInferenceEngineParamsMixin, PackageRequirementsMixin
|
195 |
):
|
@@ -201,11 +270,12 @@ class IbmGenAiInferenceEngine(
|
|
201 |
data_classification_policy = ["public", "proprietary"]
|
202 |
parameters: Optional[IbmGenAiInferenceEngineParams] = None
|
203 |
|
204 |
-
def
|
205 |
from genai import Client, Credentials
|
206 |
|
207 |
api_key_env_var_name = "GENAI_KEY"
|
208 |
api_key = os.environ.get(api_key_env_var_name)
|
|
|
209 |
assert api_key is not None, (
|
210 |
f"Error while trying to run IbmGenAiInferenceEngine."
|
211 |
f" Please set the environment param '{api_key_env_var_name}'."
|
@@ -242,9 +312,9 @@ class OpenAiInferenceEngineParamsMixin(Artifact):
|
|
242 |
top_p: Optional[float] = None
|
243 |
top_logprobs: Optional[int] = 20
|
244 |
logit_bias: Optional[Dict[str, int]] = None
|
245 |
-
logprobs: Optional[bool] =
|
246 |
n: Optional[int] = None
|
247 |
-
parallel_tool_calls: bool = None
|
248 |
service_tier: Optional[Literal["auto", "default"]] = None
|
249 |
|
250 |
|
@@ -259,9 +329,9 @@ class OpenAiInferenceEngineParams(Artifact):
|
|
259 |
top_p: Optional[float] = None
|
260 |
top_logprobs: Optional[int] = 20
|
261 |
logit_bias: Optional[Dict[str, int]] = None
|
262 |
-
logprobs: Optional[bool] =
|
263 |
n: Optional[int] = None
|
264 |
-
parallel_tool_calls: bool = None
|
265 |
service_tier: Optional[Literal["auto", "default"]] = None
|
266 |
|
267 |
|
@@ -279,7 +349,7 @@ class OpenAiInferenceEngine(
|
|
279 |
data_classification_policy = ["public"]
|
280 |
parameters: Optional[OpenAiInferenceEngineParams] = None
|
281 |
|
282 |
-
def
|
283 |
from openai import OpenAI
|
284 |
|
285 |
api_key_env_var_name = "OPENAI_API_KEY"
|
@@ -293,6 +363,13 @@ class OpenAiInferenceEngine(
|
|
293 |
|
294 |
self._set_inference_parameters()
|
295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
def _infer(self, dataset):
|
297 |
outputs = []
|
298 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
@@ -308,7 +385,7 @@ class OpenAiInferenceEngine(
|
|
308 |
}
|
309 |
],
|
310 |
model=self.model_name,
|
311 |
-
**self.
|
312 |
)
|
313 |
output = response.choices[0].message.content
|
314 |
|
@@ -331,7 +408,7 @@ class OpenAiInferenceEngine(
|
|
331 |
}
|
332 |
],
|
333 |
model=self.model_name,
|
334 |
-
**self.
|
335 |
)
|
336 |
top_logprobs_response = response.choices[0].logprobs.content
|
337 |
output = [
|
@@ -347,6 +424,96 @@ class OpenAiInferenceEngine(
|
|
347 |
return outputs
|
348 |
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
class WMLInferenceEngineParamsMixin(Artifact):
|
351 |
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
352 |
length_penalty: Optional[Dict[str, Union[int, float]]] = None
|
@@ -400,6 +567,7 @@ class WMLInferenceEngine(
|
|
400 |
parameters (WMLInferenceEngineParams, optional): Instance of WMLInferenceEngineParams
|
401 |
which defines inference parameters and their values. Deprecated attribute, please
|
402 |
pass respective parameters directly to the WMLInferenceEngine class instead.
|
|
|
403 |
|
404 |
Examples:
|
405 |
from .api import load_dataset
|
@@ -433,7 +601,7 @@ class WMLInferenceEngine(
|
|
433 |
}
|
434 |
data_classification_policy = ["public", "proprietary"]
|
435 |
parameters: Optional[WMLInferenceEngineParams] = None
|
436 |
-
|
437 |
_client: Any = InternalField(default=None, name="WML client")
|
438 |
|
439 |
def verify(self):
|
@@ -490,7 +658,7 @@ class WMLInferenceEngine(
|
|
490 |
client.set.default_project(self.credentials["project_id"])
|
491 |
return client
|
492 |
|
493 |
-
def
|
494 |
self._client = self._initialize_wml_client()
|
495 |
|
496 |
self._set_inference_parameters()
|
@@ -504,10 +672,19 @@ class WMLInferenceEngine(
|
|
504 |
api_client=self._client,
|
505 |
)
|
506 |
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
511 |
|
512 |
|
513 |
class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
@@ -541,7 +718,7 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
541 |
|
542 |
self.processor = AutoProcessor.from_pretrained(self.model_name)
|
543 |
|
544 |
-
def
|
545 |
if not self.lazy_load:
|
546 |
self._prepare_engine()
|
547 |
|
|
|
5 |
|
6 |
from tqdm import tqdm
|
7 |
|
8 |
+
from .artifact import Artifact, fetch_artifact
|
9 |
from .dataclass import InternalField, NonPositionalField
|
10 |
from .deprecation_utils import deprecation
|
11 |
from .image_operators import extract_images
|
12 |
from .logging_utils import get_logger
|
13 |
from .operator import PackageRequirementsMixin
|
14 |
+
from .settings_utils import get_settings
|
15 |
+
|
16 |
+
settings = get_settings()
|
17 |
|
18 |
|
19 |
class InferenceEngine(abc.ABC, Artifact):
|
|
|
24 |
"""Perform inference on the input dataset."""
|
25 |
pass
|
26 |
|
27 |
+
@abc.abstractmethod
|
28 |
+
def prepare_engine(self):
|
29 |
+
"""Perform inference on the input dataset."""
|
30 |
+
pass
|
31 |
+
|
32 |
+
def prepare(self):
|
33 |
+
if not settings.mock_inference_mode:
|
34 |
+
self.prepare_engine()
|
35 |
+
|
36 |
def infer(self, dataset) -> str:
|
37 |
"""Verifies instances of a dataset and performs inference."""
|
38 |
[self.verify_instance(instance) for instance in dataset]
|
39 |
+
if settings.mock_inference_mode:
|
40 |
+
return [instance["source"] for instance in dataset]
|
41 |
return self._infer(dataset)
|
42 |
|
43 |
@deprecation(version="2.0.0")
|
|
|
136 |
model=self.model_name, trust_remote_code=True, **model_args
|
137 |
)
|
138 |
|
139 |
+
def prepare_engine(self):
|
140 |
if not self.lazy_load:
|
141 |
self._prepare_pipeline()
|
142 |
|
|
|
158 |
class MockInferenceEngine(InferenceEngine):
|
159 |
model_name: str
|
160 |
|
161 |
+
def prepare_engine(self):
|
162 |
return
|
163 |
|
164 |
def _infer(self, dataset):
|
165 |
return ["[[10]]" for instance in dataset]
|
166 |
|
167 |
|
168 |
+
class MockModeMixin(Artifact):
|
169 |
+
mock_mode: bool = False
|
170 |
+
|
171 |
+
|
172 |
class IbmGenAiInferenceEngineParamsMixin(Artifact):
|
173 |
beam_width: Optional[int] = None
|
174 |
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
|
|
208 |
typical_p: Optional[float] = None
|
209 |
|
210 |
|
211 |
+
class GenericInferenceEngine(InferenceEngine):
|
212 |
+
default: Optional[str] = None
|
213 |
+
|
214 |
+
def prepare_engine(self):
|
215 |
+
if "UNITXT_INFERENCE_ENGINE" in os.environ:
|
216 |
+
engine_reference = os.environ["UNITXT_INFERENCE_ENGINE"]
|
217 |
+
else:
|
218 |
+
assert self.default is not None, (
|
219 |
+
"GenericInferenceEngine could not be initialized"
|
220 |
+
'\nThis is since both the "UNITXT_INFERENCE_ENGINE" environmental variable is not set and no default engine was not inputted.'
|
221 |
+
"\nFor example, you can fix it by setting"
|
222 |
+
"\nexport UNITXT_INFERENCE_ENGINE=engines.ibm_gen_ai.llama_3_70b_instruct"
|
223 |
+
"\nto your ~/.bashrc"
|
224 |
+
"\nor passing a similar required engine in the default argument"
|
225 |
+
)
|
226 |
+
engine_reference = self.default
|
227 |
+
self.engine, _ = fetch_artifact(engine_reference)
|
228 |
+
|
229 |
+
def _infer(self, dataset):
|
230 |
+
return self.engine._infer(dataset)
|
231 |
+
|
232 |
+
|
233 |
+
class OllamaInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
234 |
+
label: str = "ollama"
|
235 |
+
model_name: str
|
236 |
+
_requirements_list = {
|
237 |
+
"ollama": "Install ollama package using 'pip install --upgrade ollama"
|
238 |
+
}
|
239 |
+
data_classification_policy = ["public", "proprietary"]
|
240 |
+
|
241 |
+
def prepare_engine(self):
|
242 |
+
pass
|
243 |
+
|
244 |
+
def _infer(self, dataset):
|
245 |
+
import ollama
|
246 |
+
|
247 |
+
result = [
|
248 |
+
ollama.chat(
|
249 |
+
model="llama2",
|
250 |
+
messages=[
|
251 |
+
{
|
252 |
+
"role": "user",
|
253 |
+
"content": instance["source"],
|
254 |
+
},
|
255 |
+
],
|
256 |
+
)
|
257 |
+
for instance in dataset
|
258 |
+
]
|
259 |
+
return [element["message"]["content"] for element in result]
|
260 |
+
|
261 |
+
|
262 |
class IbmGenAiInferenceEngine(
|
263 |
InferenceEngine, IbmGenAiInferenceEngineParamsMixin, PackageRequirementsMixin
|
264 |
):
|
|
|
270 |
data_classification_policy = ["public", "proprietary"]
|
271 |
parameters: Optional[IbmGenAiInferenceEngineParams] = None
|
272 |
|
273 |
+
def prepare_engine(self):
|
274 |
from genai import Client, Credentials
|
275 |
|
276 |
api_key_env_var_name = "GENAI_KEY"
|
277 |
api_key = os.environ.get(api_key_env_var_name)
|
278 |
+
|
279 |
assert api_key is not None, (
|
280 |
f"Error while trying to run IbmGenAiInferenceEngine."
|
281 |
f" Please set the environment param '{api_key_env_var_name}'."
|
|
|
312 |
top_p: Optional[float] = None
|
313 |
top_logprobs: Optional[int] = 20
|
314 |
logit_bias: Optional[Dict[str, int]] = None
|
315 |
+
logprobs: Optional[bool] = True
|
316 |
n: Optional[int] = None
|
317 |
+
parallel_tool_calls: Optional[bool] = None
|
318 |
service_tier: Optional[Literal["auto", "default"]] = None
|
319 |
|
320 |
|
|
|
329 |
top_p: Optional[float] = None
|
330 |
top_logprobs: Optional[int] = 20
|
331 |
logit_bias: Optional[Dict[str, int]] = None
|
332 |
+
logprobs: Optional[bool] = True
|
333 |
n: Optional[int] = None
|
334 |
+
parallel_tool_calls: Optional[bool] = None
|
335 |
service_tier: Optional[Literal["auto", "default"]] = None
|
336 |
|
337 |
|
|
|
349 |
data_classification_policy = ["public"]
|
350 |
parameters: Optional[OpenAiInferenceEngineParams] = None
|
351 |
|
352 |
+
def prepare_engine(self):
|
353 |
from openai import OpenAI
|
354 |
|
355 |
api_key_env_var_name = "OPENAI_API_KEY"
|
|
|
363 |
|
364 |
self._set_inference_parameters()
|
365 |
|
366 |
+
def _get_completion_kwargs(self):
|
367 |
+
return {
|
368 |
+
k: v
|
369 |
+
for k, v in self.to_dict([OpenAiInferenceEngineParamsMixin]).items()
|
370 |
+
if v is not None
|
371 |
+
}
|
372 |
+
|
373 |
def _infer(self, dataset):
|
374 |
outputs = []
|
375 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
|
|
385 |
}
|
386 |
],
|
387 |
model=self.model_name,
|
388 |
+
**self._get_completion_kwargs(),
|
389 |
)
|
390 |
output = response.choices[0].message.content
|
391 |
|
|
|
408 |
}
|
409 |
],
|
410 |
model=self.model_name,
|
411 |
+
**self._get_completion_kwargs(),
|
412 |
)
|
413 |
top_logprobs_response = response.choices[0].logprobs.content
|
414 |
output = [
|
|
|
424 |
return outputs
|
425 |
|
426 |
|
427 |
+
class TogetherAiInferenceEngineParamsMixin(Artifact):
|
428 |
+
max_tokens: Optional[int] = None
|
429 |
+
stop: Optional[List[str]] = None
|
430 |
+
temperature: Optional[float] = None
|
431 |
+
top_p: Optional[float] = None
|
432 |
+
top_k: Optional[int] = None
|
433 |
+
repetition_penalty: Optional[float] = None
|
434 |
+
logprobs: Optional[int] = None
|
435 |
+
echo: Optional[bool] = None
|
436 |
+
n: Optional[int] = None
|
437 |
+
min_p: Optional[float] = None
|
438 |
+
presence_penalty: Optional[float] = None
|
439 |
+
frequency_penalty: Optional[float] = None
|
440 |
+
|
441 |
+
|
442 |
+
class TogetherAiInferenceEngine(
|
443 |
+
InferenceEngine, TogetherAiInferenceEngineParamsMixin, PackageRequirementsMixin
|
444 |
+
):
|
445 |
+
label: str = "together"
|
446 |
+
model_name: str
|
447 |
+
_requirements_list = {
|
448 |
+
"together": "Install together package using 'pip install --upgrade together"
|
449 |
+
}
|
450 |
+
data_classification_policy = ["public"]
|
451 |
+
parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None
|
452 |
+
|
453 |
+
def prepare_engine(self):
|
454 |
+
from together import Together
|
455 |
+
from together.types.models import ModelType
|
456 |
+
|
457 |
+
api_key_env_var_name = "TOGETHER_API_KEY"
|
458 |
+
api_key = os.environ.get(api_key_env_var_name)
|
459 |
+
assert api_key is not None, (
|
460 |
+
f"Error while trying to run TogetherAiInferenceEngine."
|
461 |
+
f" Please set the environment param '{api_key_env_var_name}'."
|
462 |
+
)
|
463 |
+
self.client = Together(api_key=api_key)
|
464 |
+
self._set_inference_parameters()
|
465 |
+
|
466 |
+
# Get model type from Together List Models API
|
467 |
+
together_models = self.client.models.list()
|
468 |
+
together_model_id_to_type = {
|
469 |
+
together_model.id: together_model.type for together_model in together_models
|
470 |
+
}
|
471 |
+
model_type = together_model_id_to_type.get(self.model_name)
|
472 |
+
assert model_type is not None, (
|
473 |
+
f"Could not find model {self.model_name} " "in Together AI model list"
|
474 |
+
)
|
475 |
+
assert model_type in [ModelType.CHAT, ModelType.LANGUAGE, ModelType.CODE], (
|
476 |
+
f"Together AI model type {model_type} is not supported; "
|
477 |
+
"supported types are 'chat', 'language' and 'code'."
|
478 |
+
)
|
479 |
+
self.model_type = model_type
|
480 |
+
|
481 |
+
def _get_infer_kwargs(self):
|
482 |
+
return {
|
483 |
+
k: v
|
484 |
+
for k, v in self.to_dict([TogetherAiInferenceEngineParamsMixin]).items()
|
485 |
+
if v is not None
|
486 |
+
}
|
487 |
+
|
488 |
+
def _infer_chat(self, prompt: str) -> str:
|
489 |
+
response = self.client.chat.completions.create(
|
490 |
+
model=self.model_name,
|
491 |
+
messages=[{"role": "user", "content": prompt}],
|
492 |
+
**self._get_infer_kwargs(),
|
493 |
+
)
|
494 |
+
return response.choices[0].message.content
|
495 |
+
|
496 |
+
def _infer_text(self, prompt: str) -> str:
|
497 |
+
response = self.client.completions.create(
|
498 |
+
model=self.model_name,
|
499 |
+
prompt=prompt,
|
500 |
+
**self._get_infer_kwargs(),
|
501 |
+
)
|
502 |
+
return response.choices[0].text
|
503 |
+
|
504 |
+
def _infer(self, dataset):
|
505 |
+
from together.types.models import ModelType
|
506 |
+
|
507 |
+
outputs = []
|
508 |
+
if self.model_type == ModelType.CHAT:
|
509 |
+
for instance in tqdm(dataset, desc="Inferring with Together AI Chat API"):
|
510 |
+
outputs.append(self._infer_chat(instance["source"]))
|
511 |
+
else:
|
512 |
+
for instance in tqdm(dataset, desc="Inferring with Together AI Text API"):
|
513 |
+
outputs.append(self._infer_text(instance["source"]))
|
514 |
+
return outputs
|
515 |
+
|
516 |
+
|
517 |
class WMLInferenceEngineParamsMixin(Artifact):
|
518 |
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
519 |
length_penalty: Optional[Dict[str, Union[int, float]]] = None
|
|
|
567 |
parameters (WMLInferenceEngineParams, optional): Instance of WMLInferenceEngineParams
|
568 |
which defines inference parameters and their values. Deprecated attribute, please
|
569 |
pass respective parameters directly to the WMLInferenceEngine class instead.
|
570 |
+
concurrency_limit (int): number of requests that will be sent in parallel, max is 10.
|
571 |
|
572 |
Examples:
|
573 |
from .api import load_dataset
|
|
|
601 |
}
|
602 |
data_classification_policy = ["public", "proprietary"]
|
603 |
parameters: Optional[WMLInferenceEngineParams] = None
|
604 |
+
concurrency_limit: int = 10
|
605 |
_client: Any = InternalField(default=None, name="WML client")
|
606 |
|
607 |
def verify(self):
|
|
|
658 |
client.set.default_project(self.credentials["project_id"])
|
659 |
return client
|
660 |
|
661 |
+
def prepare_engine(self):
|
662 |
self._client = self._initialize_wml_client()
|
663 |
|
664 |
self._set_inference_parameters()
|
|
|
672 |
api_client=self._client,
|
673 |
)
|
674 |
|
675 |
+
# the class was previously used with a dataset that is a single instance
|
676 |
+
dataset = dataset if isinstance(dataset, list) else [dataset]
|
677 |
+
|
678 |
+
result = [
|
679 |
+
model.generate_text(
|
680 |
+
prompt=instance["source"],
|
681 |
+
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
|
682 |
+
)
|
683 |
+
for instance in dataset
|
684 |
+
]
|
685 |
+
|
686 |
+
# the class was previously used with a dataset that is a single instance
|
687 |
+
return result[0] if not isinstance(dataset, list) else result
|
688 |
|
689 |
|
690 |
class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
|
718 |
|
719 |
self.processor = AutoProcessor.from_pretrained(self.model_name)
|
720 |
|
721 |
+
def prepare_engine(self):
|
722 |
if not self.lazy_load:
|
723 |
self._prepare_engine()
|
724 |
|
llm_as_judge.py
CHANGED
@@ -144,13 +144,13 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
144 |
)
|
145 |
|
146 |
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
147 |
-
if self.format:
|
148 |
raise ValueError(
|
149 |
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
150 |
"not support formatting. Please remove the format definition from the recipe"
|
151 |
" (OpenAi Chat API take care of the formatting automatically)."
|
152 |
)
|
153 |
-
if self.system_prompt:
|
154 |
raise ValueError(
|
155 |
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
156 |
"not support system prompt. Please remove the system_prompt definition from the recipe"
|
@@ -181,9 +181,17 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
181 |
results = []
|
182 |
for instance in outputs:
|
183 |
if self.task == "pairwise_comparative_rating.single_turn":
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
)
|
|
|
|
|
187 |
if is_model_b_the_baseline:
|
188 |
model_a_preference_score = instance["prediction"]
|
189 |
else:
|
|
|
144 |
)
|
145 |
|
146 |
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
147 |
+
if self.format and type(self.format) is not SystemFormat:
|
148 |
raise ValueError(
|
149 |
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
150 |
"not support formatting. Please remove the format definition from the recipe"
|
151 |
" (OpenAi Chat API take care of the formatting automatically)."
|
152 |
)
|
153 |
+
if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
|
154 |
raise ValueError(
|
155 |
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
156 |
"not support system prompt. Please remove the system_prompt definition from the recipe"
|
|
|
181 |
results = []
|
182 |
for instance in outputs:
|
183 |
if self.task == "pairwise_comparative_rating.single_turn":
|
184 |
+
import json
|
185 |
+
|
186 |
+
# seems like the task data sometimes comes as a string, not a dict
|
187 |
+
# this fixes it
|
188 |
+
task_data = (
|
189 |
+
json.loads(instance["task_data"])
|
190 |
+
if isinstance(instance["task_data"], str)
|
191 |
+
else instance["task_data"]
|
192 |
)
|
193 |
+
|
194 |
+
is_model_b_the_baseline = task_data["model_b"] == "baseline_model"
|
195 |
if is_model_b_the_baseline:
|
196 |
model_a_preference_score = instance["prediction"]
|
197 |
else:
|
loaders.py
CHANGED
@@ -151,6 +151,7 @@ class LoadHF(Loader):
|
|
151 |
data_dir: Optional directory to store downloaded data.
|
152 |
split: Optional specification of which split to load.
|
153 |
data_files: Optional specification of particular data files to load.
|
|
|
154 |
streaming: Bool indicating if streaming should be used.
|
155 |
filtering_lambda: A lambda function for filtering the data after loading.
|
156 |
num_proc: Optional integer to specify the number of processes to use for parallel dataset loading.
|
@@ -170,6 +171,7 @@ class LoadHF(Loader):
|
|
170 |
data_files: Optional[
|
171 |
Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
|
172 |
] = None
|
|
|
173 |
streaming: bool = True
|
174 |
filtering_lambda: Optional[str] = None
|
175 |
num_proc: Optional[int] = None
|
@@ -199,6 +201,7 @@ class LoadHF(Loader):
|
|
199 |
name=self.name,
|
200 |
data_dir=self.data_dir,
|
201 |
data_files=self.data_files,
|
|
|
202 |
streaming=self.streaming,
|
203 |
cache_dir=None if self.streaming else dir_to_be_deleted,
|
204 |
split=self.split,
|
@@ -488,6 +491,7 @@ class LoadFromIBMCloud(Loader):
|
|
488 |
bucket_name: Name of the S3 bucket from which to load data.
|
489 |
data_dir: Optional directory path within the bucket.
|
490 |
data_files: Union type allowing either a list of file names or a mapping of splits to file names.
|
|
|
491 |
caching: Bool indicating if caching is enabled to avoid re-downloading data.
|
492 |
|
493 |
Example:
|
@@ -511,6 +515,7 @@ class LoadFromIBMCloud(Loader):
|
|
511 |
data_dir: str = None
|
512 |
|
513 |
data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
|
|
|
514 |
caching: bool = True
|
515 |
data_classification_policy = ["proprietary"]
|
516 |
|
@@ -636,10 +641,13 @@ class LoadFromIBMCloud(Loader):
|
|
636 |
)
|
637 |
|
638 |
if isinstance(self.data_files, list):
|
639 |
-
dataset = hf_load_dataset(local_dir, streaming=False)
|
640 |
else:
|
641 |
dataset = hf_load_dataset(
|
642 |
-
local_dir,
|
|
|
|
|
|
|
643 |
)
|
644 |
|
645 |
return MultiStream.from_iterables(dataset)
|
|
|
151 |
data_dir: Optional directory to store downloaded data.
|
152 |
split: Optional specification of which split to load.
|
153 |
data_files: Optional specification of particular data files to load.
|
154 |
+
revision: Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
|
155 |
streaming: Bool indicating if streaming should be used.
|
156 |
filtering_lambda: A lambda function for filtering the data after loading.
|
157 |
num_proc: Optional integer to specify the number of processes to use for parallel dataset loading.
|
|
|
171 |
data_files: Optional[
|
172 |
Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
|
173 |
] = None
|
174 |
+
revision: Optional[str] = None
|
175 |
streaming: bool = True
|
176 |
filtering_lambda: Optional[str] = None
|
177 |
num_proc: Optional[int] = None
|
|
|
201 |
name=self.name,
|
202 |
data_dir=self.data_dir,
|
203 |
data_files=self.data_files,
|
204 |
+
revision=self.revision,
|
205 |
streaming=self.streaming,
|
206 |
cache_dir=None if self.streaming else dir_to_be_deleted,
|
207 |
split=self.split,
|
|
|
491 |
bucket_name: Name of the S3 bucket from which to load data.
|
492 |
data_dir: Optional directory path within the bucket.
|
493 |
data_files: Union type allowing either a list of file names or a mapping of splits to file names.
|
494 |
+
data_field: The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
|
495 |
caching: Bool indicating if caching is enabled to avoid re-downloading data.
|
496 |
|
497 |
Example:
|
|
|
515 |
data_dir: str = None
|
516 |
|
517 |
data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
|
518 |
+
data_field: str = None
|
519 |
caching: bool = True
|
520 |
data_classification_policy = ["proprietary"]
|
521 |
|
|
|
641 |
)
|
642 |
|
643 |
if isinstance(self.data_files, list):
|
644 |
+
dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
|
645 |
else:
|
646 |
dataset = hf_load_dataset(
|
647 |
+
local_dir,
|
648 |
+
streaming=False,
|
649 |
+
data_files=self.data_files,
|
650 |
+
field=self.data_field,
|
651 |
)
|
652 |
|
653 |
return MultiStream.from_iterables(dataset)
|
metric.py
CHANGED
@@ -4,6 +4,7 @@ import evaluate
|
|
4 |
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
|
|
7 |
from .benchmark import __file__ as _
|
8 |
from .blocks import __file__ as _
|
9 |
from .card import __file__ as _
|
@@ -42,6 +43,7 @@ from .random_utils import __file__ as _
|
|
42 |
from .recipe import __file__ as _
|
43 |
from .register import __file__ as _
|
44 |
from .schema import __file__ as _
|
|
|
45 |
from .settings_utils import __file__ as _
|
46 |
from .settings_utils import get_constants
|
47 |
from .span_lableing_operators import __file__ as _
|
@@ -57,6 +59,7 @@ from .task import __file__ as _
|
|
57 |
from .templates import __file__ as _
|
58 |
from .text_utils import __file__ as _
|
59 |
from .type_utils import __file__ as _
|
|
|
60 |
from .utils import __file__ as _
|
61 |
from .utils import is_package_installed
|
62 |
from .validate import __file__ as _
|
|
|
4 |
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
7 |
+
from .augmentors import __file__ as _
|
8 |
from .benchmark import __file__ as _
|
9 |
from .blocks import __file__ as _
|
10 |
from .card import __file__ as _
|
|
|
43 |
from .recipe import __file__ as _
|
44 |
from .register import __file__ as _
|
45 |
from .schema import __file__ as _
|
46 |
+
from .serializers import __file__ as _
|
47 |
from .settings_utils import __file__ as _
|
48 |
from .settings_utils import get_constants
|
49 |
from .span_lableing_operators import __file__ as _
|
|
|
59 |
from .templates import __file__ as _
|
60 |
from .text_utils import __file__ as _
|
61 |
from .type_utils import __file__ as _
|
62 |
+
from .types import __file__ as _
|
63 |
from .utils import __file__ as _
|
64 |
from .utils import is_package_installed
|
65 |
from .validate import __file__ as _
|
metrics.py
CHANGED
@@ -421,7 +421,7 @@ class MetricWithConfidenceInterval(Metric):
|
|
421 |
full_score_name = ci_score_prefix + score_name
|
422 |
result[f"{full_score_name}_ci_low"] = ci.low
|
423 |
result[f"{full_score_name}_ci_high"] = ci.high
|
424 |
-
if score_name == self.main_score:
|
425 |
result["score_ci_low"] = ci.low
|
426 |
result["score_ci_high"] = ci.high
|
427 |
return result
|
@@ -1183,7 +1183,11 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1183 |
return instances
|
1184 |
|
1185 |
def get_group_scores(
|
1186 |
-
self,
|
|
|
|
|
|
|
|
|
1187 |
):
|
1188 |
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
1189 |
|
@@ -1193,6 +1197,8 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1193 |
group_aggregation_func: Callable aggregation function accepting a list of numeric scores;
|
1194 |
or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
|
1195 |
callable function returns a single score for the group
|
|
|
|
|
1196 |
|
1197 |
Returns:
|
1198 |
List of dicts, each corresponding to a group of instances (defined by 'group_id'),
|
@@ -1222,7 +1228,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1222 |
)
|
1223 |
for score_name in score_names:
|
1224 |
group_to_instance_scores[group_key][score_name][subgroup_type].append(
|
1225 |
-
instance["score"]["instance"][
|
|
|
|
|
1226 |
)
|
1227 |
|
1228 |
# if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
|
@@ -1230,7 +1238,8 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1230 |
{
|
1231 |
"score": {
|
1232 |
"instance": {
|
1233 |
-
|
|
|
1234 |
score_dict
|
1235 |
if uses_subgroups
|
1236 |
else score_dict[default_subgroup_name]
|
@@ -1268,7 +1277,7 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1268 |
group_aggregation_func=group_aggregation_func,
|
1269 |
):
|
1270 |
group_scores = self.get_group_scores(
|
1271 |
-
instances, [field_name], group_aggregation_func
|
1272 |
)
|
1273 |
return nan_mean(
|
1274 |
[group["score"]["instance"][field_name] for group in group_scores]
|
@@ -4565,8 +4574,7 @@ class NormalizedSacrebleu(HuggingfaceMetric):
|
|
4565 |
scaled_fields = ["sacrebleu", "precisions"]
|
4566 |
hf_additional_input_fields_pass_one_value = ["tokenize"]
|
4567 |
_requirements_list = {
|
4568 |
-
"
|
4569 |
-
"mecab_ko_dic": KO_ERROR_MESSAGE,
|
4570 |
}
|
4571 |
|
4572 |
|
|
|
421 |
full_score_name = ci_score_prefix + score_name
|
422 |
result[f"{full_score_name}_ci_low"] = ci.low
|
423 |
result[f"{full_score_name}_ci_high"] = ci.high
|
424 |
+
if score_name == self.score_prefix + self.main_score:
|
425 |
result["score_ci_low"] = ci.low
|
426 |
result["score_ci_high"] = ci.high
|
427 |
return result
|
|
|
1183 |
return instances
|
1184 |
|
1185 |
def get_group_scores(
|
1186 |
+
self,
|
1187 |
+
instances: List[dict],
|
1188 |
+
score_names: List[str],
|
1189 |
+
group_aggregation_func,
|
1190 |
+
prepend_score_prefix: bool = True,
|
1191 |
):
|
1192 |
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
1193 |
|
|
|
1197 |
group_aggregation_func: Callable aggregation function accepting a list of numeric scores;
|
1198 |
or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
|
1199 |
callable function returns a single score for the group
|
1200 |
+
prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False
|
1201 |
+
if down the stream such a prepending is expected.
|
1202 |
|
1203 |
Returns:
|
1204 |
List of dicts, each corresponding to a group of instances (defined by 'group_id'),
|
|
|
1228 |
)
|
1229 |
for score_name in score_names:
|
1230 |
group_to_instance_scores[group_key][score_name][subgroup_type].append(
|
1231 |
+
instance["score"]["instance"][
|
1232 |
+
(self.score_prefix if prepend_score_prefix else "") + score_name
|
1233 |
+
]
|
1234 |
)
|
1235 |
|
1236 |
# if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
|
|
|
1238 |
{
|
1239 |
"score": {
|
1240 |
"instance": {
|
1241 |
+
(self.score_prefix if prepend_score_prefix else "")
|
1242 |
+
+ score_name: group_aggregation_func(
|
1243 |
score_dict
|
1244 |
if uses_subgroups
|
1245 |
else score_dict[default_subgroup_name]
|
|
|
1277 |
group_aggregation_func=group_aggregation_func,
|
1278 |
):
|
1279 |
group_scores = self.get_group_scores(
|
1280 |
+
instances, [field_name], group_aggregation_func, False
|
1281 |
)
|
1282 |
return nan_mean(
|
1283 |
[group["score"]["instance"][field_name] for group in group_scores]
|
|
|
4574 |
scaled_fields = ["sacrebleu", "precisions"]
|
4575 |
hf_additional_input_fields_pass_one_value = ["tokenize"]
|
4576 |
_requirements_list = {
|
4577 |
+
"sacrebleu": "Additional dependencies required. To install them, run: `pip install sacrebleu`."
|
|
|
4578 |
}
|
4579 |
|
4580 |
|
operators.py
CHANGED
@@ -531,230 +531,6 @@ class AddConstant(FieldOperator):
|
|
531 |
return self.add + value
|
532 |
|
533 |
|
534 |
-
class Augmentor(InstanceOperator):
|
535 |
-
"""A stream operator that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template.
|
536 |
-
|
537 |
-
Args:
|
538 |
-
augment_model_input: Whether to augment the input to the model.
|
539 |
-
augment_task_input: Whether to augment the task input fields. The specific fields are defined in the Task operator.
|
540 |
-
|
541 |
-
"""
|
542 |
-
|
543 |
-
augment_task_input: bool = False
|
544 |
-
augment_model_input: bool = False
|
545 |
-
|
546 |
-
def verify(self):
|
547 |
-
assert not (
|
548 |
-
self.augment_task_input and self.augment_model_input
|
549 |
-
), "Augmentor must set either 'augment_task_input' and 'augment_model_input' but not both"
|
550 |
-
assert (
|
551 |
-
self.augment_task_input or self.augment_model_input
|
552 |
-
), "Augmentor must set either 'augment_task_input' or 'augment_model_input'"
|
553 |
-
|
554 |
-
super().verify()
|
555 |
-
|
556 |
-
@abstractmethod
|
557 |
-
def process_value(self, value: Any) -> Any:
|
558 |
-
pass
|
559 |
-
|
560 |
-
def prepare(self):
|
561 |
-
pass
|
562 |
-
|
563 |
-
def set_task_input_fields(self, task_input_fields: List[str]):
|
564 |
-
self._task_input_fields = [
|
565 |
-
"input_fields/" + task_input_field for task_input_field in task_input_fields
|
566 |
-
]
|
567 |
-
|
568 |
-
def process(
|
569 |
-
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
570 |
-
) -> Dict[str, Any]:
|
571 |
-
if self.augment_task_input:
|
572 |
-
assert (
|
573 |
-
len(self._task_input_fields) > 0
|
574 |
-
), "No augmentable input fields were defined in Task, and augmentation was requested. Specify the fields to augment in 'argumentable_inputs' attribute of the Task."
|
575 |
-
fields = self._task_input_fields
|
576 |
-
assert not self.augment_model_input
|
577 |
-
|
578 |
-
if self.augment_model_input:
|
579 |
-
fields = ["source"]
|
580 |
-
assert not self.augment_task_input
|
581 |
-
|
582 |
-
for field_name in fields:
|
583 |
-
try:
|
584 |
-
old_value = dict_get(
|
585 |
-
instance,
|
586 |
-
field_name,
|
587 |
-
default="",
|
588 |
-
not_exist_ok=False,
|
589 |
-
)
|
590 |
-
except ValueError as e:
|
591 |
-
raise TypeError(f"Failed to get {field_name} from {instance}") from e
|
592 |
-
|
593 |
-
try:
|
594 |
-
new_value = self.process_value(old_value)
|
595 |
-
except Exception as e:
|
596 |
-
raise RuntimeError(
|
597 |
-
f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
|
598 |
-
) from e
|
599 |
-
dict_set(instance, field_name, new_value, not_exist_ok=True)
|
600 |
-
return instance
|
601 |
-
|
602 |
-
|
603 |
-
class NullAugmentor(Augmentor):
|
604 |
-
"""Does not change the input string."""
|
605 |
-
|
606 |
-
def verify(self):
|
607 |
-
pass
|
608 |
-
|
609 |
-
def process_value(self, value: Any) -> Any:
|
610 |
-
return value
|
611 |
-
|
612 |
-
|
613 |
-
class AugmentWhitespace(Augmentor):
|
614 |
-
"""Augments the inputs by replacing existing whitespaces with other whitespaces.
|
615 |
-
|
616 |
-
Currently, each whitespace is replaced by a random choice of 1-3 whitespace characters (space, tab, newline).
|
617 |
-
"""
|
618 |
-
|
619 |
-
def process_value(self, value: Any) -> Any:
|
620 |
-
import re
|
621 |
-
|
622 |
-
words = re.split(r"(\s+)", value)
|
623 |
-
new_value = ""
|
624 |
-
|
625 |
-
random_generator = new_random_generator(sub_seed=value)
|
626 |
-
for word in words:
|
627 |
-
if word.isspace():
|
628 |
-
new_value += random_generator.choice(
|
629 |
-
["\n", "\t", " "]
|
630 |
-
) * random_generator.randint(1, 3)
|
631 |
-
else:
|
632 |
-
new_value += word
|
633 |
-
return new_value
|
634 |
-
|
635 |
-
|
636 |
-
class AugmentPrefixSuffix(Augmentor):
|
637 |
-
r"""Augments the input by prepending and appending to it a randomly selected (typically, whitespace) patterns.
|
638 |
-
|
639 |
-
Args:
|
640 |
-
prefixes, suffixes (list or dict) : the potential (typically, whitespace) patterns to select from.
|
641 |
-
The dictionary version allows to specify relative weights of the different patterns.
|
642 |
-
prefix_len, suffix_len (positive int) : The added prefix or suffix will be of length
|
643 |
-
prefix_len of suffix_len, respectively, repetitions of the randomly selected patterns.
|
644 |
-
remove_existing_whitespaces : allows to first clean any existing leading and trailing whitespaces.
|
645 |
-
The strings made of repetitions of the selected pattern(s) are then prepended and/or appended to the potentially
|
646 |
-
trimmed input.
|
647 |
-
If only one of prefixes/suffixes is needed, set the other to None.
|
648 |
-
|
649 |
-
Examples:
|
650 |
-
To prepend the input with a prefix made of 4 '\n'-s or '\t'-s, employ
|
651 |
-
AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)
|
652 |
-
To append the input with a suffix made of 3 '\n'-s or '\t'-s, with triple '\n' suffixes
|
653 |
-
being preferred over triple '\t', at 2:1 ratio, employ
|
654 |
-
AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)
|
655 |
-
which will append '\n'-s twice as often as '\t'-s.
|
656 |
-
|
657 |
-
"""
|
658 |
-
|
659 |
-
prefixes: Optional[Union[List[str], Dict[str, int]]] = {
|
660 |
-
" ": 20,
|
661 |
-
"\\t": 10,
|
662 |
-
"\\n": 40,
|
663 |
-
"": 30,
|
664 |
-
}
|
665 |
-
prefix_len: Optional[int] = 3
|
666 |
-
suffixes: Optional[Union[List[str], Dict[str, int]]] = {
|
667 |
-
" ": 20,
|
668 |
-
"\\t": 10,
|
669 |
-
"\\n": 40,
|
670 |
-
"": 30,
|
671 |
-
}
|
672 |
-
suffix_len: Optional[int] = 3
|
673 |
-
remove_existing_whitespaces: Optional[bool] = False
|
674 |
-
|
675 |
-
def verify(self):
|
676 |
-
assert (
|
677 |
-
self.prefixes or self.suffixes
|
678 |
-
), "At least one of prefixes/suffixes should be not None."
|
679 |
-
for arg, arg_name in zip(
|
680 |
-
[self.prefixes, self.suffixes], ["prefixes", "suffixes"]
|
681 |
-
):
|
682 |
-
assert (
|
683 |
-
arg is None or isoftype(arg, List[str]) or isoftype(arg, Dict[str, int])
|
684 |
-
), f"Argument {arg_name} should be either None or a list of strings or a dictionary str->int. {arg} is none of the above."
|
685 |
-
assert (
|
686 |
-
self.prefix_len > 0
|
687 |
-
), f"prefix_len must be positive, got {self.prefix_len}"
|
688 |
-
assert (
|
689 |
-
self.suffix_len > 0
|
690 |
-
), f"suffix_len must be positive, got {self.suffix_len}"
|
691 |
-
super().verify()
|
692 |
-
|
693 |
-
def _calculate_distributions(self, prefs_or_suffs):
|
694 |
-
if prefs_or_suffs is None:
|
695 |
-
return None, None
|
696 |
-
patterns = (
|
697 |
-
prefs_or_suffs
|
698 |
-
if isinstance(prefs_or_suffs, list)
|
699 |
-
else [k for k, v in prefs_or_suffs.items()]
|
700 |
-
)
|
701 |
-
total_weight = (
|
702 |
-
len(patterns)
|
703 |
-
if isinstance(prefs_or_suffs, list)
|
704 |
-
else sum([v for k, v in prefs_or_suffs.items()])
|
705 |
-
)
|
706 |
-
weights = (
|
707 |
-
[1.0 / total_weight] * len(patterns)
|
708 |
-
if isinstance(prefs_or_suffs, list)
|
709 |
-
else [float(prefs_or_suffs[p]) / total_weight for p in patterns]
|
710 |
-
)
|
711 |
-
return patterns, weights
|
712 |
-
|
713 |
-
def prepare(self):
|
714 |
-
# Being an artifact, prepare is invoked before verify. Here we need verify before the actions
|
715 |
-
self.verify()
|
716 |
-
self._prefix_pattern_distribution = {"length": self.prefix_len}
|
717 |
-
self._suffix_pattern_distribution = {"length": self.suffix_len}
|
718 |
-
|
719 |
-
(
|
720 |
-
self._prefix_pattern_distribution["patterns"],
|
721 |
-
self._prefix_pattern_distribution["weights"],
|
722 |
-
) = self._calculate_distributions(self.prefixes)
|
723 |
-
(
|
724 |
-
self._suffix_pattern_distribution["patterns"],
|
725 |
-
self._suffix_pattern_distribution["weights"],
|
726 |
-
) = self._calculate_distributions(self.suffixes)
|
727 |
-
super().prepare()
|
728 |
-
|
729 |
-
def _get_random_pattern(
|
730 |
-
self, pattern_distribution, random_generator: Random
|
731 |
-
) -> str:
|
732 |
-
string_to_add = ""
|
733 |
-
if pattern_distribution["patterns"]:
|
734 |
-
string_to_add = "".join(
|
735 |
-
random_generator.choices(
|
736 |
-
pattern_distribution["patterns"],
|
737 |
-
pattern_distribution["weights"],
|
738 |
-
k=pattern_distribution["length"],
|
739 |
-
)
|
740 |
-
)
|
741 |
-
return string_to_add
|
742 |
-
|
743 |
-
def process_value(self, value: Any) -> Any:
|
744 |
-
assert value is not None, "input value should not be None"
|
745 |
-
new_value = str(value)
|
746 |
-
if self.remove_existing_whitespaces:
|
747 |
-
new_value = new_value.strip()
|
748 |
-
random_generator = new_random_generator(sub_seed=value)
|
749 |
-
prefix = self._get_random_pattern(
|
750 |
-
self._prefix_pattern_distribution, random_generator
|
751 |
-
)
|
752 |
-
suffix = self._get_random_pattern(
|
753 |
-
self._suffix_pattern_distribution, random_generator
|
754 |
-
)
|
755 |
-
return prefix + new_value + suffix
|
756 |
-
|
757 |
-
|
758 |
class ShuffleFieldValues(FieldOperator):
|
759 |
"""Shuffles a list of values found in a field."""
|
760 |
|
@@ -1445,7 +1221,7 @@ class ComputeExpressionMixin(Artifact):
|
|
1445 |
|
1446 |
def compute_expression(self, instance: dict) -> Any:
|
1447 |
if settings.allow_unverified_code:
|
1448 |
-
return eval(self.expression, self.globals, instance)
|
1449 |
|
1450 |
raise ValueError(
|
1451 |
f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
|
|
|
531 |
return self.add + value
|
532 |
|
533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
class ShuffleFieldValues(FieldOperator):
|
535 |
"""Shuffles a list of values found in a field."""
|
536 |
|
|
|
1221 |
|
1222 |
def compute_expression(self, instance: dict) -> Any:
|
1223 |
if settings.allow_unverified_code:
|
1224 |
+
return eval(self.expression, {**self.globals, **instance})
|
1225 |
|
1226 |
raise ValueError(
|
1227 |
f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
|
schema.py
CHANGED
@@ -69,23 +69,36 @@ class Finalize(InstanceOperatorValidator):
|
|
69 |
|
70 |
return instance
|
71 |
|
72 |
-
def
|
73 |
-
self, instance: Dict[str, Any],
|
74 |
) -> Dict[str, Any]:
|
75 |
-
metadata = {
|
76 |
-
"data_classification_policy": instance["data_classification_policy"],
|
77 |
-
"template": self.artifact_to_jsonable(
|
78 |
-
instance["recipe_metadata"]["template"]
|
79 |
-
),
|
80 |
-
"num_demos": instance["recipe_metadata"]["num_demos"],
|
81 |
-
}
|
82 |
task_data = {
|
83 |
**instance["input_fields"],
|
84 |
-
"metadata":
|
|
|
|
|
85 |
}
|
86 |
-
|
87 |
-
if stream_name != constants.inference_stream:
|
88 |
task_data = {**task_data, **instance["reference_fields"]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
instance["task_data"] = json.dumps(task_data)
|
91 |
|
@@ -99,7 +112,7 @@ class Finalize(InstanceOperatorValidator):
|
|
99 |
for key in keys_to_delete:
|
100 |
del instance[key]
|
101 |
|
102 |
-
data = {**task_data, **metadata}
|
103 |
groups = []
|
104 |
for group_attributes in self.group_by:
|
105 |
group = {}
|
|
|
69 |
|
70 |
return instance
|
71 |
|
72 |
+
def _get_instance_task_data(
|
73 |
+
self, instance: Dict[str, Any], use_reference_fields=True
|
74 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
task_data = {
|
76 |
**instance["input_fields"],
|
77 |
+
"metadata": {
|
78 |
+
"data_classification_policy": instance["data_classification_policy"],
|
79 |
+
},
|
80 |
}
|
81 |
+
if use_reference_fields:
|
|
|
82 |
task_data = {**task_data, **instance["reference_fields"]}
|
83 |
+
return task_data
|
84 |
+
|
85 |
+
def process(
|
86 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
87 |
+
) -> Dict[str, Any]:
|
88 |
+
task_data = self._get_instance_task_data(
|
89 |
+
instance,
|
90 |
+
use_reference_fields=stream_name != constants.inference_stream,
|
91 |
+
)
|
92 |
+
|
93 |
+
task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"]
|
94 |
+
task_data["metadata"]["template"] = self.artifact_to_jsonable(
|
95 |
+
instance["recipe_metadata"]["template"]
|
96 |
+
)
|
97 |
+
if "demos" in instance:
|
98 |
+
task_data["demos"] = [
|
99 |
+
self._get_instance_task_data(instance)
|
100 |
+
for instance in instance.pop("demos")
|
101 |
+
]
|
102 |
|
103 |
instance["task_data"] = json.dumps(task_data)
|
104 |
|
|
|
112 |
for key in keys_to_delete:
|
113 |
del instance[key]
|
114 |
|
115 |
+
data = {**task_data, **task_data["metadata"]}
|
116 |
groups = []
|
117 |
for group_attributes in self.group_by:
|
118 |
group = {}
|
serializers.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import io
|
3 |
+
from abc import abstractmethod
|
4 |
+
from typing import Any, Dict, List, Union
|
5 |
+
|
6 |
+
from .dataclass import AbstractField, Field
|
7 |
+
from .operators import InstanceFieldOperator
|
8 |
+
from .type_utils import isoftype, to_type_string
|
9 |
+
from .types import Dialog, Image, Number, Table
|
10 |
+
|
11 |
+
|
12 |
+
class Serializer(InstanceFieldOperator):
|
13 |
+
def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str:
|
14 |
+
return self.serialize(value, instance)
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
class DefaultSerializer(Serializer):
|
22 |
+
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
|
23 |
+
return str(value)
|
24 |
+
|
25 |
+
|
26 |
+
class SingleTypeSerializer(InstanceFieldOperator):
|
27 |
+
serialized_type: object = AbstractField()
|
28 |
+
|
29 |
+
def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str:
|
30 |
+
if not isoftype(value, self.serialized_type):
|
31 |
+
raise ValueError(
|
32 |
+
f"SingleTypeSerializer for type {self.serialized_type} should get this type. got {to_type_string(value)}"
|
33 |
+
)
|
34 |
+
return self.serialize(value, instance)
|
35 |
+
|
36 |
+
|
37 |
+
class DefaultListSerializer(Serializer):
|
38 |
+
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
|
39 |
+
if isinstance(value, list):
|
40 |
+
return ", ".join(str(item) for item in value)
|
41 |
+
return str(value)
|
42 |
+
|
43 |
+
|
44 |
+
class ListSerializer(SingleTypeSerializer):
|
45 |
+
serialized_type = list
|
46 |
+
|
47 |
+
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
|
48 |
+
return ", ".join(str(item) for item in value)
|
49 |
+
|
50 |
+
|
51 |
+
class DialogSerializer(SingleTypeSerializer):
|
52 |
+
serialized_type = Dialog
|
53 |
+
|
54 |
+
def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
|
55 |
+
# Convert the Dialog into a string representation, typically combining roles and content
|
56 |
+
return "\n".join(f"{turn['role']}: {turn['content']}" for turn in value)
|
57 |
+
|
58 |
+
|
59 |
+
class NumberSerializer(SingleTypeSerializer):
|
60 |
+
serialized_type = Number
|
61 |
+
|
62 |
+
def serialize(self, value: Number, instance: Dict[str, Any]) -> str:
|
63 |
+
# Check if the value is an integer or a float
|
64 |
+
if isinstance(value, int):
|
65 |
+
return str(value)
|
66 |
+
# For floats, format to one decimal place
|
67 |
+
if isinstance(value, float):
|
68 |
+
return f"{value:.1f}"
|
69 |
+
raise ValueError("Unsupported type for NumberSerializer")
|
70 |
+
|
71 |
+
|
72 |
+
class NumberQuantizingSerializer(NumberSerializer):
|
73 |
+
serialized_type = Number
|
74 |
+
quantum: Union[float, int] = 0.1
|
75 |
+
|
76 |
+
def serialize(self, value: Number, instance: Dict[str, Any]) -> str:
|
77 |
+
if isoftype(value, Number):
|
78 |
+
quantized_value = round(value / self.quantum) / (1 / self.quantum)
|
79 |
+
if isinstance(self.quantum, int):
|
80 |
+
quantized_value = int(quantized_value)
|
81 |
+
return str(quantized_value)
|
82 |
+
raise ValueError("Unsupported type for NumberSerializer")
|
83 |
+
|
84 |
+
|
85 |
+
class TableSerializer(SingleTypeSerializer):
|
86 |
+
serialized_type = Table
|
87 |
+
|
88 |
+
def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
|
89 |
+
output = io.StringIO()
|
90 |
+
writer = csv.writer(output, lineterminator="\n")
|
91 |
+
|
92 |
+
# Write the header and rows to the CSV writer
|
93 |
+
writer.writerow(value["header"])
|
94 |
+
writer.writerows(value["rows"])
|
95 |
+
|
96 |
+
# Retrieve the CSV string
|
97 |
+
return output.getvalue().strip()
|
98 |
+
|
99 |
+
|
100 |
+
class ImageSerializer(SingleTypeSerializer):
|
101 |
+
serialized_type = Image
|
102 |
+
|
103 |
+
def serialize(self, value: Image, instance: Dict[str, Any]) -> str:
|
104 |
+
if "media" not in instance:
|
105 |
+
instance["media"] = {}
|
106 |
+
if "images" not in instance["media"]:
|
107 |
+
instance["media"]["images"] = []
|
108 |
+
idx = len(instance["media"]["images"])
|
109 |
+
instance["media"]["images"].append(value["image"])
|
110 |
+
value["image"] = f'<img src="media/images/{idx}">'
|
111 |
+
return value["image"]
|
112 |
+
|
113 |
+
|
114 |
+
class MultiTypeSerializer(Serializer):
|
115 |
+
serializers: List[SingleTypeSerializer] = Field(
|
116 |
+
default_factory=lambda: [
|
117 |
+
ImageSerializer(),
|
118 |
+
TableSerializer(),
|
119 |
+
DialogSerializer(),
|
120 |
+
]
|
121 |
+
)
|
122 |
+
|
123 |
+
def verify(self):
|
124 |
+
super().verify()
|
125 |
+
self._verify_serializers(self.serializers)
|
126 |
+
|
127 |
+
def _verify_serializers(self, serializers):
|
128 |
+
if not isoftype(serializers, List[SingleTypeSerializer]):
|
129 |
+
raise ValueError(
|
130 |
+
"MultiTypeSerializer requires the list of serializers to be List[SingleTypeSerializer]."
|
131 |
+
)
|
132 |
+
|
133 |
+
def add_serializers(self, serializers: List[SingleTypeSerializer]):
|
134 |
+
self._verify_serializers(serializers)
|
135 |
+
self.serializers = serializers + self.serializers
|
136 |
+
|
137 |
+
def serialize(self, value: Any, instance: Dict[str, Any]) -> Any:
|
138 |
+
for serializer in self.serializers:
|
139 |
+
if isoftype(value, serializer.serialized_type):
|
140 |
+
return serializer.serialize(value, instance)
|
141 |
+
|
142 |
+
return str(value)
|
settings_utils.py
CHANGED
@@ -146,6 +146,7 @@ if Settings.is_uninitilized():
|
|
146 |
settings.seed = (int, 42)
|
147 |
settings.skip_artifacts_prepare_and_verify = (bool, False)
|
148 |
settings.data_classification_policy = None
|
|
|
149 |
|
150 |
if Constants.is_uninitilized():
|
151 |
constants = Constants()
|
|
|
146 |
settings.seed = (int, 42)
|
147 |
settings.skip_artifacts_prepare_and_verify = (bool, False)
|
148 |
settings.data_classification_policy = None
|
149 |
+
settings.mock_inference_mode = (bool, False)
|
150 |
|
151 |
if Constants.is_uninitilized():
|
152 |
constants = Constants()
|
standard.py
CHANGED
@@ -1,20 +1,27 @@
|
|
1 |
from typing import List, Optional, Union
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from .card import TaskCard
|
4 |
from .collections_operators import GetLength
|
5 |
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
|
6 |
from .formats import Format, SystemFormat
|
7 |
from .logging_utils import get_logger
|
8 |
from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
|
9 |
-
from .operators import
|
10 |
from .recipe import Recipe
|
11 |
from .schema import Finalize
|
|
|
12 |
from .settings_utils import get_constants
|
13 |
from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
|
14 |
from .stream import MultiStream
|
15 |
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
16 |
from .task import Task
|
17 |
-
from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template
|
18 |
|
19 |
constants = get_constants()
|
20 |
logger = get_logger()
|
@@ -29,9 +36,10 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
29 |
# Base parameters
|
30 |
card: TaskCard = None
|
31 |
task: Task = None
|
32 |
-
template: Union[Template, List[Template]] = None
|
33 |
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
34 |
format: Format = Field(default_factory=SystemFormat)
|
|
|
35 |
|
36 |
# Additional parameters
|
37 |
template_card_index: int = NonPositionalField(default=None)
|
@@ -140,6 +148,11 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
140 |
else:
|
141 |
self.verify_template(self.template)
|
142 |
|
|
|
|
|
|
|
|
|
|
|
143 |
def prepare_refiners(self):
|
144 |
self.train_refiner.max_instances = self.max_train_instances
|
145 |
self.train_refiner.apply_to_streams = ["train"]
|
@@ -281,8 +294,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
281 |
|
282 |
self.processing.steps.append(self.task)
|
283 |
|
284 |
-
if self.augmentor
|
285 |
-
self.augmentor.
|
286 |
self.processing.steps.append(self.augmentor)
|
287 |
|
288 |
if self.has_custom_demos_pool:
|
@@ -362,7 +375,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
362 |
|
363 |
self.verbalization.steps.append(self.system_prompt)
|
364 |
self.verbalization.steps.append(self.format)
|
365 |
-
if self.augmentor
|
366 |
self.verbalization.steps.append(self.augmentor)
|
367 |
|
368 |
if self.postprocessors is not None:
|
@@ -376,6 +389,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
376 |
self.finalize.steps.append(Finalize(group_by=self.group_by))
|
377 |
|
378 |
def prepare(self):
|
|
|
|
|
379 |
self.reset_pipeline()
|
380 |
|
381 |
|
|
|
1 |
from typing import List, Optional, Union
|
2 |
|
3 |
+
from .augmentors import (
|
4 |
+
Augmentor,
|
5 |
+
FinalStateInputsAugmentor,
|
6 |
+
NullAugmentor,
|
7 |
+
TaskInputsAugmentor,
|
8 |
+
)
|
9 |
from .card import TaskCard
|
10 |
from .collections_operators import GetLength
|
11 |
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
|
12 |
from .formats import Format, SystemFormat
|
13 |
from .logging_utils import get_logger
|
14 |
from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
|
15 |
+
from .operators import Set, StreamRefiner
|
16 |
from .recipe import Recipe
|
17 |
from .schema import Finalize
|
18 |
+
from .serializers import SingleTypeSerializer
|
19 |
from .settings_utils import get_constants
|
20 |
from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
|
21 |
from .stream import MultiStream
|
22 |
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
23 |
from .task import Task
|
24 |
+
from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList
|
25 |
|
26 |
constants = get_constants()
|
27 |
logger = get_logger()
|
|
|
36 |
# Base parameters
|
37 |
card: TaskCard = None
|
38 |
task: Task = None
|
39 |
+
template: Union[Template, List[Template], TemplatesList] = None
|
40 |
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
41 |
format: Format = Field(default_factory=SystemFormat)
|
42 |
+
serializer: Union[SingleTypeSerializer, List[SingleTypeSerializer]] = None
|
43 |
|
44 |
# Additional parameters
|
45 |
template_card_index: int = NonPositionalField(default=None)
|
|
|
148 |
else:
|
149 |
self.verify_template(self.template)
|
150 |
|
151 |
+
if self.serializer is not None:
|
152 |
+
if not isinstance(self.serializer, list):
|
153 |
+
self.serializer = [self.serializer]
|
154 |
+
self.template.serializer.add_serializers(self.serializer)
|
155 |
+
|
156 |
def prepare_refiners(self):
|
157 |
self.train_refiner.max_instances = self.max_train_instances
|
158 |
self.train_refiner.apply_to_streams = ["train"]
|
|
|
294 |
|
295 |
self.processing.steps.append(self.task)
|
296 |
|
297 |
+
if isinstance(self.augmentor, TaskInputsAugmentor):
|
298 |
+
self.augmentor.set_fields(self.card.task.augmentable_inputs)
|
299 |
self.processing.steps.append(self.augmentor)
|
300 |
|
301 |
if self.has_custom_demos_pool:
|
|
|
375 |
|
376 |
self.verbalization.steps.append(self.system_prompt)
|
377 |
self.verbalization.steps.append(self.format)
|
378 |
+
if isinstance(self.augmentor, FinalStateInputsAugmentor):
|
379 |
self.verbalization.steps.append(self.augmentor)
|
380 |
|
381 |
if self.postprocessors is not None:
|
|
|
389 |
self.finalize.steps.append(Finalize(group_by=self.group_by))
|
390 |
|
391 |
def prepare(self):
|
392 |
+
if isinstance(self.template, TemplatesList):
|
393 |
+
self.template = self.template.items
|
394 |
self.reset_pipeline()
|
395 |
|
396 |
|
struct_data_operators.py
CHANGED
@@ -29,15 +29,62 @@ import pandas as pd
|
|
29 |
|
30 |
from .dict_utils import dict_get
|
31 |
from .operators import FieldOperator, InstanceOperator
|
|
|
|
|
|
|
32 |
from .utils import deepcopy
|
33 |
|
34 |
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
"""TableSerializer converts a given table into a flat sequence with special symbols.
|
37 |
|
38 |
Output format varies depending on the chosen serializer. This abstract class defines structure of a typical table serializer that any concrete implementation should follow.
|
39 |
"""
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
# main method to serialize a table
|
42 |
@abstractmethod
|
43 |
def serialize_table(self, table_content: Dict) -> str:
|
@@ -60,10 +107,6 @@ class SerializeTableAsIndexedRowMajor(SerializeTable):
|
|
60 |
Format: col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
|
61 |
"""
|
62 |
|
63 |
-
def process_value(self, table: Any) -> Any:
|
64 |
-
table_input = deepcopy(table)
|
65 |
-
return self.serialize_table(table_content=table_input)
|
66 |
-
|
67 |
# main method that processes a table
|
68 |
# table_content must be in the presribed input format
|
69 |
def serialize_table(self, table_content: Dict) -> str:
|
@@ -111,10 +154,6 @@ class SerializeTableAsMarkdown(SerializeTable):
|
|
111 |
...
|
112 |
"""
|
113 |
|
114 |
-
def process_value(self, table: Any) -> Any:
|
115 |
-
table_input = deepcopy(table)
|
116 |
-
return self.serialize_table(table_content=table_input)
|
117 |
-
|
118 |
# main method that serializes a table.
|
119 |
# table_content must be in the presribed input format.
|
120 |
def serialize_table(self, table_content: Dict) -> str:
|
@@ -159,10 +198,6 @@ class SerializeTableAsDFLoader(SerializeTable):
|
|
159 |
index=[0,1,2])
|
160 |
"""
|
161 |
|
162 |
-
def process_value(self, table: Any) -> Any:
|
163 |
-
table_input = deepcopy(table)
|
164 |
-
return self.serialize_table(table_content=table_input)
|
165 |
-
|
166 |
# main method that serializes a table.
|
167 |
# table_content must be in the presribed input format.
|
168 |
def serialize_table(self, table_content: Dict) -> str:
|
@@ -199,10 +234,6 @@ class SerializeTableAsJson(SerializeTable):
|
|
199 |
}
|
200 |
"""
|
201 |
|
202 |
-
def process_value(self, table: Any) -> Any:
|
203 |
-
table_input = deepcopy(table)
|
204 |
-
return self.serialize_table(table_content=table_input)
|
205 |
-
|
206 |
# main method that serializes a table.
|
207 |
# table_content must be in the presribed input format.
|
208 |
def serialize_table(self, table_content: Dict) -> str:
|
@@ -493,20 +524,7 @@ class ShuffleTableRows(FieldOperator):
|
|
493 |
|
494 |
def process_value(self, table: Any) -> Any:
|
495 |
table_input = deepcopy(table)
|
496 |
-
return
|
497 |
-
|
498 |
-
# shuffles table rows randomly
|
499 |
-
def shuffle_rows(self, table_content: Dict) -> str:
|
500 |
-
# extract header & rows from the dictionary
|
501 |
-
header = table_content.get("header", [])
|
502 |
-
rows = table_content.get("rows", [])
|
503 |
-
assert header and rows, "Incorrect input table format"
|
504 |
-
|
505 |
-
# shuffle rows
|
506 |
-
random.shuffle(rows)
|
507 |
-
table_content["rows"] = rows
|
508 |
-
|
509 |
-
return table_content
|
510 |
|
511 |
|
512 |
class ShuffleTableColumns(FieldOperator):
|
@@ -527,27 +545,7 @@ class ShuffleTableColumns(FieldOperator):
|
|
527 |
|
528 |
def process_value(self, table: Any) -> Any:
|
529 |
table_input = deepcopy(table)
|
530 |
-
return
|
531 |
-
|
532 |
-
# shuffles table columns randomly
|
533 |
-
def shuffle_columns(self, table_content: Dict) -> str:
|
534 |
-
# extract header & rows from the dictionary
|
535 |
-
header = table_content.get("header", [])
|
536 |
-
rows = table_content.get("rows", [])
|
537 |
-
assert header and rows, "Incorrect input table format"
|
538 |
-
|
539 |
-
# shuffle the indices first
|
540 |
-
indices = list(range(len(header)))
|
541 |
-
random.shuffle(indices) #
|
542 |
-
|
543 |
-
# shuffle the header & rows based on that indices
|
544 |
-
shuffled_header = [header[i] for i in indices]
|
545 |
-
shuffled_rows = [[row[i] for i in indices] for row in rows]
|
546 |
-
|
547 |
-
table_content["header"] = shuffled_header
|
548 |
-
table_content["rows"] = shuffled_rows
|
549 |
-
|
550 |
-
return table_content
|
551 |
|
552 |
|
553 |
class LoadJson(FieldOperator):
|
|
|
29 |
|
30 |
from .dict_utils import dict_get
|
31 |
from .operators import FieldOperator, InstanceOperator
|
32 |
+
from .random_utils import new_random_generator
|
33 |
+
from .serializers import TableSerializer
|
34 |
+
from .types import Table
|
35 |
from .utils import deepcopy
|
36 |
|
37 |
|
38 |
+
def shuffle_columns(table: Table, seed=0) -> Table:
|
39 |
+
# extract header & rows from the dictionary
|
40 |
+
header = table.get("header", [])
|
41 |
+
rows = table.get("rows", [])
|
42 |
+
# shuffle the indices first
|
43 |
+
indices = list(range(len(header)))
|
44 |
+
random_generator = new_random_generator({"table": table, "seed": seed})
|
45 |
+
random_generator.shuffle(indices)
|
46 |
+
|
47 |
+
# shuffle the header & rows based on that indices
|
48 |
+
shuffled_header = [header[i] for i in indices]
|
49 |
+
shuffled_rows = [[row[i] for i in indices] for row in rows]
|
50 |
+
|
51 |
+
table["header"] = shuffled_header
|
52 |
+
table["rows"] = shuffled_rows
|
53 |
+
|
54 |
+
return table
|
55 |
+
|
56 |
+
|
57 |
+
def shuffle_rows(table: Table, seed=0) -> Table:
|
58 |
+
# extract header & rows from the dictionary
|
59 |
+
rows = table.get("rows", [])
|
60 |
+
# shuffle rows
|
61 |
+
random_generator = new_random_generator({"table": table, "seed": seed})
|
62 |
+
random_generator.shuffle(rows)
|
63 |
+
table["rows"] = rows
|
64 |
+
|
65 |
+
return table
|
66 |
+
|
67 |
+
|
68 |
+
class SerializeTable(ABC, TableSerializer):
|
69 |
"""TableSerializer converts a given table into a flat sequence with special symbols.
|
70 |
|
71 |
Output format varies depending on the chosen serializer. This abstract class defines structure of a typical table serializer that any concrete implementation should follow.
|
72 |
"""
|
73 |
|
74 |
+
seed: int = 0
|
75 |
+
shuffle_rows: bool = False
|
76 |
+
shuffle_columns: bool = False
|
77 |
+
|
78 |
+
def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
|
79 |
+
value = deepcopy(value)
|
80 |
+
if self.shuffle_columns:
|
81 |
+
value = shuffle_columns(table=value, seed=self.seed)
|
82 |
+
|
83 |
+
if self.shuffle_rows:
|
84 |
+
value = shuffle_rows(table=value, seed=self.seed)
|
85 |
+
|
86 |
+
return self.serialize_table(value)
|
87 |
+
|
88 |
# main method to serialize a table
|
89 |
@abstractmethod
|
90 |
def serialize_table(self, table_content: Dict) -> str:
|
|
|
107 |
Format: col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
|
108 |
"""
|
109 |
|
|
|
|
|
|
|
|
|
110 |
# main method that processes a table
|
111 |
# table_content must be in the presribed input format
|
112 |
def serialize_table(self, table_content: Dict) -> str:
|
|
|
154 |
...
|
155 |
"""
|
156 |
|
|
|
|
|
|
|
|
|
157 |
# main method that serializes a table.
|
158 |
# table_content must be in the presribed input format.
|
159 |
def serialize_table(self, table_content: Dict) -> str:
|
|
|
198 |
index=[0,1,2])
|
199 |
"""
|
200 |
|
|
|
|
|
|
|
|
|
201 |
# main method that serializes a table.
|
202 |
# table_content must be in the presribed input format.
|
203 |
def serialize_table(self, table_content: Dict) -> str:
|
|
|
234 |
}
|
235 |
"""
|
236 |
|
|
|
|
|
|
|
|
|
237 |
# main method that serializes a table.
|
238 |
# table_content must be in the presribed input format.
|
239 |
def serialize_table(self, table_content: Dict) -> str:
|
|
|
524 |
|
525 |
def process_value(self, table: Any) -> Any:
|
526 |
table_input = deepcopy(table)
|
527 |
+
return shuffle_rows(table_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
|
529 |
|
530 |
class ShuffleTableColumns(FieldOperator):
|
|
|
545 |
|
546 |
def process_value(self, table: Any) -> Any:
|
547 |
table_input = deepcopy(table)
|
548 |
+
return shuffle_columns(table_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
|
550 |
|
551 |
class LoadJson(FieldOperator):
|
templates.py
CHANGED
@@ -10,6 +10,15 @@ from .dict_utils import dict_set
|
|
10 |
from .error_utils import Documentation, UnitxtError
|
11 |
from .operator import InstanceOperator
|
12 |
from .random_utils import new_random_generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from .settings_utils import get_constants
|
14 |
from .type_utils import isoftype
|
15 |
|
@@ -46,17 +55,26 @@ class Template(InstanceOperator):
|
|
46 |
instruction: str = NonPositionalField(default="")
|
47 |
target_prefix: str = NonPositionalField(default="")
|
48 |
title_fields: List[str] = NonPositionalField(default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def input_fields_to_instruction_and_target_prefix(self, input_fields):
|
51 |
instruction = self.apply_formatting(
|
52 |
-
input_fields, "input field", self.instruction, "instruction"
|
53 |
)
|
54 |
target_prefix = self.apply_formatting(
|
55 |
input_fields,
|
56 |
"input field",
|
57 |
self.target_prefix,
|
58 |
"target_prefix",
|
59 |
-
serialize=True,
|
60 |
)
|
61 |
return instruction, target_prefix
|
62 |
|
@@ -65,6 +83,12 @@ class Template(InstanceOperator):
|
|
65 |
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
66 |
return input_fields, reference_fields
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def process(
|
69 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
70 |
) -> Dict[str, Any]:
|
@@ -78,14 +102,21 @@ class Template(InstanceOperator):
|
|
78 |
|
79 |
input_fields = instance.get("input_fields")
|
80 |
reference_fields = instance.get("reference_fields")
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
|
85 |
self.set_titles(input_fields)
|
86 |
-
|
|
|
|
|
|
|
87 |
instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
|
88 |
-
|
89 |
)
|
90 |
|
91 |
result = {
|
@@ -97,19 +128,33 @@ class Template(InstanceOperator):
|
|
97 |
}
|
98 |
|
99 |
if stream_name == constants.inference_stream:
|
100 |
-
return result
|
101 |
|
102 |
if reference_fields is None:
|
103 |
raise ValueError("Should have reference_fields")
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
target, references = self.reference_fields_to_target_and_references(
|
106 |
-
|
107 |
)
|
108 |
|
109 |
result["target"] = target
|
110 |
result["references"] = references
|
111 |
|
112 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
@abstractmethod
|
115 |
def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
|
@@ -125,21 +170,13 @@ class Template(InstanceOperator):
|
|
125 |
) -> Tuple[str, List[str]]:
|
126 |
pass
|
127 |
|
128 |
-
def serialize_data(self, data):
|
129 |
-
return {
|
130 |
-
k: ", ".join(str(t) for t in v) if isinstance(v, list) else v
|
131 |
-
for k, v in data.items()
|
132 |
-
}
|
133 |
-
|
134 |
def apply_formatting(
|
135 |
-
self, data, data_type, format_str, format_name
|
136 |
) -> str:
|
137 |
-
if serialize:
|
138 |
-
data = self.serialize_data(data)
|
139 |
try:
|
140 |
if format_str is None:
|
141 |
raise UnitxtError(
|
142 |
-
f"Required field '
|
143 |
Documentation.ADDING_TEMPLATE,
|
144 |
)
|
145 |
return format_str.format(**data)
|
@@ -197,26 +234,21 @@ class ApplyRandomTemplate(ApplyTemplate):
|
|
197 |
return random_generator.choice(self.templates)
|
198 |
|
199 |
|
200 |
-
class
|
201 |
-
"""Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
|
202 |
-
|
203 |
-
Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
|
204 |
-
"""
|
205 |
-
|
206 |
input_format: str
|
207 |
-
output_format: str = None
|
208 |
|
209 |
-
def input_fields_to_source(
|
210 |
-
self, input_fields: Dict[str, object]
|
211 |
-
) -> Tuple[str, str]:
|
212 |
return self.apply_formatting(
|
213 |
input_fields,
|
214 |
"input field",
|
215 |
self.input_format,
|
216 |
"input_format",
|
217 |
-
serialize=True,
|
218 |
)
|
219 |
|
|
|
|
|
|
|
|
|
220 |
def reference_fields_to_target_and_references(
|
221 |
self, reference_fields: Dict[str, object]
|
222 |
) -> str:
|
@@ -225,12 +257,20 @@ class InputOutputTemplate(Template):
|
|
225 |
"reference field",
|
226 |
self.output_format,
|
227 |
"output_format",
|
228 |
-
serialize=True,
|
229 |
)
|
230 |
references = [target]
|
231 |
return target, references
|
232 |
|
233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
|
235 |
reference: str
|
236 |
|
@@ -242,14 +282,12 @@ class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
|
|
242 |
"reference field",
|
243 |
self.output_format,
|
244 |
"output_format",
|
245 |
-
serialize=True,
|
246 |
)
|
247 |
reference = self.apply_formatting(
|
248 |
reference_fields,
|
249 |
"reference field",
|
250 |
self.reference,
|
251 |
"reference",
|
252 |
-
serialize=True,
|
253 |
)
|
254 |
return target, [reference]
|
255 |
|
@@ -374,22 +412,12 @@ class DialogTemplate(InputOutputTemplate):
|
|
374 |
input_fields[dialog_fields.dialog_field] = dialog_str
|
375 |
return input_fields
|
376 |
|
377 |
-
def
|
378 |
-
self
|
379 |
-
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
380 |
-
return self.process_dialog(input_fields), reference_fields
|
381 |
|
382 |
|
383 |
class DialogPairwiseChoiceTemplate(DialogTemplate, PairwiseChoiceTemplate):
|
384 |
-
|
385 |
-
self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
|
386 |
-
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
387 |
-
inputs, reference_fields = DialogTemplate.preprocess_input_and_reference_fields(
|
388 |
-
self, input_fields, reference_fields
|
389 |
-
)
|
390 |
-
return PairwiseChoiceTemplate.preprocess_input_and_reference_fields(
|
391 |
-
self, input_fields, reference_fields
|
392 |
-
)
|
393 |
|
394 |
|
395 |
class PairwiseComparativeRatingTemplate(InputOutputTemplate):
|
@@ -448,10 +476,9 @@ class PairwiseComparativeRatingTemplate(InputOutputTemplate):
|
|
448 |
return input_fields, reference_fields
|
449 |
|
450 |
|
451 |
-
class MultipleChoiceTemplate(
|
452 |
"""Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer."""
|
453 |
|
454 |
-
input_format: str
|
455 |
target_prefix: str = ""
|
456 |
choices_field: str = "choices"
|
457 |
target_field: str = "label"
|
@@ -493,7 +520,7 @@ class MultipleChoiceTemplate(Template):
|
|
493 |
"XX",
|
494 |
]
|
495 |
|
496 |
-
def inputs_to_choices(self, data: Dict[str,
|
497 |
choices = data[self.choices_field]
|
498 |
enumrated_choices = []
|
499 |
for i, choice in enumerate(choices):
|
@@ -505,12 +532,12 @@ class MultipleChoiceTemplate(Template):
|
|
505 |
)
|
506 |
return enumrated_choices
|
507 |
|
508 |
-
def inputs_to_numerals(self, input_fields: Dict[str,
|
509 |
return self.inputs_to_choices(input_fields, "{choice_numeral}")
|
510 |
|
511 |
def prepare_multiple_choice_inputs(
|
512 |
-
self, input_fields: Dict[str,
|
513 |
-
) -> Dict[str,
|
514 |
choices = self.inputs_to_choices(input_fields, self.source_choice_format)
|
515 |
return {
|
516 |
"numerals": self.inputs_to_numerals(input_fields),
|
@@ -518,23 +545,10 @@ class MultipleChoiceTemplate(Template):
|
|
518 |
self.choices_field: self.choices_separator.join(choices),
|
519 |
}
|
520 |
|
521 |
-
def
|
522 |
-
self
|
523 |
-
) -> Tuple[str, str]:
|
524 |
-
input_fields = self.prepare_multiple_choice_inputs(input_fields)
|
525 |
-
return self.apply_formatting(
|
526 |
-
input_fields,
|
527 |
-
"input field",
|
528 |
-
self.input_format,
|
529 |
-
"input_format",
|
530 |
-
serialize=True,
|
531 |
-
)
|
532 |
-
|
533 |
-
def input_fields_to_instruction_and_target_prefix(self, input_fields):
|
534 |
-
input_fields = self.prepare_multiple_choice_inputs(input_fields)
|
535 |
-
return super().input_fields_to_instruction_and_target_prefix(input_fields)
|
536 |
|
537 |
-
def outputs_to_target_index(self, reference_fields: Dict[str, object]) ->
|
538 |
target = reference_fields[self.target_field]
|
539 |
|
540 |
if not isinstance(target, int):
|
@@ -547,9 +561,7 @@ class MultipleChoiceTemplate(Template):
|
|
547 |
) from e
|
548 |
return target
|
549 |
|
550 |
-
def
|
551 |
-
self, reference_fields: Dict[str, object]
|
552 |
-
) -> str:
|
553 |
target = reference_fields[self.target_field]
|
554 |
|
555 |
if not isinstance(target, int):
|
@@ -571,51 +583,40 @@ class MultipleChoiceTemplate(Template):
|
|
571 |
Documentation.ADDING_TEMPLATE,
|
572 |
) from e
|
573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
return target, [target]
|
575 |
|
576 |
-
def
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
]
|
582 |
-
|
|
|
583 |
|
584 |
-
|
|
|
|
|
585 |
|
586 |
-
|
587 |
-
|
588 |
-
instance["input_fields"][self.choices_field] = choices
|
589 |
|
590 |
-
|
591 |
-
return instance
|
592 |
|
593 |
-
|
594 |
-
instance["
|
595 |
-
|
596 |
)
|
597 |
-
|
598 |
return instance
|
599 |
|
600 |
-
def process(
|
601 |
-
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
602 |
-
) -> Dict[str, Any]:
|
603 |
-
if self.shuffle_choices:
|
604 |
-
instance = self._shuffle_choices(instance, stream_name)
|
605 |
-
result = super().process(instance, stream_name)
|
606 |
-
if stream_name == constants.inference_stream:
|
607 |
-
result["input_fields"]["options"] = self.inputs_to_choices(
|
608 |
-
instance["input_fields"], self.target_choice_format
|
609 |
-
)
|
610 |
-
else:
|
611 |
-
if "options" not in result["reference_fields"]:
|
612 |
-
result["reference_fields"]["options"] = self.inputs_to_choices(
|
613 |
-
instance["reference_fields"], self.target_choice_format
|
614 |
-
)
|
615 |
-
return result
|
616 |
-
|
617 |
|
618 |
-
class YesNoTemplate(
|
619 |
"""A template for generating binary Yes/No questions asking whether an input text is of a specific class.
|
620 |
|
621 |
input_format:
|
@@ -641,17 +642,6 @@ class YesNoTemplate(Template):
|
|
641 |
yes_answer: str = "Yes"
|
642 |
no_answer: str = "No"
|
643 |
|
644 |
-
def input_fields_to_source(
|
645 |
-
self, input_fields: Dict[str, object]
|
646 |
-
) -> Tuple[str, str]:
|
647 |
-
return self.apply_formatting(
|
648 |
-
input_fields,
|
649 |
-
"input field",
|
650 |
-
self.input_format,
|
651 |
-
"input_format",
|
652 |
-
serialize=True,
|
653 |
-
)
|
654 |
-
|
655 |
def reference_fields_to_target_and_references(
|
656 |
self, reference_fields: Dict[str, object]
|
657 |
) -> str:
|
@@ -695,16 +685,13 @@ class KeyValTemplate(Template):
|
|
695 |
def process_dict(
|
696 |
self, data: Dict[str, object], key_val_sep, pairs_sep, use_keys
|
697 |
) -> str:
|
698 |
-
data = self.serialize_data(data)
|
699 |
pairs = []
|
700 |
for key, val in data.items():
|
701 |
key_val = [key, str(val)] if use_keys else [str(val)]
|
702 |
pairs.append(key_val_sep.join(key_val))
|
703 |
return pairs_sep.join(pairs)
|
704 |
|
705 |
-
def input_fields_to_source(
|
706 |
-
self, input_fields: Dict[str, object]
|
707 |
-
) -> Tuple[str, str]:
|
708 |
return self.process_dict(
|
709 |
input_fields,
|
710 |
key_val_sep=self.key_val_separator,
|
@@ -725,25 +712,16 @@ class KeyValTemplate(Template):
|
|
725 |
|
726 |
|
727 |
class OutputQuantizingTemplate(InputOutputTemplate):
|
728 |
-
|
|
|
|
|
|
|
729 |
|
730 |
-
def
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
quantized_outputs = {
|
736 |
-
key: f"{int(round(value / self.quantum) * self.quantum)}"
|
737 |
-
for key, value in reference_fields.items()
|
738 |
-
}
|
739 |
-
else:
|
740 |
-
# When quantum is a float, format quantized values with precision based on quantum
|
741 |
-
quantum_str = f"{self.quantum:.10f}".rstrip("0").rstrip(".")
|
742 |
-
quantized_outputs = {
|
743 |
-
key: f"{round(value / self.quantum) * self.quantum:{quantum_str}}"
|
744 |
-
for key, value in reference_fields.items()
|
745 |
-
}
|
746 |
-
return super().reference_fields_to_target_and_references(quantized_outputs)
|
747 |
|
748 |
|
749 |
class MultiLabelTemplate(InputOutputTemplate):
|
@@ -753,9 +731,9 @@ class MultiLabelTemplate(InputOutputTemplate):
|
|
753 |
output_format: str = "{labels}"
|
754 |
empty_label: str = "None"
|
755 |
|
756 |
-
def
|
757 |
-
self, reference_fields: Dict[str,
|
758 |
-
) -> str:
|
759 |
labels = reference_fields[self.labels_field]
|
760 |
if not isinstance(labels, list):
|
761 |
raise UnitxtError(
|
@@ -765,18 +743,29 @@ class MultiLabelTemplate(InputOutputTemplate):
|
|
765 |
if len(labels) == 0:
|
766 |
labels = [self.empty_label]
|
767 |
labels_str = self.labels_separator.join(labels)
|
768 |
-
return
|
769 |
-
{self.labels_field: labels_str}
|
770 |
-
)
|
771 |
|
772 |
|
773 |
class MultiReferenceTemplate(InputOutputTemplate):
|
774 |
references_field: str = "references"
|
775 |
random_reference: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
|
777 |
def reference_fields_to_target_and_references(
|
778 |
self, reference_fields: Dict[str, object]
|
779 |
-
) -> List[str]:
|
780 |
references = reference_fields[self.references_field]
|
781 |
if not isoftype(references, List[str]):
|
782 |
raise UnitxtError(
|
@@ -825,12 +814,12 @@ class SpanLabelingBaseTemplate(MultiLabelTemplate):
|
|
825 |
if self.labels_support is None or span[3] in self.labels_support:
|
826 |
yield span[2], span[3]
|
827 |
|
828 |
-
def
|
829 |
-
self, reference_fields: Dict[str,
|
830 |
-
) -> Dict[str,
|
831 |
span_labels_pairs = self.extract_span_label_pairs(reference_fields)
|
832 |
targets = self.span_label_pairs_to_targets(span_labels_pairs)
|
833 |
-
return super().
|
834 |
|
835 |
@abstractmethod
|
836 |
def span_label_pairs_to_targets(self, pairs):
|
|
|
10 |
from .error_utils import Documentation, UnitxtError
|
11 |
from .operator import InstanceOperator
|
12 |
from .random_utils import new_random_generator
|
13 |
+
from .serializers import (
|
14 |
+
DialogSerializer,
|
15 |
+
ImageSerializer,
|
16 |
+
ListSerializer,
|
17 |
+
MultiTypeSerializer,
|
18 |
+
NumberQuantizingSerializer,
|
19 |
+
Serializer,
|
20 |
+
TableSerializer,
|
21 |
+
)
|
22 |
from .settings_utils import get_constants
|
23 |
from .type_utils import isoftype
|
24 |
|
|
|
55 |
instruction: str = NonPositionalField(default="")
|
56 |
target_prefix: str = NonPositionalField(default="")
|
57 |
title_fields: List[str] = NonPositionalField(default_factory=list)
|
58 |
+
serializer: Serializer = NonPositionalField(
|
59 |
+
default_factory=lambda: MultiTypeSerializer(
|
60 |
+
serializers=[
|
61 |
+
ImageSerializer(),
|
62 |
+
TableSerializer(),
|
63 |
+
DialogSerializer(),
|
64 |
+
ListSerializer(),
|
65 |
+
]
|
66 |
+
)
|
67 |
+
)
|
68 |
|
69 |
def input_fields_to_instruction_and_target_prefix(self, input_fields):
|
70 |
instruction = self.apply_formatting(
|
71 |
+
input_fields, "input field", self.instruction, "instruction"
|
72 |
)
|
73 |
target_prefix = self.apply_formatting(
|
74 |
input_fields,
|
75 |
"input field",
|
76 |
self.target_prefix,
|
77 |
"target_prefix",
|
|
|
78 |
)
|
79 |
return instruction, target_prefix
|
80 |
|
|
|
83 |
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
84 |
return input_fields, reference_fields
|
85 |
|
86 |
+
def preprocess_input_fields(self, input_fields: Dict[str, Any]):
|
87 |
+
return input_fields
|
88 |
+
|
89 |
+
def preprocess_reference_fields(self, reference_fields: Dict[str, Any]):
|
90 |
+
return reference_fields
|
91 |
+
|
92 |
def process(
|
93 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
94 |
) -> Dict[str, Any]:
|
|
|
102 |
|
103 |
input_fields = instance.get("input_fields")
|
104 |
reference_fields = instance.get("reference_fields")
|
105 |
+
|
106 |
+
if stream_name != constants.inference_stream:
|
107 |
+
input_fields, reference_fields = self.preprocess_input_and_reference_fields(
|
108 |
+
input_fields, reference_fields
|
109 |
+
)
|
110 |
+
|
111 |
+
input_fields = self.preprocess_input_fields(input_fields)
|
112 |
|
113 |
self.set_titles(input_fields)
|
114 |
+
|
115 |
+
serialized_inputs = self.serialize(input_fields, instance)
|
116 |
+
|
117 |
+
source = self.input_fields_to_source(serialized_inputs)
|
118 |
instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
|
119 |
+
serialized_inputs
|
120 |
)
|
121 |
|
122 |
result = {
|
|
|
128 |
}
|
129 |
|
130 |
if stream_name == constants.inference_stream:
|
131 |
+
return self.post_process_instance(result)
|
132 |
|
133 |
if reference_fields is None:
|
134 |
raise ValueError("Should have reference_fields")
|
135 |
|
136 |
+
reference_fields = self.preprocess_reference_fields(reference_fields)
|
137 |
+
|
138 |
+
serialized_references = self.serialize(
|
139 |
+
reference_fields, instance
|
140 |
+
) # Dict[str, str]
|
141 |
+
|
142 |
target, references = self.reference_fields_to_target_and_references(
|
143 |
+
serialized_references
|
144 |
)
|
145 |
|
146 |
result["target"] = target
|
147 |
result["references"] = references
|
148 |
|
149 |
+
return self.post_process_instance(result)
|
150 |
+
|
151 |
+
def post_process_instance(self, instance):
|
152 |
+
return instance
|
153 |
+
|
154 |
+
def serialize(
|
155 |
+
self, data: Dict[str, Any], instance: Dict[str, Any]
|
156 |
+
) -> Dict[str, str]:
|
157 |
+
return {k: self.serializer.serialize(v, instance) for k, v in data.items()}
|
158 |
|
159 |
@abstractmethod
|
160 |
def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
|
|
|
170 |
) -> Tuple[str, List[str]]:
|
171 |
pass
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
def apply_formatting(
|
174 |
+
self, data: Dict[str, Any], data_type: str, format_str: str, format_name: str
|
175 |
) -> str:
|
|
|
|
|
176 |
try:
|
177 |
if format_str is None:
|
178 |
raise UnitxtError(
|
179 |
+
f"Required field '{format_name}' of class {self.__class__.__name__} not set in {self.__class__.__name__}",
|
180 |
Documentation.ADDING_TEMPLATE,
|
181 |
)
|
182 |
return format_str.format(**data)
|
|
|
234 |
return random_generator.choice(self.templates)
|
235 |
|
236 |
|
237 |
+
class InputFormatTemplate(Template):
|
|
|
|
|
|
|
|
|
|
|
238 |
input_format: str
|
|
|
239 |
|
240 |
+
def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
|
|
|
|
|
241 |
return self.apply_formatting(
|
242 |
input_fields,
|
243 |
"input field",
|
244 |
self.input_format,
|
245 |
"input_format",
|
|
|
246 |
)
|
247 |
|
248 |
+
|
249 |
+
class OutputFormatTemplate(Template):
|
250 |
+
output_format: str = None
|
251 |
+
|
252 |
def reference_fields_to_target_and_references(
|
253 |
self, reference_fields: Dict[str, object]
|
254 |
) -> str:
|
|
|
257 |
"reference field",
|
258 |
self.output_format,
|
259 |
"output_format",
|
|
|
260 |
)
|
261 |
references = [target]
|
262 |
return target, references
|
263 |
|
264 |
|
265 |
+
class InputOutputTemplate(InputFormatTemplate, OutputFormatTemplate):
|
266 |
+
"""Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
|
267 |
+
|
268 |
+
Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
|
269 |
+
"""
|
270 |
+
|
271 |
+
pass
|
272 |
+
|
273 |
+
|
274 |
class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
|
275 |
reference: str
|
276 |
|
|
|
282 |
"reference field",
|
283 |
self.output_format,
|
284 |
"output_format",
|
|
|
285 |
)
|
286 |
reference = self.apply_formatting(
|
287 |
reference_fields,
|
288 |
"reference field",
|
289 |
self.reference,
|
290 |
"reference",
|
|
|
291 |
)
|
292 |
return target, [reference]
|
293 |
|
|
|
412 |
input_fields[dialog_fields.dialog_field] = dialog_str
|
413 |
return input_fields
|
414 |
|
415 |
+
def preprocess_input_fields(self, input_fields: Dict[str, Any]):
|
416 |
+
return self.process_dialog(input_fields)
|
|
|
|
|
417 |
|
418 |
|
419 |
class DialogPairwiseChoiceTemplate(DialogTemplate, PairwiseChoiceTemplate):
|
420 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
|
422 |
|
423 |
class PairwiseComparativeRatingTemplate(InputOutputTemplate):
|
|
|
476 |
return input_fields, reference_fields
|
477 |
|
478 |
|
479 |
+
class MultipleChoiceTemplate(InputFormatTemplate):
|
480 |
"""Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer."""
|
481 |
|
|
|
482 |
target_prefix: str = ""
|
483 |
choices_field: str = "choices"
|
484 |
target_field: str = "label"
|
|
|
520 |
"XX",
|
521 |
]
|
522 |
|
523 |
+
def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str:
|
524 |
choices = data[self.choices_field]
|
525 |
enumrated_choices = []
|
526 |
for i, choice in enumerate(choices):
|
|
|
532 |
)
|
533 |
return enumrated_choices
|
534 |
|
535 |
+
def inputs_to_numerals(self, input_fields: Dict[str, Any]) -> Tuple[str, str]:
|
536 |
return self.inputs_to_choices(input_fields, "{choice_numeral}")
|
537 |
|
538 |
def prepare_multiple_choice_inputs(
|
539 |
+
self, input_fields: Dict[str, Any]
|
540 |
+
) -> Dict[str, Any]:
|
541 |
choices = self.inputs_to_choices(input_fields, self.source_choice_format)
|
542 |
return {
|
543 |
"numerals": self.inputs_to_numerals(input_fields),
|
|
|
545 |
self.choices_field: self.choices_separator.join(choices),
|
546 |
}
|
547 |
|
548 |
+
def preprocess_input_fields(self, input_fields: Dict[str, Any]) -> Dict[str, Any]:
|
549 |
+
return self.prepare_multiple_choice_inputs(input_fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
+
def outputs_to_target_index(self, reference_fields: Dict[str, object]) -> int:
|
552 |
target = reference_fields[self.target_field]
|
553 |
|
554 |
if not isinstance(target, int):
|
|
|
561 |
) from e
|
562 |
return target
|
563 |
|
564 |
+
def preprocess_reference_fields(self, reference_fields: Dict[str, Any]):
|
|
|
|
|
565 |
target = reference_fields[self.target_field]
|
566 |
|
567 |
if not isinstance(target, int):
|
|
|
583 |
Documentation.ADDING_TEMPLATE,
|
584 |
) from e
|
585 |
|
586 |
+
return {self.target_field: target}
|
587 |
+
|
588 |
+
def reference_fields_to_target_and_references(
|
589 |
+
self, reference_fields: Dict[str, object]
|
590 |
+
) -> str:
|
591 |
+
target = reference_fields[self.target_field]
|
592 |
return target, [target]
|
593 |
|
594 |
+
def preprocess_input_and_reference_fields(
|
595 |
+
self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
|
596 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
597 |
+
if self.shuffle_choices:
|
598 |
+
target_index = self.outputs_to_target_index(reference_fields)
|
599 |
+
original_label_choice = reference_fields[self.choices_field][target_index]
|
600 |
+
choices = input_fields[self.choices_field]
|
601 |
+
random_seed = {**input_fields}
|
602 |
|
603 |
+
random_generator = new_random_generator(random_seed)
|
604 |
+
random_generator.shuffle(choices)
|
605 |
+
input_fields[self.choices_field] = choices
|
606 |
|
607 |
+
reference_fields[self.choices_field] = choices
|
608 |
+
reference_fields[self.target_field] = choices.index(original_label_choice)
|
|
|
609 |
|
610 |
+
return input_fields, reference_fields
|
|
|
611 |
|
612 |
+
def post_process_instance(self, instance):
|
613 |
+
instance["input_fields"]["options"] = self.inputs_to_choices(
|
614 |
+
instance["input_fields"], self.target_choice_format
|
615 |
)
|
|
|
616 |
return instance
|
617 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
618 |
|
619 |
+
class YesNoTemplate(InputFormatTemplate):
|
620 |
"""A template for generating binary Yes/No questions asking whether an input text is of a specific class.
|
621 |
|
622 |
input_format:
|
|
|
642 |
yes_answer: str = "Yes"
|
643 |
no_answer: str = "No"
|
644 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
645 |
def reference_fields_to_target_and_references(
|
646 |
self, reference_fields: Dict[str, object]
|
647 |
) -> str:
|
|
|
685 |
def process_dict(
|
686 |
self, data: Dict[str, object], key_val_sep, pairs_sep, use_keys
|
687 |
) -> str:
|
|
|
688 |
pairs = []
|
689 |
for key, val in data.items():
|
690 |
key_val = [key, str(val)] if use_keys else [str(val)]
|
691 |
pairs.append(key_val_sep.join(key_val))
|
692 |
return pairs_sep.join(pairs)
|
693 |
|
694 |
+
def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
|
|
|
|
|
695 |
return self.process_dict(
|
696 |
input_fields,
|
697 |
key_val_sep=self.key_val_separator,
|
|
|
712 |
|
713 |
|
714 |
class OutputQuantizingTemplate(InputOutputTemplate):
|
715 |
+
serializer: MultiTypeSerializer = NonPositionalField(
|
716 |
+
default_factory=MultiTypeSerializer
|
717 |
+
)
|
718 |
+
quantum: Union[float, int] = 0.1
|
719 |
|
720 |
+
def prepare(self):
|
721 |
+
super().prepare()
|
722 |
+
self.serializer.add_serializers(
|
723 |
+
[NumberQuantizingSerializer(quantum=self.quantum)]
|
724 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
|
726 |
|
727 |
class MultiLabelTemplate(InputOutputTemplate):
|
|
|
731 |
output_format: str = "{labels}"
|
732 |
empty_label: str = "None"
|
733 |
|
734 |
+
def preprocess_reference_fields(
|
735 |
+
self, reference_fields: Dict[str, Any]
|
736 |
+
) -> Dict[str, Any]:
|
737 |
labels = reference_fields[self.labels_field]
|
738 |
if not isinstance(labels, list):
|
739 |
raise UnitxtError(
|
|
|
743 |
if len(labels) == 0:
|
744 |
labels = [self.empty_label]
|
745 |
labels_str = self.labels_separator.join(labels)
|
746 |
+
return {self.labels_field: labels_str}
|
|
|
|
|
747 |
|
748 |
|
749 |
class MultiReferenceTemplate(InputOutputTemplate):
|
750 |
references_field: str = "references"
|
751 |
random_reference: bool = False
|
752 |
+
serializer: Serializer = NonPositionalField(default_factory=MultiTypeSerializer)
|
753 |
+
|
754 |
+
def serialize(
|
755 |
+
self, data: Dict[str, Any], instance: Dict[str, Any]
|
756 |
+
) -> Dict[str, str]:
|
757 |
+
result = {}
|
758 |
+
for k, v in data.items():
|
759 |
+
if k == self.references_field:
|
760 |
+
v = [self.serializer.serialize(item, instance) for item in v]
|
761 |
+
else:
|
762 |
+
v = self.serializer.serialize(v, instance)
|
763 |
+
result[k] = v
|
764 |
+
return result
|
765 |
|
766 |
def reference_fields_to_target_and_references(
|
767 |
self, reference_fields: Dict[str, object]
|
768 |
+
) -> Tuple[str, List[str]]:
|
769 |
references = reference_fields[self.references_field]
|
770 |
if not isoftype(references, List[str]):
|
771 |
raise UnitxtError(
|
|
|
814 |
if self.labels_support is None or span[3] in self.labels_support:
|
815 |
yield span[2], span[3]
|
816 |
|
817 |
+
def preprocess_reference_fields(
|
818 |
+
self, reference_fields: Dict[str, Any]
|
819 |
+
) -> Dict[str, Any]:
|
820 |
span_labels_pairs = self.extract_span_label_pairs(reference_fields)
|
821 |
targets = self.span_label_pairs_to_targets(span_labels_pairs)
|
822 |
+
return super().preprocess_reference_fields({"labels": targets})
|
823 |
|
824 |
@abstractmethod
|
825 |
def span_label_pairs_to_targets(self, pairs):
|
type_utils.py
CHANGED
@@ -4,48 +4,75 @@ import io
|
|
4 |
import itertools
|
5 |
import re
|
6 |
import typing
|
|
|
7 |
|
8 |
from .utils import safe_eval
|
9 |
|
10 |
-
|
11 |
-
"Any",
|
12 |
-
"List
|
13 |
-
"Dict
|
14 |
-
"Tuple
|
15 |
-
"Union
|
16 |
-
"Optional
|
17 |
-
"
|
18 |
-
"
|
19 |
-
"
|
20 |
-
"
|
21 |
-
"
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
Type = typing.Any
|
25 |
|
26 |
|
27 |
class UnsupportedTypeError(ValueError):
|
28 |
def __init__(self, type_object):
|
29 |
-
supported_types = ", ".join(
|
30 |
super().__init__(
|
31 |
f"Type: '{type_object!s}' is not supported type. Use one of {supported_types}"
|
32 |
)
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
35 |
_generics = [
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
42 |
]
|
43 |
|
44 |
_generics_types = [type(t) for t in _generics]
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def is_type(object):
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
def is_type_dict(object):
|
@@ -215,34 +242,31 @@ def parse_type_string(type_string: str) -> typing.Any:
|
|
215 |
and basic Python data types. It also defines a list of safe tokens that are allowed
|
216 |
in the type string.
|
217 |
"""
|
218 |
-
safe_context = {
|
219 |
-
"Any": typing.Any,
|
220 |
-
"List": typing.List,
|
221 |
-
"Dict": typing.Dict,
|
222 |
-
"Tuple": typing.Tuple,
|
223 |
-
"Union": typing.Union,
|
224 |
-
"int": int,
|
225 |
-
"str": str,
|
226 |
-
"float": float,
|
227 |
-
"bool": bool,
|
228 |
-
"Optional": typing.Optional,
|
229 |
-
}
|
230 |
-
|
231 |
type_string = format_type_string(type_string)
|
232 |
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
|
237 |
def to_type_string(typing_type):
|
238 |
-
|
239 |
-
raise UnsupportedTypeError(typing_type)
|
240 |
-
type_string = (
|
241 |
-
str(typing_type)
|
242 |
-
.replace("typing.", "")
|
243 |
-
.replace("<class '", "")
|
244 |
-
.replace("'>", "")
|
245 |
-
)
|
246 |
assert parse_type_string(type_string), "Is not parsed well"
|
247 |
return type_string
|
248 |
|
@@ -447,9 +471,9 @@ def infer_type_string(obj: typing.Any) -> str:
|
|
447 |
def isoftype(object, typing_type):
|
448 |
"""Checks if an object is of a certain typing type, including nested types.
|
449 |
|
450 |
-
This function supports simple types (
|
451 |
-
(
|
452 |
-
|
453 |
|
454 |
Args:
|
455 |
object: The object to check.
|
@@ -457,19 +481,21 @@ def isoftype(object, typing_type):
|
|
457 |
|
458 |
Returns:
|
459 |
bool: True if the object is of the specified type, False otherwise.
|
460 |
-
|
461 |
-
Examples:
|
462 |
-
.. highlight:: python
|
463 |
-
.. code-block:: python
|
464 |
-
|
465 |
-
isoftype(1, int) # True
|
466 |
-
isoftype([1, 2, 3], typing.List[int]) # True
|
467 |
-
isoftype([1, 2, 3], typing.List[str]) # False
|
468 |
-
isoftype([[1, 2], [3, 4]], typing.List[typing.List[int]]) # True
|
469 |
"""
|
470 |
if not is_type(typing_type):
|
471 |
raise UnsupportedTypeError(typing_type)
|
472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
if typing_type == typing.Any:
|
474 |
return True
|
475 |
|
@@ -477,15 +503,16 @@ def isoftype(object, typing_type):
|
|
477 |
origin = typing_type.__origin__
|
478 |
type_args = typing.get_args(typing_type)
|
479 |
|
|
|
|
|
|
|
480 |
if origin is typing.Union:
|
481 |
return any(isoftype(object, sub_type) for sub_type in type_args)
|
482 |
|
483 |
if not isinstance(object, origin):
|
484 |
return False
|
485 |
-
|
486 |
if origin is list or origin is set:
|
487 |
return all(isoftype(element, type_args[0]) for element in object)
|
488 |
-
|
489 |
if origin is dict:
|
490 |
return all(
|
491 |
isoftype(key, type_args[0]) and isoftype(value, type_args[1])
|
@@ -496,11 +523,77 @@ def isoftype(object, typing_type):
|
|
496 |
isoftype(element, type_arg)
|
497 |
for element, type_arg in zip(object, type_args)
|
498 |
)
|
499 |
-
return None
|
500 |
|
501 |
return isinstance(object, typing_type)
|
502 |
|
503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
# copied from: https://github.com/bojiang/typing_utils/blob/main/typing_utils/__init__.py
|
505 |
# liscened under Apache License 2.0
|
506 |
|
|
|
4 |
import itertools
|
5 |
import re
|
6 |
import typing
|
7 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
8 |
|
9 |
from .utils import safe_eval
|
10 |
|
11 |
+
_registered_types = {
|
12 |
+
"Any": typing.Any,
|
13 |
+
"List": typing.List,
|
14 |
+
"Dict": typing.Dict,
|
15 |
+
"Tuple": typing.Tuple,
|
16 |
+
"Union": typing.Union,
|
17 |
+
"Optional": typing.Optional,
|
18 |
+
"Literal": typing.Literal,
|
19 |
+
"int": int,
|
20 |
+
"str": str,
|
21 |
+
"float": float,
|
22 |
+
"bool": bool,
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def register_type(new_type):
|
27 |
+
assert is_new_type(new_type) or is_typed_dict(
|
28 |
+
new_type
|
29 |
+
), "Can register only typing.NewType or typing.TypedDict"
|
30 |
+
_registered_types[new_type.__name__] = new_type
|
31 |
+
|
32 |
|
33 |
Type = typing.Any
|
34 |
|
35 |
|
36 |
class UnsupportedTypeError(ValueError):
|
37 |
def __init__(self, type_object):
|
38 |
+
supported_types = ", ".join(_registered_types.keys())
|
39 |
super().__init__(
|
40 |
f"Type: '{type_object!s}' is not supported type. Use one of {supported_types}"
|
41 |
)
|
42 |
|
43 |
|
44 |
+
class GenericTypedDict(TypedDict):
|
45 |
+
pass
|
46 |
+
|
47 |
+
|
48 |
_generics = [
|
49 |
+
List[Any],
|
50 |
+
Dict[Any, Any],
|
51 |
+
Tuple[Any],
|
52 |
+
Union[Any, Any],
|
53 |
+
Optional[Any],
|
54 |
+
Any,
|
55 |
+
Literal,
|
56 |
]
|
57 |
|
58 |
_generics_types = [type(t) for t in _generics]
|
59 |
|
60 |
|
61 |
+
def is_new_type(object):
|
62 |
+
return callable(object) and hasattr(object, "__supertype__")
|
63 |
+
|
64 |
+
|
65 |
+
def is_typed_dict(object):
|
66 |
+
return isinstance(object, type(GenericTypedDict))
|
67 |
+
|
68 |
+
|
69 |
def is_type(object):
|
70 |
+
"""Checks if the provided object is a type, including generics, Literal, TypedDict, and NewType."""
|
71 |
+
return (
|
72 |
+
isinstance(object, (type, *_generics_types))
|
73 |
+
or is_new_type(object)
|
74 |
+
or is_typed_dict(object)
|
75 |
+
)
|
76 |
|
77 |
|
78 |
def is_type_dict(object):
|
|
|
242 |
and basic Python data types. It also defines a list of safe tokens that are allowed
|
243 |
in the type string.
|
244 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
type_string = format_type_string(type_string)
|
246 |
|
247 |
+
return safe_eval(
|
248 |
+
type_string, context=_registered_types, allowed_tokens=["[", "]", ",", " "]
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
def replace_class_names(full_string: str) -> str:
|
253 |
+
# Regular expression to match any fully qualified class name and extract the class name
|
254 |
+
pattern = r"(?:\w+\.)*<locals>\.(\w+)|(?:\w+\.)*(\w+)"
|
255 |
+
|
256 |
+
# Function to replace the matched pattern with just the class name
|
257 |
+
def replacement(match):
|
258 |
+
# If the match has a group for <locals>
|
259 |
+
if match.group(1):
|
260 |
+
return match.group(1)
|
261 |
+
# Otherwise, return the last group (class name)
|
262 |
+
return match.group(2)
|
263 |
+
|
264 |
+
# Use re.sub to replace all occurrences in the string
|
265 |
+
return re.sub(pattern, replacement, full_string)
|
266 |
|
267 |
|
268 |
def to_type_string(typing_type):
|
269 |
+
type_string = strtype(typing_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
assert parse_type_string(type_string), "Is not parsed well"
|
271 |
return type_string
|
272 |
|
|
|
471 |
def isoftype(object, typing_type):
|
472 |
"""Checks if an object is of a certain typing type, including nested types.
|
473 |
|
474 |
+
This function supports simple types, typing types (List[int], Tuple[str, int]),
|
475 |
+
nested typing types (List[List[int]], Tuple[List[str], int]), Literal, TypedDict,
|
476 |
+
and NewType.
|
477 |
|
478 |
Args:
|
479 |
object: The object to check.
|
|
|
481 |
|
482 |
Returns:
|
483 |
bool: True if the object is of the specified type, False otherwise.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
484 |
"""
|
485 |
if not is_type(typing_type):
|
486 |
raise UnsupportedTypeError(typing_type)
|
487 |
|
488 |
+
if is_new_type(typing_type):
|
489 |
+
typing_type = typing_type.__supertype__
|
490 |
+
|
491 |
+
if is_typed_dict(typing_type):
|
492 |
+
if not isinstance(object, dict):
|
493 |
+
return False
|
494 |
+
for key, expected_type in typing_type.__annotations__.items():
|
495 |
+
if key not in object or not isoftype(object[key], expected_type):
|
496 |
+
return False
|
497 |
+
return True
|
498 |
+
|
499 |
if typing_type == typing.Any:
|
500 |
return True
|
501 |
|
|
|
503 |
origin = typing_type.__origin__
|
504 |
type_args = typing.get_args(typing_type)
|
505 |
|
506 |
+
if origin is Literal:
|
507 |
+
return object in type_args
|
508 |
+
|
509 |
if origin is typing.Union:
|
510 |
return any(isoftype(object, sub_type) for sub_type in type_args)
|
511 |
|
512 |
if not isinstance(object, origin):
|
513 |
return False
|
|
|
514 |
if origin is list or origin is set:
|
515 |
return all(isoftype(element, type_args[0]) for element in object)
|
|
|
516 |
if origin is dict:
|
517 |
return all(
|
518 |
isoftype(key, type_args[0]) and isoftype(value, type_args[1])
|
|
|
523 |
isoftype(element, type_arg)
|
524 |
for element, type_arg in zip(object, type_args)
|
525 |
)
|
|
|
526 |
|
527 |
return isinstance(object, typing_type)
|
528 |
|
529 |
|
530 |
+
def strtype(typing_type) -> str:
|
531 |
+
"""Converts a typing type to its string representation.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
typing_type (Any): The typing type to be converted. This can include standard types,
|
535 |
+
custom types, or types from the `typing` module, such as `Literal`, `Union`,
|
536 |
+
`List`, `Dict`, `Tuple`, `TypedDict`, and `NewType`.
|
537 |
+
|
538 |
+
Returns:
|
539 |
+
str: The string representation of the provided typing type.
|
540 |
+
|
541 |
+
Raises:
|
542 |
+
UnsupportedTypeError: If the provided `typing_type` is not a recognized type.
|
543 |
+
|
544 |
+
Notes:
|
545 |
+
- If `typing_type` is `Literal`, `NewType`, or `TypedDict`, the function returns
|
546 |
+
the name of the type.
|
547 |
+
- If `typing_type` is `Any`, it returns the string `"Any"`.
|
548 |
+
- For other typing constructs like `Union`, `List`, `Dict`, and `Tuple`, the function
|
549 |
+
recursively converts each part of the type to its string representation.
|
550 |
+
- The function checks the `__origin__` attribute to determine the base type and formats
|
551 |
+
the type arguments accordingly.
|
552 |
+
"""
|
553 |
+
if not is_type(typing_type):
|
554 |
+
raise UnsupportedTypeError(typing_type)
|
555 |
+
|
556 |
+
if is_new_type(typing_type) or is_typed_dict(typing_type):
|
557 |
+
return typing_type.__name__
|
558 |
+
|
559 |
+
if typing_type == typing.Any:
|
560 |
+
return "Any"
|
561 |
+
|
562 |
+
if hasattr(typing_type, "__origin__"):
|
563 |
+
origin = typing_type.__origin__
|
564 |
+
type_args = typing.get_args(typing_type)
|
565 |
+
|
566 |
+
if type_args[-1] is type(None):
|
567 |
+
return (
|
568 |
+
"Optional["
|
569 |
+
+ ", ".join([strtype(sub_type) for sub_type in type_args[:-1]])
|
570 |
+
+ "]"
|
571 |
+
)
|
572 |
+
|
573 |
+
if origin is Literal:
|
574 |
+
return str(typing_type).replace("typing.", "")
|
575 |
+
if origin is typing.Union:
|
576 |
+
return (
|
577 |
+
"Union["
|
578 |
+
+ ", ".join([strtype(sub_type) for sub_type in type_args])
|
579 |
+
+ "]"
|
580 |
+
)
|
581 |
+
if origin is list or origin is set:
|
582 |
+
return "List[" + strtype(type_args[0]) + "]"
|
583 |
+
if origin is set:
|
584 |
+
return "Set[" + strtype(type_args[0]) + "]"
|
585 |
+
if origin is dict:
|
586 |
+
return "Dict[" + strtype(type_args[0]) + ", " + strtype(type_args[1]) + "]"
|
587 |
+
if origin is tuple:
|
588 |
+
return (
|
589 |
+
"Tuple["
|
590 |
+
+ ", ".join([strtype(sub_type) for sub_type in type_args])
|
591 |
+
+ "]"
|
592 |
+
)
|
593 |
+
|
594 |
+
return typing_type.__name__
|
595 |
+
|
596 |
+
|
597 |
# copied from: https://github.com/bojiang/typing_utils/blob/main/typing_utils/__init__.py
|
598 |
# liscened under Apache License 2.0
|
599 |
|
types.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Literal, NewType, TypedDict, Union
|
2 |
+
|
3 |
+
from .type_utils import register_type
|
4 |
+
|
5 |
+
Text = NewType("Text", str)
|
6 |
+
Number = NewType("Number", Union[float, int])
|
7 |
+
|
8 |
+
|
9 |
+
class Turn(TypedDict):
|
10 |
+
role: Literal["system", "user", "agent"]
|
11 |
+
content: Text
|
12 |
+
|
13 |
+
|
14 |
+
Dialog = NewType("Dialog", List[Turn])
|
15 |
+
|
16 |
+
|
17 |
+
class Image(TypedDict):
|
18 |
+
image: Any
|
19 |
+
|
20 |
+
|
21 |
+
class Audio(TypedDict):
|
22 |
+
audio: Any
|
23 |
+
|
24 |
+
|
25 |
+
class Table(TypedDict):
|
26 |
+
header: List[str]
|
27 |
+
rows: List[List[Any]]
|
28 |
+
|
29 |
+
|
30 |
+
register_type(Text)
|
31 |
+
register_type(Number)
|
32 |
+
register_type(Turn)
|
33 |
+
register_type(Dialog)
|
34 |
+
register_type(Table)
|
35 |
+
register_type(Audio)
|
36 |
+
register_type(Image)
|
utils.py
CHANGED
@@ -2,6 +2,7 @@ import copy
|
|
2 |
import importlib.util
|
3 |
import json
|
4 |
import os
|
|
|
5 |
from functools import lru_cache
|
6 |
from typing import Any, Dict
|
7 |
|
@@ -87,6 +88,23 @@ def is_module_available(module_name):
|
|
87 |
return False
|
88 |
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
|
91 |
"""Evaluates a given expression in a restricted environment, allowing only specified tokens and context variables.
|
92 |
|
@@ -109,7 +127,9 @@ def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
|
|
109 |
by restricting the available tokens and not exposing built-in functions.
|
110 |
"""
|
111 |
allowed_sub_strings = list(context.keys()) + allowed_tokens
|
112 |
-
if is_made_of_sub_strings(
|
|
|
|
|
113 |
return eval(expression, {"__builtins__": {}}, context)
|
114 |
raise ValueError(
|
115 |
f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
|
|
|
2 |
import importlib.util
|
3 |
import json
|
4 |
import os
|
5 |
+
import re
|
6 |
from functools import lru_cache
|
7 |
from typing import Any, Dict
|
8 |
|
|
|
88 |
return False
|
89 |
|
90 |
|
91 |
+
def remove_numerics_and_quoted_texts(input_str):
|
92 |
+
# Remove floats first to avoid leaving stray periods
|
93 |
+
input_str = re.sub(r"\d+\.\d+", "", input_str)
|
94 |
+
|
95 |
+
# Remove integers
|
96 |
+
input_str = re.sub(r"\d+", "", input_str)
|
97 |
+
|
98 |
+
# Remove strings in single quotes
|
99 |
+
input_str = re.sub(r"'.*?'", "", input_str)
|
100 |
+
|
101 |
+
# Remove strings in double quotes
|
102 |
+
input_str = re.sub(r'".*?"', "", input_str)
|
103 |
+
|
104 |
+
# Remove strings in triple quotes
|
105 |
+
return re.sub(r'""".*?"""', "", input_str, flags=re.DOTALL)
|
106 |
+
|
107 |
+
|
108 |
def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
|
109 |
"""Evaluates a given expression in a restricted environment, allowing only specified tokens and context variables.
|
110 |
|
|
|
127 |
by restricting the available tokens and not exposing built-in functions.
|
128 |
"""
|
129 |
allowed_sub_strings = list(context.keys()) + allowed_tokens
|
130 |
+
if is_made_of_sub_strings(
|
131 |
+
remove_numerics_and_quoted_texts(expression), allowed_sub_strings
|
132 |
+
):
|
133 |
return eval(expression, {"__builtins__": {}}, context)
|
134 |
raise ValueError(
|
135 |
f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.
|
|
|
1 |
+
version = "1.13.0"
|