Elron commited on
Commit
7cdc7d0
·
verified ·
1 Parent(s): d08fbc6

Upload folder using huggingface_hub

Browse files
Files changed (23) hide show
  1. README.md +5 -5
  2. augmentors.py +195 -0
  3. dataset.py +3 -0
  4. deprecation_utils.py +37 -0
  5. dialog_operators.py +151 -4
  6. formats.py +9 -7
  7. image_operators.py +58 -9
  8. inference.py +195 -18
  9. llm_as_judge.py +12 -4
  10. loaders.py +10 -2
  11. metric.py +3 -0
  12. metrics.py +15 -7
  13. operators.py +1 -225
  14. schema.py +26 -13
  15. serializers.py +142 -0
  16. settings_utils.py +1 -0
  17. standard.py +21 -6
  18. struct_data_operators.py +50 -52
  19. templates.py +140 -151
  20. type_utils.py +152 -59
  21. types.py +36 -0
  22. utils.py +21 -1
  23. version.py +1 -1
README.md CHANGED
@@ -40,11 +40,11 @@ https://github.com/IBM/unitxt/assets/23455264/baef9131-39d4-4164-90b2-05da52919f
40
 
41
  ### 🦄 Currently on Unitxt Catalog
42
 
43
- ![NLP Tasks](https://img.shields.io/badge/NLP_tasks-40-blue)
44
- ![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-457-blue)
45
- ![Templates](https://img.shields.io/badge/Templates-229-blue)
46
- ![Formats](https://img.shields.io/badge/Formats-18-blue)
47
- ![Metrics](https://img.shields.io/badge/Metrics-98-blue)
48
 
49
  ### 🦄 Run Unitxt Exploration Dashboard
50
 
 
40
 
41
  ### 🦄 Currently on Unitxt Catalog
42
 
43
+ ![NLP Tasks](https://img.shields.io/badge/NLP_tasks-48-blue)
44
+ ![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-537-blue)
45
+ ![Templates](https://img.shields.io/badge/Templates-265-blue)
46
+ ![Formats](https://img.shields.io/badge/Formats-23-blue)
47
+ ![Metrics](https://img.shields.io/badge/Metrics-136-blue)
48
 
49
  ### 🦄 Run Unitxt Exploration Dashboard
50
 
augmentors.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import Random
2
+ from typing import (
3
+ Any,
4
+ Dict,
5
+ List,
6
+ Optional,
7
+ Union,
8
+ )
9
+
10
+ from .operators import FieldOperator
11
+ from .random_utils import new_random_generator
12
+ from .type_utils import isoftype
13
+
14
+
15
+ class Augmentor(FieldOperator):
16
+ """A stream operator that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template."""
17
+
18
+ operator: FieldOperator
19
+
20
+ def process_value(self, value: Any) -> Any:
21
+ return self.operator.process_value(value)
22
+
23
+
24
+ class TaskInputsAugmentor(Augmentor):
25
+ def set_fields(self, fields: List[str]):
26
+ fields = ["input_fields/" + field for field in fields]
27
+ self.field_to_field = {field: field for field in fields}
28
+
29
+
30
+ class FinalStateInputsAugmentor(Augmentor):
31
+ pass
32
+
33
+
34
+ class ModelInputAugmentor(FinalStateInputsAugmentor):
35
+ field = "source"
36
+
37
+
38
+ class ImagesAugmentor(FinalStateInputsAugmentor):
39
+ field = "media/images"
40
+ process_every_value = True
41
+
42
+
43
+ class Identity(FieldOperator):
44
+ def process_value(self, value: Any) -> Any:
45
+ return value
46
+
47
+
48
+ class NullAugmentor(Augmentor):
49
+ """Does not change the input string."""
50
+
51
+ operator = Identity()
52
+
53
+
54
+ class AugmentWhitespace(FieldOperator):
55
+ """Augments the inputs by replacing existing whitespaces with other whitespaces.
56
+
57
+ Currently, each whitespace is replaced by a random choice of 1-3 whitespace characters (space, tab, newline).
58
+ """
59
+
60
+ def process_value(self, value: str) -> str:
61
+ import re
62
+
63
+ words = re.split(r"(\s+)", value)
64
+ new_value = ""
65
+
66
+ random_generator = new_random_generator(sub_seed=value)
67
+ for word in words:
68
+ if word.isspace():
69
+ new_value += random_generator.choice(
70
+ ["\n", "\t", " "]
71
+ ) * random_generator.randint(1, 3)
72
+ else:
73
+ new_value += word
74
+ return new_value
75
+
76
+
77
+ class AugmentPrefixSuffix(FieldOperator):
78
+ r"""Augments the input by prepending and appending randomly selected (typically, whitespace) patterns.
79
+
80
+ Args:
81
+ prefixes, suffixes (list or dict) : the potential (typically, whitespace) patterns to select from.
82
+ The dictionary version allows the specification relative weights for the different patterns.
83
+ prefix_len, suffix_len (positive int) : The added prefix or suffix will be of a certain length.
84
+ remove_existing_whitespaces : Clean any existing leading and trailing whitespaces.
85
+ The strings made of repetitions of the selected pattern(s) are then prepended and/or appended to the potentially
86
+ trimmed input.
87
+ If only either just prefixes or just suffixes are needed, set the other to None.
88
+
89
+ Examples:
90
+ To prepend the input with a prefix made of 4 '\n'-s or '\t'-s, employ
91
+ AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)
92
+ To append the input with a suffix made of 3 '\n'-s or '\t'-s, with triple '\n' suffixes
93
+ being preferred over triple '\t', at 2:1 ratio, employ
94
+ AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)
95
+ which will append '\n'-s twice as often as '\t'-s.
96
+
97
+ """
98
+
99
+ prefixes: Optional[Union[List[str], Dict[str, int]]] = {
100
+ " ": 20,
101
+ "\\t": 10,
102
+ "\\n": 40,
103
+ "": 30,
104
+ }
105
+ prefix_len: Optional[int] = 3
106
+ suffixes: Optional[Union[List[str], Dict[str, int]]] = {
107
+ " ": 20,
108
+ "\\t": 10,
109
+ "\\n": 40,
110
+ "": 30,
111
+ }
112
+ suffix_len: Optional[int] = 3
113
+ remove_existing_whitespaces: Optional[bool] = False
114
+
115
+ def verify(self):
116
+ assert (
117
+ self.prefixes or self.suffixes
118
+ ), "At least one of prefixes/suffixes should be not None."
119
+ for arg, arg_name in zip(
120
+ [self.prefixes, self.suffixes], ["prefixes", "suffixes"]
121
+ ):
122
+ assert (
123
+ arg is None or isoftype(arg, List[str]) or isoftype(arg, Dict[str, int])
124
+ ), f"Argument {arg_name} should be either None or a list of strings or a dictionary str->int. {arg} is none of the above."
125
+ assert (
126
+ self.prefix_len > 0
127
+ ), f"prefix_len must be positive, got {self.prefix_len}"
128
+ assert (
129
+ self.suffix_len > 0
130
+ ), f"suffix_len must be positive, got {self.suffix_len}"
131
+ super().verify()
132
+
133
+ def _calculate_distributions(self, prefs_or_suffs):
134
+ if prefs_or_suffs is None:
135
+ return None, None
136
+ patterns = (
137
+ prefs_or_suffs
138
+ if isinstance(prefs_or_suffs, list)
139
+ else [k for k, v in prefs_or_suffs.items()]
140
+ )
141
+ total_weight = (
142
+ len(patterns)
143
+ if isinstance(prefs_or_suffs, list)
144
+ else sum([v for k, v in prefs_or_suffs.items()])
145
+ )
146
+ weights = (
147
+ [1.0 / total_weight] * len(patterns)
148
+ if isinstance(prefs_or_suffs, list)
149
+ else [float(prefs_or_suffs[p]) / total_weight for p in patterns]
150
+ )
151
+ return patterns, weights
152
+
153
+ def prepare(self):
154
+ # Being an artifact, prepare is invoked before verify. Here we need verify before the actions
155
+ self.verify()
156
+ self._prefix_pattern_distribution = {"length": self.prefix_len}
157
+ self._suffix_pattern_distribution = {"length": self.suffix_len}
158
+
159
+ (
160
+ self._prefix_pattern_distribution["patterns"],
161
+ self._prefix_pattern_distribution["weights"],
162
+ ) = self._calculate_distributions(self.prefixes)
163
+ (
164
+ self._suffix_pattern_distribution["patterns"],
165
+ self._suffix_pattern_distribution["weights"],
166
+ ) = self._calculate_distributions(self.suffixes)
167
+ super().prepare()
168
+
169
+ def _get_random_pattern(
170
+ self, pattern_distribution, random_generator: Random
171
+ ) -> str:
172
+ string_to_add = ""
173
+ if pattern_distribution["patterns"]:
174
+ string_to_add = "".join(
175
+ random_generator.choices(
176
+ pattern_distribution["patterns"],
177
+ pattern_distribution["weights"],
178
+ k=pattern_distribution["length"],
179
+ )
180
+ )
181
+ return string_to_add
182
+
183
+ def process_value(self, value: Any) -> Any:
184
+ assert value is not None, "input value should not be None"
185
+ new_value = str(value)
186
+ if self.remove_existing_whitespaces:
187
+ new_value = new_value.strip()
188
+ random_generator = new_random_generator(sub_seed=value)
189
+ prefix = self._get_random_pattern(
190
+ self._prefix_pattern_distribution, random_generator
191
+ )
192
+ suffix = self._get_random_pattern(
193
+ self._suffix_pattern_distribution, random_generator
194
+ )
195
+ return prefix + new_value + suffix
dataset.py CHANGED
@@ -4,6 +4,7 @@ import datasets
4
 
5
  from .api import __file__ as _
6
  from .artifact import __file__ as _
 
7
  from .benchmark import __file__ as _
8
  from .blocks import __file__ as _
9
  from .card import __file__ as _
@@ -43,6 +44,7 @@ from .random_utils import __file__ as _
43
  from .recipe import __file__ as _
44
  from .register import __file__ as _
45
  from .schema import __file__ as _
 
46
  from .settings_utils import __file__ as _
47
  from .settings_utils import get_constants
48
  from .span_lableing_operators import __file__ as _
@@ -58,6 +60,7 @@ from .task import __file__ as _
58
  from .templates import __file__ as _
59
  from .text_utils import __file__ as _
60
  from .type_utils import __file__ as _
 
61
  from .utils import __file__ as _
62
  from .utils import is_package_installed
63
  from .validate import __file__ as _
 
4
 
5
  from .api import __file__ as _
6
  from .artifact import __file__ as _
7
+ from .augmentors import __file__ as _
8
  from .benchmark import __file__ as _
9
  from .blocks import __file__ as _
10
  from .card import __file__ as _
 
44
  from .recipe import __file__ as _
45
  from .register import __file__ as _
46
  from .schema import __file__ as _
47
+ from .serializers import __file__ as _
48
  from .settings_utils import __file__ as _
49
  from .settings_utils import get_constants
50
  from .span_lableing_operators import __file__ as _
 
60
  from .templates import __file__ as _
61
  from .text_utils import __file__ as _
62
  from .type_utils import __file__ as _
63
+ from .types import __file__ as _
64
  from .utils import __file__ as _
65
  from .utils import is_package_installed
66
  from .validate import __file__ as _
deprecation_utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import functools
2
  import warnings
3
 
 
4
  from .settings_utils import get_constants, get_settings
5
 
6
  constants = get_constants()
@@ -98,3 +99,39 @@ def deprecation(version, alternative=None, msg=None):
98
  return depraction_wrapper(func, version, alt_text)
99
 
100
  return decorator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import functools
2
  import warnings
3
 
4
+ from .error_utils import UnitxtWarning
5
  from .settings_utils import get_constants, get_settings
6
 
7
  constants = get_constants()
 
99
  return depraction_wrapper(func, version, alt_text)
100
 
101
  return decorator
102
+
103
+
104
+ def init_warning(msg=""):
105
+ # Decorator that raises warning when class is initialized
106
+ def decorator(initiated_class):
107
+ UnitxtWarning(msg)
108
+ return initiated_class
109
+
110
+ return decorator
111
+
112
+
113
+ def warn_on_call(warning_type=UserWarning, msg=""):
114
+ def decorator(obj):
115
+ if isinstance(obj, type):
116
+ original_init = obj.__init__
117
+
118
+ @functools.wraps(original_init)
119
+ def new_init(self, *args, **kwargs):
120
+ warnings.warn(msg, warning_type, stacklevel=2)
121
+ original_init(self, *args, **kwargs)
122
+
123
+ obj.__init__ = new_init
124
+ return obj
125
+
126
+ if callable(obj):
127
+
128
+ @functools.wraps(obj)
129
+ def wrapper(*args, **kwargs):
130
+ warnings.warn(msg, warning_type, stacklevel=2)
131
+ return obj(*args, **kwargs)
132
+
133
+ return wrapper
134
+
135
+ raise TypeError("This decorator can only be applied to classes or functions.")
136
+
137
+ return decorator
dialog_operators.py CHANGED
@@ -11,7 +11,6 @@ dialog = [
11
  {"user": "kkk", "system": ""},
12
  ]
13
  """
14
-
15
  from typing import Any, Dict, List, Optional
16
 
17
  from .formats import SystemFormat
@@ -34,10 +33,11 @@ class SerializeDialog(InstanceFieldOperator):
34
  context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
35
  """
36
 
37
- format: Optional[SystemFormat] = None
38
  last_response_to_field: Optional[str] = None
39
  context_field: Optional[str] = None
40
  context_separator: str = " "
 
41
 
42
  def standardize_format(self, demo_format):
43
  turn_format = demo_format.replace("{source}", "{user}")
@@ -54,10 +54,11 @@ class SerializeDialog(InstanceFieldOperator):
54
  return turn_format[: turn_format.index("{user}") + len("{user}")]
55
 
56
  def get_turn_format(self, turn_format, step, length):
57
- if step == 0:
58
  turn_format = self.slice_first_turn(turn_format)
59
  if step == length - 1:
60
- turn_format = self.slice_last_turn(turn_format)
 
61
  if self.last_response_to_field is not None:
62
  turn_format = self.slice_last_response(turn_format)
63
  return turn_format
@@ -87,3 +88,149 @@ class SerializeDialog(InstanceFieldOperator):
87
  if self.last_response_to_field is not None:
88
  instance[self.last_response_to_field] = turn["system"]
89
  return dialog
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  {"user": "kkk", "system": ""},
12
  ]
13
  """
 
14
  from typing import Any, Dict, List, Optional
15
 
16
  from .formats import SystemFormat
 
33
  context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
34
  """
35
 
36
+ format: SystemFormat = None
37
  last_response_to_field: Optional[str] = None
38
  context_field: Optional[str] = None
39
  context_separator: str = " "
40
+ slice_first_and_last_turns_format: bool = True
41
 
42
  def standardize_format(self, demo_format):
43
  turn_format = demo_format.replace("{source}", "{user}")
 
54
  return turn_format[: turn_format.index("{user}") + len("{user}")]
55
 
56
  def get_turn_format(self, turn_format, step, length):
57
+ if step == 0 and self.slice_first_and_last_turns_format:
58
  turn_format = self.slice_first_turn(turn_format)
59
  if step == length - 1:
60
+ if self.slice_first_and_last_turns_format:
61
+ turn_format = self.slice_last_turn(turn_format)
62
  if self.last_response_to_field is not None:
63
  turn_format = self.slice_last_response(turn_format)
64
  return turn_format
 
88
  if self.last_response_to_field is not None:
89
  instance[self.last_response_to_field] = turn["system"]
90
  return dialog
91
+
92
+
93
+ class SerializeOpenAiFormatDialog(SerializeDialog):
94
+ """Serializes dialog data for feeding into a model.
95
+
96
+ This class takes structured dialog data in the OpenAi format, and converts it into a text format
97
+ according to a specified template. It allows for the inclusion or exclusion
98
+ of system responses and can operate on a per-turn basis or aggregate the entire
99
+ dialog.
100
+
101
+ Attributes:
102
+ field (str): The field in the input data that contains the dialog.
103
+ to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
104
+ last_user_turn_to_field (Optional[str]): Field to store the last user turn.
105
+ last_system_turn_to_field (Optional[str]): Field to store the last system turn.
106
+ context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
107
+ """
108
+
109
+ is_last_turn_user_only: bool = True
110
+
111
+ @staticmethod
112
+ def validate_openai_dialog_format(dialog: List[Dict[str, str]]) -> None:
113
+ """Validates that the given dialog follows the correct OpenAI format.
114
+
115
+ The function checks that:
116
+ 1. The dialog is a list of dictionaries.
117
+ 2. Each dictionary contains the keys 'role' and 'content'.
118
+ 3. The 'role' value is either 'user' or 'assistant'.
119
+ 4. Both 'role' and 'content' values are strings.
120
+ 5. The first 'role' is 'user'
121
+
122
+ If the dialog does not conform to the expected format, a descriptive
123
+ ValueError is raised indicating the issue.
124
+
125
+ Args:
126
+ dialog (List[Dict[str, str]]): The dialog to validate.
127
+
128
+ Raises:
129
+ ValueError: If the dialog does not meet the format requirements.
130
+ """
131
+ if not isinstance(dialog, list):
132
+ raise ValueError("Dialog must be a list of dictionaries.")
133
+
134
+ for i, entry in enumerate(dialog):
135
+ if not isinstance(entry, dict):
136
+ raise ValueError(
137
+ f"Entry {i} is not a dictionary: {entry}. Each entry in the dialog must be a dictionary."
138
+ )
139
+
140
+ if "role" not in entry:
141
+ raise ValueError(
142
+ f"Entry {i} is missing the 'role' key: {entry}. Each dictionary must have a 'role' key."
143
+ )
144
+
145
+ if "content" not in entry:
146
+ raise ValueError(
147
+ f"Entry {i} is missing the 'content' key: {entry}. Each dictionary must have a 'content' key."
148
+ )
149
+
150
+ if not isinstance(entry["role"], str):
151
+ raise ValueError(
152
+ f"Entry {i} has a non-string 'role': {entry['role']}. The 'role' value must be a string."
153
+ )
154
+
155
+ if not isinstance(entry["content"], str):
156
+ raise ValueError(
157
+ f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
158
+ )
159
+
160
+ if entry["role"] not in {"user", "assistant"}:
161
+ raise ValueError(
162
+ f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
163
+ )
164
+
165
+ first_entry = dialog[0]
166
+ if first_entry["role"] != "user":
167
+ raise ValueError(
168
+ f"First entry role is expected to be 'user' It is {first_entry['role']}."
169
+ )
170
+
171
+ @staticmethod
172
+ def merge_dialog_entries(dialog: List[Dict[str, str]]) -> List[Dict[str, str]]:
173
+ """Merges consecutive dialog entries with the same role.
174
+
175
+ Args:
176
+ dialog (List[Dict[str, str]]): The input dialog list where each dictionary has a 'role' and 'content'.
177
+
178
+ Returns:
179
+ List[Dict[str, str]]: A new list where consecutive entries with the same role are merged.
180
+ """
181
+ if len(dialog) == 0:
182
+ return []
183
+
184
+ merged_dialog = [dialog[0]]
185
+
186
+ for entry in dialog[1:]:
187
+ if entry["role"] == merged_dialog[-1]["role"]:
188
+ merged_dialog[-1]["content"] += " " + entry["content"]
189
+ else:
190
+ merged_dialog.append(entry)
191
+
192
+ return merged_dialog
193
+
194
+ def transform_dialog_to_standard_format(
195
+ self, dialog: List[Dict[str, str]]
196
+ ) -> List[Dict[str, str]]:
197
+ """Transforms a dialog from OpenAI format to a simplified format.
198
+
199
+ Each dictionary
200
+ contains 'user' and 'system' keys with their respective contents. Consecutive entries
201
+ with the same role are merged. Entries with invalid roles raise an error.
202
+
203
+ Args:
204
+ dialog (List[Dict[str, str]]): The input dialog in OpenAI format.
205
+
206
+ Returns:
207
+ List[Dict[str, str]]: The transformed dialog.
208
+
209
+ Raises:
210
+ ValueError: If an invalid role is detected.
211
+ """
212
+ SerializeOpenAiFormatDialog.validate_openai_dialog_format(dialog)
213
+ merged_dialog = SerializeOpenAiFormatDialog.merge_dialog_entries(dialog)
214
+ # self.validate_dialog_have_complete_pairs(merged_dialog)
215
+
216
+ result = []
217
+ for i in range(0, len(merged_dialog) - 1, 2):
218
+ user_entry = merged_dialog[i]
219
+ system_entry = merged_dialog[i + 1]
220
+
221
+ result.append(
222
+ {"user": user_entry["content"], "system": system_entry["content"]}
223
+ )
224
+ if len(merged_dialog) % 2 != 0:
225
+ user_entry = merged_dialog[-1]
226
+ result.append({"user": user_entry["content"], "system": ""})
227
+
228
+ return result
229
+
230
+ def process_instance_value(
231
+ self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any]
232
+ ):
233
+ standard_format_dialog = self.transform_dialog_to_standard_format(
234
+ structured_dialog
235
+ )
236
+ return super().process_instance_value(standard_format_dialog, instance)
formats.py CHANGED
@@ -164,7 +164,7 @@ class SystemFormat(BaseFormat):
164
  demos is not None and isoftype(demos, List[Dict[str, Any]])
165
  ), f"A list of dict-s is expected in field '{self.demos_field}'. Received instance: {instance}"
166
  demo_instances = demos
167
- instance.pop(self.demos_field)
168
 
169
  demos_string = ""
170
  for demo_instance in demo_instances:
@@ -226,14 +226,16 @@ class HFSystemFormat(BaseFormat):
226
  """
227
 
228
  model_name: str
 
229
 
230
- def process(
231
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
232
- ) -> Dict[str, Any]:
233
  from transformers import AutoTokenizer
234
 
235
- tokenizer = AutoTokenizer.from_pretrained(self.model_name)
236
 
 
 
 
237
  assert (
238
  "source" in instance
239
  ), f"field 'source' is expected to be in the input instance. Received instance: {instance}"
@@ -267,7 +269,7 @@ class HFSystemFormat(BaseFormat):
267
  demos is not None and isoftype(demos, List[Dict[str, Any]])
268
  ), f"A list of dict-s is expected in field '{self.demos_field}'. Received instance: {instance}"
269
  demo_instances = demos
270
- instance.pop(self.demos_field)
271
 
272
  for demo_instance in demo_instances:
273
  messages.extend(
@@ -280,7 +282,7 @@ class HFSystemFormat(BaseFormat):
280
  ]
281
  )
282
  messages.extend([{"role": "user", "content": source}])
283
- tokenized_chat = tokenizer.apply_chat_template(
284
  messages, tokenize=False, add_generation_prompt=True
285
  )
286
 
 
164
  demos is not None and isoftype(demos, List[Dict[str, Any]])
165
  ), f"A list of dict-s is expected in field '{self.demos_field}'. Received instance: {instance}"
166
  demo_instances = demos
167
+ # instance.pop(self.demos_field)
168
 
169
  demos_string = ""
170
  for demo_instance in demo_instances:
 
226
  """
227
 
228
  model_name: str
229
+ _requirements_list = ["transformers"]
230
 
231
+ def prepare(self):
 
 
232
  from transformers import AutoTokenizer
233
 
234
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
235
 
236
+ def process(
237
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
238
+ ) -> Dict[str, Any]:
239
  assert (
240
  "source" in instance
241
  ), f"field 'source' is expected to be in the input instance. Received instance: {instance}"
 
269
  demos is not None and isoftype(demos, List[Dict[str, Any]])
270
  ), f"A list of dict-s is expected in field '{self.demos_field}'. Received instance: {instance}"
271
  demo_instances = demos
272
+ # instance.pop(self.demos_field)
273
 
274
  for demo_instance in demo_instances:
275
  messages.extend(
 
282
  ]
283
  )
284
  messages.extend([{"role": "user", "content": source}])
285
+ tokenized_chat = self.tokenizer.apply_chat_template(
286
  messages, tokenize=False, add_generation_prompt=True
287
  )
288
 
image_operators.py CHANGED
@@ -1,8 +1,25 @@
 
 
1
  import re
 
2
  from typing import Any, Dict
3
 
 
 
4
  from .dict_utils import dict_get
5
- from .operators import InstanceFieldOperator
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def extract_images(text, instance):
@@ -15,12 +32,44 @@ def extract_images(text, instance):
15
  return images
16
 
17
 
18
- class ImageToText(InstanceFieldOperator):
 
 
 
 
 
 
 
 
 
19
  def process_instance_value(self, value: Any, instance: Dict[str, Any]):
20
- if "media" not in instance:
21
- instance["media"] = {}
22
- if "images" not in instance["media"]:
23
- instance["media"]["images"] = []
24
- idx = len(instance["media"]["images"])
25
- instance["media"]["images"].append(value)
26
- return f'<img src="media/images/{idx}">'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
  import re
4
+ from abc import abstractmethod
5
  from typing import Any, Dict
6
 
7
+ import numpy as np
8
+
9
  from .dict_utils import dict_get
10
+ from .operators import FieldOperator, InstanceFieldOperator, PackageRequirementsMixin
11
+
12
+
13
+ class PillowMixin(PackageRequirementsMixin):
14
+ _requirements_list = {"PIL": "pip install pillow"}
15
+
16
+ def prepare(self):
17
+ super().prepare()
18
+ import PIL
19
+ from PIL import Image
20
+
21
+ self.pil = PIL
22
+ self.image = Image
23
 
24
 
25
  def extract_images(text, instance):
 
32
  return images
33
 
34
 
35
+ class DecodeImage(FieldOperator, PillowMixin):
36
+ def decode_base64_to_image(self, base64_string):
37
+ image_data = base64.b64decode(base64_string)
38
+ return self.image.open(io.BytesIO(image_data))
39
+
40
+ def process_value(self, value: Any) -> Any:
41
+ return {"image": self.decode_base64_to_image(value)}
42
+
43
+
44
+ class ToImage(InstanceFieldOperator):
45
  def process_instance_value(self, value: Any, instance: Dict[str, Any]):
46
+ return {"image": value}
47
+
48
+
49
+ class ImageFieldOperator(FieldOperator, PillowMixin):
50
+ @abstractmethod
51
+ def process_image(self, image):
52
+ pass
53
+
54
+ def process_value(self, value: Any) -> Any:
55
+ if not isinstance(value, self.image.Image):
56
+ raise ValueError(f"ImageFieldOperator requires image, got {type(value)}.")
57
+ return self.process_image(value)
58
+
59
+
60
+ class GrayScale(ImageFieldOperator):
61
+ def process_image(self, image):
62
+ # Convert the image to grayscale
63
+ grayscale_image = image.convert("L")
64
+
65
+ # Convert the grayscale image to a NumPy array
66
+ grayscale_array = np.array(grayscale_image)
67
+
68
+ # Add a dummy channel dimension to make it (height, width, 1)
69
+ grayscale_array = np.expand_dims(grayscale_array, axis=-1)
70
+
71
+ # Repeat the channel to have (height, width, 3) if needed for compatibility
72
+ grayscale_array = np.repeat(grayscale_array, 3, axis=-1)
73
+
74
+ # Convert back to a PIL image with 3 channels
75
+ return self.image.fromarray(grayscale_array)
inference.py CHANGED
@@ -5,12 +5,15 @@ from typing import Any, Dict, List, Literal, Optional, Union
5
 
6
  from tqdm import tqdm
7
 
8
- from .artifact import Artifact
9
  from .dataclass import InternalField, NonPositionalField
10
  from .deprecation_utils import deprecation
11
  from .image_operators import extract_images
12
  from .logging_utils import get_logger
13
  from .operator import PackageRequirementsMixin
 
 
 
14
 
15
 
16
  class InferenceEngine(abc.ABC, Artifact):
@@ -21,9 +24,20 @@ class InferenceEngine(abc.ABC, Artifact):
21
  """Perform inference on the input dataset."""
22
  pass
23
 
 
 
 
 
 
 
 
 
 
24
  def infer(self, dataset) -> str:
25
  """Verifies instances of a dataset and performs inference."""
26
  [self.verify_instance(instance) for instance in dataset]
 
 
27
  return self._infer(dataset)
28
 
29
  @deprecation(version="2.0.0")
@@ -122,7 +136,7 @@ class HFPipelineBasedInferenceEngine(
122
  model=self.model_name, trust_remote_code=True, **model_args
123
  )
124
 
125
- def prepare(self):
126
  if not self.lazy_load:
127
  self._prepare_pipeline()
128
 
@@ -144,13 +158,17 @@ class HFPipelineBasedInferenceEngine(
144
  class MockInferenceEngine(InferenceEngine):
145
  model_name: str
146
 
147
- def prepare(self):
148
  return
149
 
150
  def _infer(self, dataset):
151
  return ["[[10]]" for instance in dataset]
152
 
153
 
 
 
 
 
154
  class IbmGenAiInferenceEngineParamsMixin(Artifact):
155
  beam_width: Optional[int] = None
156
  decoding_method: Optional[Literal["greedy", "sample"]] = None
@@ -190,6 +208,57 @@ class IbmGenAiInferenceEngineParams(Artifact):
190
  typical_p: Optional[float] = None
191
 
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  class IbmGenAiInferenceEngine(
194
  InferenceEngine, IbmGenAiInferenceEngineParamsMixin, PackageRequirementsMixin
195
  ):
@@ -201,11 +270,12 @@ class IbmGenAiInferenceEngine(
201
  data_classification_policy = ["public", "proprietary"]
202
  parameters: Optional[IbmGenAiInferenceEngineParams] = None
203
 
204
- def prepare(self):
205
  from genai import Client, Credentials
206
 
207
  api_key_env_var_name = "GENAI_KEY"
208
  api_key = os.environ.get(api_key_env_var_name)
 
209
  assert api_key is not None, (
210
  f"Error while trying to run IbmGenAiInferenceEngine."
211
  f" Please set the environment param '{api_key_env_var_name}'."
@@ -242,9 +312,9 @@ class OpenAiInferenceEngineParamsMixin(Artifact):
242
  top_p: Optional[float] = None
243
  top_logprobs: Optional[int] = 20
244
  logit_bias: Optional[Dict[str, int]] = None
245
- logprobs: Optional[bool] = None
246
  n: Optional[int] = None
247
- parallel_tool_calls: bool = None
248
  service_tier: Optional[Literal["auto", "default"]] = None
249
 
250
 
@@ -259,9 +329,9 @@ class OpenAiInferenceEngineParams(Artifact):
259
  top_p: Optional[float] = None
260
  top_logprobs: Optional[int] = 20
261
  logit_bias: Optional[Dict[str, int]] = None
262
- logprobs: Optional[bool] = None
263
  n: Optional[int] = None
264
- parallel_tool_calls: bool = None
265
  service_tier: Optional[Literal["auto", "default"]] = None
266
 
267
 
@@ -279,7 +349,7 @@ class OpenAiInferenceEngine(
279
  data_classification_policy = ["public"]
280
  parameters: Optional[OpenAiInferenceEngineParams] = None
281
 
282
- def prepare(self):
283
  from openai import OpenAI
284
 
285
  api_key_env_var_name = "OPENAI_API_KEY"
@@ -293,6 +363,13 @@ class OpenAiInferenceEngine(
293
 
294
  self._set_inference_parameters()
295
 
 
 
 
 
 
 
 
296
  def _infer(self, dataset):
297
  outputs = []
298
  for instance in tqdm(dataset, desc="Inferring with openAI API"):
@@ -308,7 +385,7 @@ class OpenAiInferenceEngine(
308
  }
309
  ],
310
  model=self.model_name,
311
- **self.to_dict([OpenAiInferenceEngineParamsMixin]),
312
  )
313
  output = response.choices[0].message.content
314
 
@@ -331,7 +408,7 @@ class OpenAiInferenceEngine(
331
  }
332
  ],
333
  model=self.model_name,
334
- **self.to_dict([OpenAiInferenceEngineParamsMixin]),
335
  )
336
  top_logprobs_response = response.choices[0].logprobs.content
337
  output = [
@@ -347,6 +424,96 @@ class OpenAiInferenceEngine(
347
  return outputs
348
 
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  class WMLInferenceEngineParamsMixin(Artifact):
351
  decoding_method: Optional[Literal["greedy", "sample"]] = None
352
  length_penalty: Optional[Dict[str, Union[int, float]]] = None
@@ -400,6 +567,7 @@ class WMLInferenceEngine(
400
  parameters (WMLInferenceEngineParams, optional): Instance of WMLInferenceEngineParams
401
  which defines inference parameters and their values. Deprecated attribute, please
402
  pass respective parameters directly to the WMLInferenceEngine class instead.
 
403
 
404
  Examples:
405
  from .api import load_dataset
@@ -433,7 +601,7 @@ class WMLInferenceEngine(
433
  }
434
  data_classification_policy = ["public", "proprietary"]
435
  parameters: Optional[WMLInferenceEngineParams] = None
436
-
437
  _client: Any = InternalField(default=None, name="WML client")
438
 
439
  def verify(self):
@@ -490,7 +658,7 @@ class WMLInferenceEngine(
490
  client.set.default_project(self.credentials["project_id"])
491
  return client
492
 
493
- def prepare(self):
494
  self._client = self._initialize_wml_client()
495
 
496
  self._set_inference_parameters()
@@ -504,10 +672,19 @@ class WMLInferenceEngine(
504
  api_client=self._client,
505
  )
506
 
507
- return model.generate_text(
508
- prompt=dataset["source"],
509
- params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
510
- )
 
 
 
 
 
 
 
 
 
511
 
512
 
513
  class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
@@ -541,7 +718,7 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
541
 
542
  self.processor = AutoProcessor.from_pretrained(self.model_name)
543
 
544
- def prepare(self):
545
  if not self.lazy_load:
546
  self._prepare_engine()
547
 
 
5
 
6
  from tqdm import tqdm
7
 
8
+ from .artifact import Artifact, fetch_artifact
9
  from .dataclass import InternalField, NonPositionalField
10
  from .deprecation_utils import deprecation
11
  from .image_operators import extract_images
12
  from .logging_utils import get_logger
13
  from .operator import PackageRequirementsMixin
14
+ from .settings_utils import get_settings
15
+
16
+ settings = get_settings()
17
 
18
 
19
  class InferenceEngine(abc.ABC, Artifact):
 
24
  """Perform inference on the input dataset."""
25
  pass
26
 
27
+ @abc.abstractmethod
28
+ def prepare_engine(self):
29
+ """Perform inference on the input dataset."""
30
+ pass
31
+
32
+ def prepare(self):
33
+ if not settings.mock_inference_mode:
34
+ self.prepare_engine()
35
+
36
  def infer(self, dataset) -> str:
37
  """Verifies instances of a dataset and performs inference."""
38
  [self.verify_instance(instance) for instance in dataset]
39
+ if settings.mock_inference_mode:
40
+ return [instance["source"] for instance in dataset]
41
  return self._infer(dataset)
42
 
43
  @deprecation(version="2.0.0")
 
136
  model=self.model_name, trust_remote_code=True, **model_args
137
  )
138
 
139
+ def prepare_engine(self):
140
  if not self.lazy_load:
141
  self._prepare_pipeline()
142
 
 
158
  class MockInferenceEngine(InferenceEngine):
159
  model_name: str
160
 
161
+ def prepare_engine(self):
162
  return
163
 
164
  def _infer(self, dataset):
165
  return ["[[10]]" for instance in dataset]
166
 
167
 
168
+ class MockModeMixin(Artifact):
169
+ mock_mode: bool = False
170
+
171
+
172
  class IbmGenAiInferenceEngineParamsMixin(Artifact):
173
  beam_width: Optional[int] = None
174
  decoding_method: Optional[Literal["greedy", "sample"]] = None
 
208
  typical_p: Optional[float] = None
209
 
210
 
211
+ class GenericInferenceEngine(InferenceEngine):
212
+ default: Optional[str] = None
213
+
214
+ def prepare_engine(self):
215
+ if "UNITXT_INFERENCE_ENGINE" in os.environ:
216
+ engine_reference = os.environ["UNITXT_INFERENCE_ENGINE"]
217
+ else:
218
+ assert self.default is not None, (
219
+ "GenericInferenceEngine could not be initialized"
220
+ '\nThis is since both the "UNITXT_INFERENCE_ENGINE" environmental variable is not set and no default engine was not inputted.'
221
+ "\nFor example, you can fix it by setting"
222
+ "\nexport UNITXT_INFERENCE_ENGINE=engines.ibm_gen_ai.llama_3_70b_instruct"
223
+ "\nto your ~/.bashrc"
224
+ "\nor passing a similar required engine in the default argument"
225
+ )
226
+ engine_reference = self.default
227
+ self.engine, _ = fetch_artifact(engine_reference)
228
+
229
+ def _infer(self, dataset):
230
+ return self.engine._infer(dataset)
231
+
232
+
233
+ class OllamaInferenceEngine(InferenceEngine, PackageRequirementsMixin):
234
+ label: str = "ollama"
235
+ model_name: str
236
+ _requirements_list = {
237
+ "ollama": "Install ollama package using 'pip install --upgrade ollama"
238
+ }
239
+ data_classification_policy = ["public", "proprietary"]
240
+
241
+ def prepare_engine(self):
242
+ pass
243
+
244
+ def _infer(self, dataset):
245
+ import ollama
246
+
247
+ result = [
248
+ ollama.chat(
249
+ model="llama2",
250
+ messages=[
251
+ {
252
+ "role": "user",
253
+ "content": instance["source"],
254
+ },
255
+ ],
256
+ )
257
+ for instance in dataset
258
+ ]
259
+ return [element["message"]["content"] for element in result]
260
+
261
+
262
  class IbmGenAiInferenceEngine(
263
  InferenceEngine, IbmGenAiInferenceEngineParamsMixin, PackageRequirementsMixin
264
  ):
 
270
  data_classification_policy = ["public", "proprietary"]
271
  parameters: Optional[IbmGenAiInferenceEngineParams] = None
272
 
273
+ def prepare_engine(self):
274
  from genai import Client, Credentials
275
 
276
  api_key_env_var_name = "GENAI_KEY"
277
  api_key = os.environ.get(api_key_env_var_name)
278
+
279
  assert api_key is not None, (
280
  f"Error while trying to run IbmGenAiInferenceEngine."
281
  f" Please set the environment param '{api_key_env_var_name}'."
 
312
  top_p: Optional[float] = None
313
  top_logprobs: Optional[int] = 20
314
  logit_bias: Optional[Dict[str, int]] = None
315
+ logprobs: Optional[bool] = True
316
  n: Optional[int] = None
317
+ parallel_tool_calls: Optional[bool] = None
318
  service_tier: Optional[Literal["auto", "default"]] = None
319
 
320
 
 
329
  top_p: Optional[float] = None
330
  top_logprobs: Optional[int] = 20
331
  logit_bias: Optional[Dict[str, int]] = None
332
+ logprobs: Optional[bool] = True
333
  n: Optional[int] = None
334
+ parallel_tool_calls: Optional[bool] = None
335
  service_tier: Optional[Literal["auto", "default"]] = None
336
 
337
 
 
349
  data_classification_policy = ["public"]
350
  parameters: Optional[OpenAiInferenceEngineParams] = None
351
 
352
+ def prepare_engine(self):
353
  from openai import OpenAI
354
 
355
  api_key_env_var_name = "OPENAI_API_KEY"
 
363
 
364
  self._set_inference_parameters()
365
 
366
+ def _get_completion_kwargs(self):
367
+ return {
368
+ k: v
369
+ for k, v in self.to_dict([OpenAiInferenceEngineParamsMixin]).items()
370
+ if v is not None
371
+ }
372
+
373
  def _infer(self, dataset):
374
  outputs = []
375
  for instance in tqdm(dataset, desc="Inferring with openAI API"):
 
385
  }
386
  ],
387
  model=self.model_name,
388
+ **self._get_completion_kwargs(),
389
  )
390
  output = response.choices[0].message.content
391
 
 
408
  }
409
  ],
410
  model=self.model_name,
411
+ **self._get_completion_kwargs(),
412
  )
413
  top_logprobs_response = response.choices[0].logprobs.content
414
  output = [
 
424
  return outputs
425
 
426
 
427
+ class TogetherAiInferenceEngineParamsMixin(Artifact):
428
+ max_tokens: Optional[int] = None
429
+ stop: Optional[List[str]] = None
430
+ temperature: Optional[float] = None
431
+ top_p: Optional[float] = None
432
+ top_k: Optional[int] = None
433
+ repetition_penalty: Optional[float] = None
434
+ logprobs: Optional[int] = None
435
+ echo: Optional[bool] = None
436
+ n: Optional[int] = None
437
+ min_p: Optional[float] = None
438
+ presence_penalty: Optional[float] = None
439
+ frequency_penalty: Optional[float] = None
440
+
441
+
442
+ class TogetherAiInferenceEngine(
443
+ InferenceEngine, TogetherAiInferenceEngineParamsMixin, PackageRequirementsMixin
444
+ ):
445
+ label: str = "together"
446
+ model_name: str
447
+ _requirements_list = {
448
+ "together": "Install together package using 'pip install --upgrade together"
449
+ }
450
+ data_classification_policy = ["public"]
451
+ parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None
452
+
453
+ def prepare_engine(self):
454
+ from together import Together
455
+ from together.types.models import ModelType
456
+
457
+ api_key_env_var_name = "TOGETHER_API_KEY"
458
+ api_key = os.environ.get(api_key_env_var_name)
459
+ assert api_key is not None, (
460
+ f"Error while trying to run TogetherAiInferenceEngine."
461
+ f" Please set the environment param '{api_key_env_var_name}'."
462
+ )
463
+ self.client = Together(api_key=api_key)
464
+ self._set_inference_parameters()
465
+
466
+ # Get model type from Together List Models API
467
+ together_models = self.client.models.list()
468
+ together_model_id_to_type = {
469
+ together_model.id: together_model.type for together_model in together_models
470
+ }
471
+ model_type = together_model_id_to_type.get(self.model_name)
472
+ assert model_type is not None, (
473
+ f"Could not find model {self.model_name} " "in Together AI model list"
474
+ )
475
+ assert model_type in [ModelType.CHAT, ModelType.LANGUAGE, ModelType.CODE], (
476
+ f"Together AI model type {model_type} is not supported; "
477
+ "supported types are 'chat', 'language' and 'code'."
478
+ )
479
+ self.model_type = model_type
480
+
481
+ def _get_infer_kwargs(self):
482
+ return {
483
+ k: v
484
+ for k, v in self.to_dict([TogetherAiInferenceEngineParamsMixin]).items()
485
+ if v is not None
486
+ }
487
+
488
+ def _infer_chat(self, prompt: str) -> str:
489
+ response = self.client.chat.completions.create(
490
+ model=self.model_name,
491
+ messages=[{"role": "user", "content": prompt}],
492
+ **self._get_infer_kwargs(),
493
+ )
494
+ return response.choices[0].message.content
495
+
496
+ def _infer_text(self, prompt: str) -> str:
497
+ response = self.client.completions.create(
498
+ model=self.model_name,
499
+ prompt=prompt,
500
+ **self._get_infer_kwargs(),
501
+ )
502
+ return response.choices[0].text
503
+
504
+ def _infer(self, dataset):
505
+ from together.types.models import ModelType
506
+
507
+ outputs = []
508
+ if self.model_type == ModelType.CHAT:
509
+ for instance in tqdm(dataset, desc="Inferring with Together AI Chat API"):
510
+ outputs.append(self._infer_chat(instance["source"]))
511
+ else:
512
+ for instance in tqdm(dataset, desc="Inferring with Together AI Text API"):
513
+ outputs.append(self._infer_text(instance["source"]))
514
+ return outputs
515
+
516
+
517
  class WMLInferenceEngineParamsMixin(Artifact):
518
  decoding_method: Optional[Literal["greedy", "sample"]] = None
519
  length_penalty: Optional[Dict[str, Union[int, float]]] = None
 
567
  parameters (WMLInferenceEngineParams, optional): Instance of WMLInferenceEngineParams
568
  which defines inference parameters and their values. Deprecated attribute, please
569
  pass respective parameters directly to the WMLInferenceEngine class instead.
570
+ concurrency_limit (int): number of requests that will be sent in parallel, max is 10.
571
 
572
  Examples:
573
  from .api import load_dataset
 
601
  }
602
  data_classification_policy = ["public", "proprietary"]
603
  parameters: Optional[WMLInferenceEngineParams] = None
604
+ concurrency_limit: int = 10
605
  _client: Any = InternalField(default=None, name="WML client")
606
 
607
  def verify(self):
 
658
  client.set.default_project(self.credentials["project_id"])
659
  return client
660
 
661
+ def prepare_engine(self):
662
  self._client = self._initialize_wml_client()
663
 
664
  self._set_inference_parameters()
 
672
  api_client=self._client,
673
  )
674
 
675
+ # the class was previously used with a dataset that is a single instance
676
+ dataset = dataset if isinstance(dataset, list) else [dataset]
677
+
678
+ result = [
679
+ model.generate_text(
680
+ prompt=instance["source"],
681
+ params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
682
+ )
683
+ for instance in dataset
684
+ ]
685
+
686
+ # the class was previously used with a dataset that is a single instance
687
+ return result[0] if not isinstance(dataset, list) else result
688
 
689
 
690
  class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
 
718
 
719
  self.processor = AutoProcessor.from_pretrained(self.model_name)
720
 
721
+ def prepare_engine(self):
722
  if not self.lazy_load:
723
  self._prepare_engine()
724
 
llm_as_judge.py CHANGED
@@ -144,13 +144,13 @@ class LLMAsJudge(BulkInstanceMetric):
144
  )
145
 
146
  if isinstance(self.inference_model, OpenAiInferenceEngine):
147
- if self.format:
148
  raise ValueError(
149
  "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
150
  "not support formatting. Please remove the format definition from the recipe"
151
  " (OpenAi Chat API take care of the formatting automatically)."
152
  )
153
- if self.system_prompt:
154
  raise ValueError(
155
  "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
156
  "not support system prompt. Please remove the system_prompt definition from the recipe"
@@ -181,9 +181,17 @@ class LLMAsJudge(BulkInstanceMetric):
181
  results = []
182
  for instance in outputs:
183
  if self.task == "pairwise_comparative_rating.single_turn":
184
- is_model_b_the_baseline = (
185
- instance["task_data"]["model_b"] == "baseline_model"
 
 
 
 
 
 
186
  )
 
 
187
  if is_model_b_the_baseline:
188
  model_a_preference_score = instance["prediction"]
189
  else:
 
144
  )
145
 
146
  if isinstance(self.inference_model, OpenAiInferenceEngine):
147
+ if self.format and type(self.format) is not SystemFormat:
148
  raise ValueError(
149
  "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
150
  "not support formatting. Please remove the format definition from the recipe"
151
  " (OpenAi Chat API take care of the formatting automatically)."
152
  )
153
+ if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
154
  raise ValueError(
155
  "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
156
  "not support system prompt. Please remove the system_prompt definition from the recipe"
 
181
  results = []
182
  for instance in outputs:
183
  if self.task == "pairwise_comparative_rating.single_turn":
184
+ import json
185
+
186
+ # seems like the task data sometimes comes as a string, not a dict
187
+ # this fixes it
188
+ task_data = (
189
+ json.loads(instance["task_data"])
190
+ if isinstance(instance["task_data"], str)
191
+ else instance["task_data"]
192
  )
193
+
194
+ is_model_b_the_baseline = task_data["model_b"] == "baseline_model"
195
  if is_model_b_the_baseline:
196
  model_a_preference_score = instance["prediction"]
197
  else:
loaders.py CHANGED
@@ -151,6 +151,7 @@ class LoadHF(Loader):
151
  data_dir: Optional directory to store downloaded data.
152
  split: Optional specification of which split to load.
153
  data_files: Optional specification of particular data files to load.
 
154
  streaming: Bool indicating if streaming should be used.
155
  filtering_lambda: A lambda function for filtering the data after loading.
156
  num_proc: Optional integer to specify the number of processes to use for parallel dataset loading.
@@ -170,6 +171,7 @@ class LoadHF(Loader):
170
  data_files: Optional[
171
  Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
172
  ] = None
 
173
  streaming: bool = True
174
  filtering_lambda: Optional[str] = None
175
  num_proc: Optional[int] = None
@@ -199,6 +201,7 @@ class LoadHF(Loader):
199
  name=self.name,
200
  data_dir=self.data_dir,
201
  data_files=self.data_files,
 
202
  streaming=self.streaming,
203
  cache_dir=None if self.streaming else dir_to_be_deleted,
204
  split=self.split,
@@ -488,6 +491,7 @@ class LoadFromIBMCloud(Loader):
488
  bucket_name: Name of the S3 bucket from which to load data.
489
  data_dir: Optional directory path within the bucket.
490
  data_files: Union type allowing either a list of file names or a mapping of splits to file names.
 
491
  caching: Bool indicating if caching is enabled to avoid re-downloading data.
492
 
493
  Example:
@@ -511,6 +515,7 @@ class LoadFromIBMCloud(Loader):
511
  data_dir: str = None
512
 
513
  data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
 
514
  caching: bool = True
515
  data_classification_policy = ["proprietary"]
516
 
@@ -636,10 +641,13 @@ class LoadFromIBMCloud(Loader):
636
  )
637
 
638
  if isinstance(self.data_files, list):
639
- dataset = hf_load_dataset(local_dir, streaming=False)
640
  else:
641
  dataset = hf_load_dataset(
642
- local_dir, streaming=False, data_files=self.data_files
 
 
 
643
  )
644
 
645
  return MultiStream.from_iterables(dataset)
 
151
  data_dir: Optional directory to store downloaded data.
152
  split: Optional specification of which split to load.
153
  data_files: Optional specification of particular data files to load.
154
+ revision: Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
155
  streaming: Bool indicating if streaming should be used.
156
  filtering_lambda: A lambda function for filtering the data after loading.
157
  num_proc: Optional integer to specify the number of processes to use for parallel dataset loading.
 
171
  data_files: Optional[
172
  Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
173
  ] = None
174
+ revision: Optional[str] = None
175
  streaming: bool = True
176
  filtering_lambda: Optional[str] = None
177
  num_proc: Optional[int] = None
 
201
  name=self.name,
202
  data_dir=self.data_dir,
203
  data_files=self.data_files,
204
+ revision=self.revision,
205
  streaming=self.streaming,
206
  cache_dir=None if self.streaming else dir_to_be_deleted,
207
  split=self.split,
 
491
  bucket_name: Name of the S3 bucket from which to load data.
492
  data_dir: Optional directory path within the bucket.
493
  data_files: Union type allowing either a list of file names or a mapping of splits to file names.
494
+ data_field: The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
495
  caching: Bool indicating if caching is enabled to avoid re-downloading data.
496
 
497
  Example:
 
515
  data_dir: str = None
516
 
517
  data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
518
+ data_field: str = None
519
  caching: bool = True
520
  data_classification_policy = ["proprietary"]
521
 
 
641
  )
642
 
643
  if isinstance(self.data_files, list):
644
+ dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
645
  else:
646
  dataset = hf_load_dataset(
647
+ local_dir,
648
+ streaming=False,
649
+ data_files=self.data_files,
650
+ field=self.data_field,
651
  )
652
 
653
  return MultiStream.from_iterables(dataset)
metric.py CHANGED
@@ -4,6 +4,7 @@ import evaluate
4
 
5
  from .api import __file__ as _
6
  from .artifact import __file__ as _
 
7
  from .benchmark import __file__ as _
8
  from .blocks import __file__ as _
9
  from .card import __file__ as _
@@ -42,6 +43,7 @@ from .random_utils import __file__ as _
42
  from .recipe import __file__ as _
43
  from .register import __file__ as _
44
  from .schema import __file__ as _
 
45
  from .settings_utils import __file__ as _
46
  from .settings_utils import get_constants
47
  from .span_lableing_operators import __file__ as _
@@ -57,6 +59,7 @@ from .task import __file__ as _
57
  from .templates import __file__ as _
58
  from .text_utils import __file__ as _
59
  from .type_utils import __file__ as _
 
60
  from .utils import __file__ as _
61
  from .utils import is_package_installed
62
  from .validate import __file__ as _
 
4
 
5
  from .api import __file__ as _
6
  from .artifact import __file__ as _
7
+ from .augmentors import __file__ as _
8
  from .benchmark import __file__ as _
9
  from .blocks import __file__ as _
10
  from .card import __file__ as _
 
43
  from .recipe import __file__ as _
44
  from .register import __file__ as _
45
  from .schema import __file__ as _
46
+ from .serializers import __file__ as _
47
  from .settings_utils import __file__ as _
48
  from .settings_utils import get_constants
49
  from .span_lableing_operators import __file__ as _
 
59
  from .templates import __file__ as _
60
  from .text_utils import __file__ as _
61
  from .type_utils import __file__ as _
62
+ from .types import __file__ as _
63
  from .utils import __file__ as _
64
  from .utils import is_package_installed
65
  from .validate import __file__ as _
metrics.py CHANGED
@@ -421,7 +421,7 @@ class MetricWithConfidenceInterval(Metric):
421
  full_score_name = ci_score_prefix + score_name
422
  result[f"{full_score_name}_ci_low"] = ci.low
423
  result[f"{full_score_name}_ci_high"] = ci.high
424
- if score_name == self.main_score:
425
  result["score_ci_low"] = ci.low
426
  result["score_ci_high"] = ci.high
427
  return result
@@ -1183,7 +1183,11 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1183
  return instances
1184
 
1185
  def get_group_scores(
1186
- self, instances: List[dict], score_names: List[str], group_aggregation_func
 
 
 
 
1187
  ):
1188
  """Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
1189
 
@@ -1193,6 +1197,8 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1193
  group_aggregation_func: Callable aggregation function accepting a list of numeric scores;
1194
  or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
1195
  callable function returns a single score for the group
 
 
1196
 
1197
  Returns:
1198
  List of dicts, each corresponding to a group of instances (defined by 'group_id'),
@@ -1222,7 +1228,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1222
  )
1223
  for score_name in score_names:
1224
  group_to_instance_scores[group_key][score_name][subgroup_type].append(
1225
- instance["score"]["instance"][score_name]
 
 
1226
  )
1227
 
1228
  # if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
@@ -1230,7 +1238,8 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1230
  {
1231
  "score": {
1232
  "instance": {
1233
- score_name: group_aggregation_func(
 
1234
  score_dict
1235
  if uses_subgroups
1236
  else score_dict[default_subgroup_name]
@@ -1268,7 +1277,7 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1268
  group_aggregation_func=group_aggregation_func,
1269
  ):
1270
  group_scores = self.get_group_scores(
1271
- instances, [field_name], group_aggregation_func
1272
  )
1273
  return nan_mean(
1274
  [group["score"]["instance"][field_name] for group in group_scores]
@@ -4565,8 +4574,7 @@ class NormalizedSacrebleu(HuggingfaceMetric):
4565
  scaled_fields = ["sacrebleu", "precisions"]
4566
  hf_additional_input_fields_pass_one_value = ["tokenize"]
4567
  _requirements_list = {
4568
- "mecab_ko": KO_ERROR_MESSAGE,
4569
- "mecab_ko_dic": KO_ERROR_MESSAGE,
4570
  }
4571
 
4572
 
 
421
  full_score_name = ci_score_prefix + score_name
422
  result[f"{full_score_name}_ci_low"] = ci.low
423
  result[f"{full_score_name}_ci_high"] = ci.high
424
+ if score_name == self.score_prefix + self.main_score:
425
  result["score_ci_low"] = ci.low
426
  result["score_ci_high"] = ci.high
427
  return result
 
1183
  return instances
1184
 
1185
  def get_group_scores(
1186
+ self,
1187
+ instances: List[dict],
1188
+ score_names: List[str],
1189
+ group_aggregation_func,
1190
+ prepend_score_prefix: bool = True,
1191
  ):
1192
  """Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
1193
 
 
1197
  group_aggregation_func: Callable aggregation function accepting a list of numeric scores;
1198
  or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
1199
  callable function returns a single score for the group
1200
+ prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False
1201
+ if down the stream such a prepending is expected.
1202
 
1203
  Returns:
1204
  List of dicts, each corresponding to a group of instances (defined by 'group_id'),
 
1228
  )
1229
  for score_name in score_names:
1230
  group_to_instance_scores[group_key][score_name][subgroup_type].append(
1231
+ instance["score"]["instance"][
1232
+ (self.score_prefix if prepend_score_prefix else "") + score_name
1233
+ ]
1234
  )
1235
 
1236
  # if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
 
1238
  {
1239
  "score": {
1240
  "instance": {
1241
+ (self.score_prefix if prepend_score_prefix else "")
1242
+ + score_name: group_aggregation_func(
1243
  score_dict
1244
  if uses_subgroups
1245
  else score_dict[default_subgroup_name]
 
1277
  group_aggregation_func=group_aggregation_func,
1278
  ):
1279
  group_scores = self.get_group_scores(
1280
+ instances, [field_name], group_aggregation_func, False
1281
  )
1282
  return nan_mean(
1283
  [group["score"]["instance"][field_name] for group in group_scores]
 
4574
  scaled_fields = ["sacrebleu", "precisions"]
4575
  hf_additional_input_fields_pass_one_value = ["tokenize"]
4576
  _requirements_list = {
4577
+ "sacrebleu": "Additional dependencies required. To install them, run: `pip install sacrebleu`."
 
4578
  }
4579
 
4580
 
operators.py CHANGED
@@ -531,230 +531,6 @@ class AddConstant(FieldOperator):
531
  return self.add + value
532
 
533
 
534
- class Augmentor(InstanceOperator):
535
- """A stream operator that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template.
536
-
537
- Args:
538
- augment_model_input: Whether to augment the input to the model.
539
- augment_task_input: Whether to augment the task input fields. The specific fields are defined in the Task operator.
540
-
541
- """
542
-
543
- augment_task_input: bool = False
544
- augment_model_input: bool = False
545
-
546
- def verify(self):
547
- assert not (
548
- self.augment_task_input and self.augment_model_input
549
- ), "Augmentor must set either 'augment_task_input' and 'augment_model_input' but not both"
550
- assert (
551
- self.augment_task_input or self.augment_model_input
552
- ), "Augmentor must set either 'augment_task_input' or 'augment_model_input'"
553
-
554
- super().verify()
555
-
556
- @abstractmethod
557
- def process_value(self, value: Any) -> Any:
558
- pass
559
-
560
- def prepare(self):
561
- pass
562
-
563
- def set_task_input_fields(self, task_input_fields: List[str]):
564
- self._task_input_fields = [
565
- "input_fields/" + task_input_field for task_input_field in task_input_fields
566
- ]
567
-
568
- def process(
569
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
570
- ) -> Dict[str, Any]:
571
- if self.augment_task_input:
572
- assert (
573
- len(self._task_input_fields) > 0
574
- ), "No augmentable input fields were defined in Task, and augmentation was requested. Specify the fields to augment in 'argumentable_inputs' attribute of the Task."
575
- fields = self._task_input_fields
576
- assert not self.augment_model_input
577
-
578
- if self.augment_model_input:
579
- fields = ["source"]
580
- assert not self.augment_task_input
581
-
582
- for field_name in fields:
583
- try:
584
- old_value = dict_get(
585
- instance,
586
- field_name,
587
- default="",
588
- not_exist_ok=False,
589
- )
590
- except ValueError as e:
591
- raise TypeError(f"Failed to get {field_name} from {instance}") from e
592
-
593
- try:
594
- new_value = self.process_value(old_value)
595
- except Exception as e:
596
- raise RuntimeError(
597
- f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
598
- ) from e
599
- dict_set(instance, field_name, new_value, not_exist_ok=True)
600
- return instance
601
-
602
-
603
- class NullAugmentor(Augmentor):
604
- """Does not change the input string."""
605
-
606
- def verify(self):
607
- pass
608
-
609
- def process_value(self, value: Any) -> Any:
610
- return value
611
-
612
-
613
- class AugmentWhitespace(Augmentor):
614
- """Augments the inputs by replacing existing whitespaces with other whitespaces.
615
-
616
- Currently, each whitespace is replaced by a random choice of 1-3 whitespace characters (space, tab, newline).
617
- """
618
-
619
- def process_value(self, value: Any) -> Any:
620
- import re
621
-
622
- words = re.split(r"(\s+)", value)
623
- new_value = ""
624
-
625
- random_generator = new_random_generator(sub_seed=value)
626
- for word in words:
627
- if word.isspace():
628
- new_value += random_generator.choice(
629
- ["\n", "\t", " "]
630
- ) * random_generator.randint(1, 3)
631
- else:
632
- new_value += word
633
- return new_value
634
-
635
-
636
- class AugmentPrefixSuffix(Augmentor):
637
- r"""Augments the input by prepending and appending to it a randomly selected (typically, whitespace) patterns.
638
-
639
- Args:
640
- prefixes, suffixes (list or dict) : the potential (typically, whitespace) patterns to select from.
641
- The dictionary version allows to specify relative weights of the different patterns.
642
- prefix_len, suffix_len (positive int) : The added prefix or suffix will be of length
643
- prefix_len of suffix_len, respectively, repetitions of the randomly selected patterns.
644
- remove_existing_whitespaces : allows to first clean any existing leading and trailing whitespaces.
645
- The strings made of repetitions of the selected pattern(s) are then prepended and/or appended to the potentially
646
- trimmed input.
647
- If only one of prefixes/suffixes is needed, set the other to None.
648
-
649
- Examples:
650
- To prepend the input with a prefix made of 4 '\n'-s or '\t'-s, employ
651
- AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)
652
- To append the input with a suffix made of 3 '\n'-s or '\t'-s, with triple '\n' suffixes
653
- being preferred over triple '\t', at 2:1 ratio, employ
654
- AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)
655
- which will append '\n'-s twice as often as '\t'-s.
656
-
657
- """
658
-
659
- prefixes: Optional[Union[List[str], Dict[str, int]]] = {
660
- " ": 20,
661
- "\\t": 10,
662
- "\\n": 40,
663
- "": 30,
664
- }
665
- prefix_len: Optional[int] = 3
666
- suffixes: Optional[Union[List[str], Dict[str, int]]] = {
667
- " ": 20,
668
- "\\t": 10,
669
- "\\n": 40,
670
- "": 30,
671
- }
672
- suffix_len: Optional[int] = 3
673
- remove_existing_whitespaces: Optional[bool] = False
674
-
675
- def verify(self):
676
- assert (
677
- self.prefixes or self.suffixes
678
- ), "At least one of prefixes/suffixes should be not None."
679
- for arg, arg_name in zip(
680
- [self.prefixes, self.suffixes], ["prefixes", "suffixes"]
681
- ):
682
- assert (
683
- arg is None or isoftype(arg, List[str]) or isoftype(arg, Dict[str, int])
684
- ), f"Argument {arg_name} should be either None or a list of strings or a dictionary str->int. {arg} is none of the above."
685
- assert (
686
- self.prefix_len > 0
687
- ), f"prefix_len must be positive, got {self.prefix_len}"
688
- assert (
689
- self.suffix_len > 0
690
- ), f"suffix_len must be positive, got {self.suffix_len}"
691
- super().verify()
692
-
693
- def _calculate_distributions(self, prefs_or_suffs):
694
- if prefs_or_suffs is None:
695
- return None, None
696
- patterns = (
697
- prefs_or_suffs
698
- if isinstance(prefs_or_suffs, list)
699
- else [k for k, v in prefs_or_suffs.items()]
700
- )
701
- total_weight = (
702
- len(patterns)
703
- if isinstance(prefs_or_suffs, list)
704
- else sum([v for k, v in prefs_or_suffs.items()])
705
- )
706
- weights = (
707
- [1.0 / total_weight] * len(patterns)
708
- if isinstance(prefs_or_suffs, list)
709
- else [float(prefs_or_suffs[p]) / total_weight for p in patterns]
710
- )
711
- return patterns, weights
712
-
713
- def prepare(self):
714
- # Being an artifact, prepare is invoked before verify. Here we need verify before the actions
715
- self.verify()
716
- self._prefix_pattern_distribution = {"length": self.prefix_len}
717
- self._suffix_pattern_distribution = {"length": self.suffix_len}
718
-
719
- (
720
- self._prefix_pattern_distribution["patterns"],
721
- self._prefix_pattern_distribution["weights"],
722
- ) = self._calculate_distributions(self.prefixes)
723
- (
724
- self._suffix_pattern_distribution["patterns"],
725
- self._suffix_pattern_distribution["weights"],
726
- ) = self._calculate_distributions(self.suffixes)
727
- super().prepare()
728
-
729
- def _get_random_pattern(
730
- self, pattern_distribution, random_generator: Random
731
- ) -> str:
732
- string_to_add = ""
733
- if pattern_distribution["patterns"]:
734
- string_to_add = "".join(
735
- random_generator.choices(
736
- pattern_distribution["patterns"],
737
- pattern_distribution["weights"],
738
- k=pattern_distribution["length"],
739
- )
740
- )
741
- return string_to_add
742
-
743
- def process_value(self, value: Any) -> Any:
744
- assert value is not None, "input value should not be None"
745
- new_value = str(value)
746
- if self.remove_existing_whitespaces:
747
- new_value = new_value.strip()
748
- random_generator = new_random_generator(sub_seed=value)
749
- prefix = self._get_random_pattern(
750
- self._prefix_pattern_distribution, random_generator
751
- )
752
- suffix = self._get_random_pattern(
753
- self._suffix_pattern_distribution, random_generator
754
- )
755
- return prefix + new_value + suffix
756
-
757
-
758
  class ShuffleFieldValues(FieldOperator):
759
  """Shuffles a list of values found in a field."""
760
 
@@ -1445,7 +1221,7 @@ class ComputeExpressionMixin(Artifact):
1445
 
1446
  def compute_expression(self, instance: dict) -> Any:
1447
  if settings.allow_unverified_code:
1448
- return eval(self.expression, self.globals, instance)
1449
 
1450
  raise ValueError(
1451
  f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
 
531
  return self.add + value
532
 
533
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
  class ShuffleFieldValues(FieldOperator):
535
  """Shuffles a list of values found in a field."""
536
 
 
1221
 
1222
  def compute_expression(self, instance: dict) -> Any:
1223
  if settings.allow_unverified_code:
1224
+ return eval(self.expression, {**self.globals, **instance})
1225
 
1226
  raise ValueError(
1227
  f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
schema.py CHANGED
@@ -69,23 +69,36 @@ class Finalize(InstanceOperatorValidator):
69
 
70
  return instance
71
 
72
- def process(
73
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
74
  ) -> Dict[str, Any]:
75
- metadata = {
76
- "data_classification_policy": instance["data_classification_policy"],
77
- "template": self.artifact_to_jsonable(
78
- instance["recipe_metadata"]["template"]
79
- ),
80
- "num_demos": instance["recipe_metadata"]["num_demos"],
81
- }
82
  task_data = {
83
  **instance["input_fields"],
84
- "metadata": metadata,
 
 
85
  }
86
-
87
- if stream_name != constants.inference_stream:
88
  task_data = {**task_data, **instance["reference_fields"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  instance["task_data"] = json.dumps(task_data)
91
 
@@ -99,7 +112,7 @@ class Finalize(InstanceOperatorValidator):
99
  for key in keys_to_delete:
100
  del instance[key]
101
 
102
- data = {**task_data, **metadata}
103
  groups = []
104
  for group_attributes in self.group_by:
105
  group = {}
 
69
 
70
  return instance
71
 
72
+ def _get_instance_task_data(
73
+ self, instance: Dict[str, Any], use_reference_fields=True
74
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
75
  task_data = {
76
  **instance["input_fields"],
77
+ "metadata": {
78
+ "data_classification_policy": instance["data_classification_policy"],
79
+ },
80
  }
81
+ if use_reference_fields:
 
82
  task_data = {**task_data, **instance["reference_fields"]}
83
+ return task_data
84
+
85
+ def process(
86
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
87
+ ) -> Dict[str, Any]:
88
+ task_data = self._get_instance_task_data(
89
+ instance,
90
+ use_reference_fields=stream_name != constants.inference_stream,
91
+ )
92
+
93
+ task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"]
94
+ task_data["metadata"]["template"] = self.artifact_to_jsonable(
95
+ instance["recipe_metadata"]["template"]
96
+ )
97
+ if "demos" in instance:
98
+ task_data["demos"] = [
99
+ self._get_instance_task_data(instance)
100
+ for instance in instance.pop("demos")
101
+ ]
102
 
103
  instance["task_data"] = json.dumps(task_data)
104
 
 
112
  for key in keys_to_delete:
113
  del instance[key]
114
 
115
+ data = {**task_data, **task_data["metadata"]}
116
  groups = []
117
  for group_attributes in self.group_by:
118
  group = {}
serializers.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ from abc import abstractmethod
4
+ from typing import Any, Dict, List, Union
5
+
6
+ from .dataclass import AbstractField, Field
7
+ from .operators import InstanceFieldOperator
8
+ from .type_utils import isoftype, to_type_string
9
+ from .types import Dialog, Image, Number, Table
10
+
11
+
12
+ class Serializer(InstanceFieldOperator):
13
+ def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str:
14
+ return self.serialize(value, instance)
15
+
16
+ @abstractmethod
17
+ def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
18
+ pass
19
+
20
+
21
+ class DefaultSerializer(Serializer):
22
+ def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
23
+ return str(value)
24
+
25
+
26
+ class SingleTypeSerializer(InstanceFieldOperator):
27
+ serialized_type: object = AbstractField()
28
+
29
+ def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str:
30
+ if not isoftype(value, self.serialized_type):
31
+ raise ValueError(
32
+ f"SingleTypeSerializer for type {self.serialized_type} should get this type. got {to_type_string(value)}"
33
+ )
34
+ return self.serialize(value, instance)
35
+
36
+
37
+ class DefaultListSerializer(Serializer):
38
+ def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
39
+ if isinstance(value, list):
40
+ return ", ".join(str(item) for item in value)
41
+ return str(value)
42
+
43
+
44
+ class ListSerializer(SingleTypeSerializer):
45
+ serialized_type = list
46
+
47
+ def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
48
+ return ", ".join(str(item) for item in value)
49
+
50
+
51
+ class DialogSerializer(SingleTypeSerializer):
52
+ serialized_type = Dialog
53
+
54
+ def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
55
+ # Convert the Dialog into a string representation, typically combining roles and content
56
+ return "\n".join(f"{turn['role']}: {turn['content']}" for turn in value)
57
+
58
+
59
+ class NumberSerializer(SingleTypeSerializer):
60
+ serialized_type = Number
61
+
62
+ def serialize(self, value: Number, instance: Dict[str, Any]) -> str:
63
+ # Check if the value is an integer or a float
64
+ if isinstance(value, int):
65
+ return str(value)
66
+ # For floats, format to one decimal place
67
+ if isinstance(value, float):
68
+ return f"{value:.1f}"
69
+ raise ValueError("Unsupported type for NumberSerializer")
70
+
71
+
72
+ class NumberQuantizingSerializer(NumberSerializer):
73
+ serialized_type = Number
74
+ quantum: Union[float, int] = 0.1
75
+
76
+ def serialize(self, value: Number, instance: Dict[str, Any]) -> str:
77
+ if isoftype(value, Number):
78
+ quantized_value = round(value / self.quantum) / (1 / self.quantum)
79
+ if isinstance(self.quantum, int):
80
+ quantized_value = int(quantized_value)
81
+ return str(quantized_value)
82
+ raise ValueError("Unsupported type for NumberSerializer")
83
+
84
+
85
+ class TableSerializer(SingleTypeSerializer):
86
+ serialized_type = Table
87
+
88
+ def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
89
+ output = io.StringIO()
90
+ writer = csv.writer(output, lineterminator="\n")
91
+
92
+ # Write the header and rows to the CSV writer
93
+ writer.writerow(value["header"])
94
+ writer.writerows(value["rows"])
95
+
96
+ # Retrieve the CSV string
97
+ return output.getvalue().strip()
98
+
99
+
100
+ class ImageSerializer(SingleTypeSerializer):
101
+ serialized_type = Image
102
+
103
+ def serialize(self, value: Image, instance: Dict[str, Any]) -> str:
104
+ if "media" not in instance:
105
+ instance["media"] = {}
106
+ if "images" not in instance["media"]:
107
+ instance["media"]["images"] = []
108
+ idx = len(instance["media"]["images"])
109
+ instance["media"]["images"].append(value["image"])
110
+ value["image"] = f'<img src="media/images/{idx}">'
111
+ return value["image"]
112
+
113
+
114
+ class MultiTypeSerializer(Serializer):
115
+ serializers: List[SingleTypeSerializer] = Field(
116
+ default_factory=lambda: [
117
+ ImageSerializer(),
118
+ TableSerializer(),
119
+ DialogSerializer(),
120
+ ]
121
+ )
122
+
123
+ def verify(self):
124
+ super().verify()
125
+ self._verify_serializers(self.serializers)
126
+
127
+ def _verify_serializers(self, serializers):
128
+ if not isoftype(serializers, List[SingleTypeSerializer]):
129
+ raise ValueError(
130
+ "MultiTypeSerializer requires the list of serializers to be List[SingleTypeSerializer]."
131
+ )
132
+
133
+ def add_serializers(self, serializers: List[SingleTypeSerializer]):
134
+ self._verify_serializers(serializers)
135
+ self.serializers = serializers + self.serializers
136
+
137
+ def serialize(self, value: Any, instance: Dict[str, Any]) -> Any:
138
+ for serializer in self.serializers:
139
+ if isoftype(value, serializer.serialized_type):
140
+ return serializer.serialize(value, instance)
141
+
142
+ return str(value)
settings_utils.py CHANGED
@@ -146,6 +146,7 @@ if Settings.is_uninitilized():
146
  settings.seed = (int, 42)
147
  settings.skip_artifacts_prepare_and_verify = (bool, False)
148
  settings.data_classification_policy = None
 
149
 
150
  if Constants.is_uninitilized():
151
  constants = Constants()
 
146
  settings.seed = (int, 42)
147
  settings.skip_artifacts_prepare_and_verify = (bool, False)
148
  settings.data_classification_policy = None
149
+ settings.mock_inference_mode = (bool, False)
150
 
151
  if Constants.is_uninitilized():
152
  constants = Constants()
standard.py CHANGED
@@ -1,20 +1,27 @@
1
  from typing import List, Optional, Union
2
 
 
 
 
 
 
 
3
  from .card import TaskCard
4
  from .collections_operators import GetLength
5
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
6
  from .formats import Format, SystemFormat
7
  from .logging_utils import get_logger
8
  from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
9
- from .operators import Augmentor, NullAugmentor, Set, StreamRefiner
10
  from .recipe import Recipe
11
  from .schema import Finalize
 
12
  from .settings_utils import get_constants
13
  from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
14
  from .stream import MultiStream
15
  from .system_prompts import EmptySystemPrompt, SystemPrompt
16
  from .task import Task
17
- from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template
18
 
19
  constants = get_constants()
20
  logger = get_logger()
@@ -29,9 +36,10 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
29
  # Base parameters
30
  card: TaskCard = None
31
  task: Task = None
32
- template: Union[Template, List[Template]] = None
33
  system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
34
  format: Format = Field(default_factory=SystemFormat)
 
35
 
36
  # Additional parameters
37
  template_card_index: int = NonPositionalField(default=None)
@@ -140,6 +148,11 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
140
  else:
141
  self.verify_template(self.template)
142
 
 
 
 
 
 
143
  def prepare_refiners(self):
144
  self.train_refiner.max_instances = self.max_train_instances
145
  self.train_refiner.apply_to_streams = ["train"]
@@ -281,8 +294,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
281
 
282
  self.processing.steps.append(self.task)
283
 
284
- if self.augmentor.augment_task_input:
285
- self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
286
  self.processing.steps.append(self.augmentor)
287
 
288
  if self.has_custom_demos_pool:
@@ -362,7 +375,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
362
 
363
  self.verbalization.steps.append(self.system_prompt)
364
  self.verbalization.steps.append(self.format)
365
- if self.augmentor.augment_model_input:
366
  self.verbalization.steps.append(self.augmentor)
367
 
368
  if self.postprocessors is not None:
@@ -376,6 +389,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
376
  self.finalize.steps.append(Finalize(group_by=self.group_by))
377
 
378
  def prepare(self):
 
 
379
  self.reset_pipeline()
380
 
381
 
 
1
  from typing import List, Optional, Union
2
 
3
+ from .augmentors import (
4
+ Augmentor,
5
+ FinalStateInputsAugmentor,
6
+ NullAugmentor,
7
+ TaskInputsAugmentor,
8
+ )
9
  from .card import TaskCard
10
  from .collections_operators import GetLength
11
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
12
  from .formats import Format, SystemFormat
13
  from .logging_utils import get_logger
14
  from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
15
+ from .operators import Set, StreamRefiner
16
  from .recipe import Recipe
17
  from .schema import Finalize
18
+ from .serializers import SingleTypeSerializer
19
  from .settings_utils import get_constants
20
  from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
21
  from .stream import MultiStream
22
  from .system_prompts import EmptySystemPrompt, SystemPrompt
23
  from .task import Task
24
+ from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList
25
 
26
  constants = get_constants()
27
  logger = get_logger()
 
36
  # Base parameters
37
  card: TaskCard = None
38
  task: Task = None
39
+ template: Union[Template, List[Template], TemplatesList] = None
40
  system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
41
  format: Format = Field(default_factory=SystemFormat)
42
+ serializer: Union[SingleTypeSerializer, List[SingleTypeSerializer]] = None
43
 
44
  # Additional parameters
45
  template_card_index: int = NonPositionalField(default=None)
 
148
  else:
149
  self.verify_template(self.template)
150
 
151
+ if self.serializer is not None:
152
+ if not isinstance(self.serializer, list):
153
+ self.serializer = [self.serializer]
154
+ self.template.serializer.add_serializers(self.serializer)
155
+
156
  def prepare_refiners(self):
157
  self.train_refiner.max_instances = self.max_train_instances
158
  self.train_refiner.apply_to_streams = ["train"]
 
294
 
295
  self.processing.steps.append(self.task)
296
 
297
+ if isinstance(self.augmentor, TaskInputsAugmentor):
298
+ self.augmentor.set_fields(self.card.task.augmentable_inputs)
299
  self.processing.steps.append(self.augmentor)
300
 
301
  if self.has_custom_demos_pool:
 
375
 
376
  self.verbalization.steps.append(self.system_prompt)
377
  self.verbalization.steps.append(self.format)
378
+ if isinstance(self.augmentor, FinalStateInputsAugmentor):
379
  self.verbalization.steps.append(self.augmentor)
380
 
381
  if self.postprocessors is not None:
 
389
  self.finalize.steps.append(Finalize(group_by=self.group_by))
390
 
391
  def prepare(self):
392
+ if isinstance(self.template, TemplatesList):
393
+ self.template = self.template.items
394
  self.reset_pipeline()
395
 
396
 
struct_data_operators.py CHANGED
@@ -29,15 +29,62 @@ import pandas as pd
29
 
30
  from .dict_utils import dict_get
31
  from .operators import FieldOperator, InstanceOperator
 
 
 
32
  from .utils import deepcopy
33
 
34
 
35
- class SerializeTable(ABC, FieldOperator):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  """TableSerializer converts a given table into a flat sequence with special symbols.
37
 
38
  Output format varies depending on the chosen serializer. This abstract class defines structure of a typical table serializer that any concrete implementation should follow.
39
  """
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # main method to serialize a table
42
  @abstractmethod
43
  def serialize_table(self, table_content: Dict) -> str:
@@ -60,10 +107,6 @@ class SerializeTableAsIndexedRowMajor(SerializeTable):
60
  Format: col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
61
  """
62
 
63
- def process_value(self, table: Any) -> Any:
64
- table_input = deepcopy(table)
65
- return self.serialize_table(table_content=table_input)
66
-
67
  # main method that processes a table
68
  # table_content must be in the presribed input format
69
  def serialize_table(self, table_content: Dict) -> str:
@@ -111,10 +154,6 @@ class SerializeTableAsMarkdown(SerializeTable):
111
  ...
112
  """
113
 
114
- def process_value(self, table: Any) -> Any:
115
- table_input = deepcopy(table)
116
- return self.serialize_table(table_content=table_input)
117
-
118
  # main method that serializes a table.
119
  # table_content must be in the presribed input format.
120
  def serialize_table(self, table_content: Dict) -> str:
@@ -159,10 +198,6 @@ class SerializeTableAsDFLoader(SerializeTable):
159
  index=[0,1,2])
160
  """
161
 
162
- def process_value(self, table: Any) -> Any:
163
- table_input = deepcopy(table)
164
- return self.serialize_table(table_content=table_input)
165
-
166
  # main method that serializes a table.
167
  # table_content must be in the presribed input format.
168
  def serialize_table(self, table_content: Dict) -> str:
@@ -199,10 +234,6 @@ class SerializeTableAsJson(SerializeTable):
199
  }
200
  """
201
 
202
- def process_value(self, table: Any) -> Any:
203
- table_input = deepcopy(table)
204
- return self.serialize_table(table_content=table_input)
205
-
206
  # main method that serializes a table.
207
  # table_content must be in the presribed input format.
208
  def serialize_table(self, table_content: Dict) -> str:
@@ -493,20 +524,7 @@ class ShuffleTableRows(FieldOperator):
493
 
494
  def process_value(self, table: Any) -> Any:
495
  table_input = deepcopy(table)
496
- return self.shuffle_rows(table_content=table_input)
497
-
498
- # shuffles table rows randomly
499
- def shuffle_rows(self, table_content: Dict) -> str:
500
- # extract header & rows from the dictionary
501
- header = table_content.get("header", [])
502
- rows = table_content.get("rows", [])
503
- assert header and rows, "Incorrect input table format"
504
-
505
- # shuffle rows
506
- random.shuffle(rows)
507
- table_content["rows"] = rows
508
-
509
- return table_content
510
 
511
 
512
  class ShuffleTableColumns(FieldOperator):
@@ -527,27 +545,7 @@ class ShuffleTableColumns(FieldOperator):
527
 
528
  def process_value(self, table: Any) -> Any:
529
  table_input = deepcopy(table)
530
- return self.shuffle_columns(table_content=table_input)
531
-
532
- # shuffles table columns randomly
533
- def shuffle_columns(self, table_content: Dict) -> str:
534
- # extract header & rows from the dictionary
535
- header = table_content.get("header", [])
536
- rows = table_content.get("rows", [])
537
- assert header and rows, "Incorrect input table format"
538
-
539
- # shuffle the indices first
540
- indices = list(range(len(header)))
541
- random.shuffle(indices) #
542
-
543
- # shuffle the header & rows based on that indices
544
- shuffled_header = [header[i] for i in indices]
545
- shuffled_rows = [[row[i] for i in indices] for row in rows]
546
-
547
- table_content["header"] = shuffled_header
548
- table_content["rows"] = shuffled_rows
549
-
550
- return table_content
551
 
552
 
553
  class LoadJson(FieldOperator):
 
29
 
30
  from .dict_utils import dict_get
31
  from .operators import FieldOperator, InstanceOperator
32
+ from .random_utils import new_random_generator
33
+ from .serializers import TableSerializer
34
+ from .types import Table
35
  from .utils import deepcopy
36
 
37
 
38
+ def shuffle_columns(table: Table, seed=0) -> Table:
39
+ # extract header & rows from the dictionary
40
+ header = table.get("header", [])
41
+ rows = table.get("rows", [])
42
+ # shuffle the indices first
43
+ indices = list(range(len(header)))
44
+ random_generator = new_random_generator({"table": table, "seed": seed})
45
+ random_generator.shuffle(indices)
46
+
47
+ # shuffle the header & rows based on that indices
48
+ shuffled_header = [header[i] for i in indices]
49
+ shuffled_rows = [[row[i] for i in indices] for row in rows]
50
+
51
+ table["header"] = shuffled_header
52
+ table["rows"] = shuffled_rows
53
+
54
+ return table
55
+
56
+
57
+ def shuffle_rows(table: Table, seed=0) -> Table:
58
+ # extract header & rows from the dictionary
59
+ rows = table.get("rows", [])
60
+ # shuffle rows
61
+ random_generator = new_random_generator({"table": table, "seed": seed})
62
+ random_generator.shuffle(rows)
63
+ table["rows"] = rows
64
+
65
+ return table
66
+
67
+
68
+ class SerializeTable(ABC, TableSerializer):
69
  """TableSerializer converts a given table into a flat sequence with special symbols.
70
 
71
  Output format varies depending on the chosen serializer. This abstract class defines structure of a typical table serializer that any concrete implementation should follow.
72
  """
73
 
74
+ seed: int = 0
75
+ shuffle_rows: bool = False
76
+ shuffle_columns: bool = False
77
+
78
+ def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
79
+ value = deepcopy(value)
80
+ if self.shuffle_columns:
81
+ value = shuffle_columns(table=value, seed=self.seed)
82
+
83
+ if self.shuffle_rows:
84
+ value = shuffle_rows(table=value, seed=self.seed)
85
+
86
+ return self.serialize_table(value)
87
+
88
  # main method to serialize a table
89
  @abstractmethod
90
  def serialize_table(self, table_content: Dict) -> str:
 
107
  Format: col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
108
  """
109
 
 
 
 
 
110
  # main method that processes a table
111
  # table_content must be in the presribed input format
112
  def serialize_table(self, table_content: Dict) -> str:
 
154
  ...
155
  """
156
 
 
 
 
 
157
  # main method that serializes a table.
158
  # table_content must be in the presribed input format.
159
  def serialize_table(self, table_content: Dict) -> str:
 
198
  index=[0,1,2])
199
  """
200
 
 
 
 
 
201
  # main method that serializes a table.
202
  # table_content must be in the presribed input format.
203
  def serialize_table(self, table_content: Dict) -> str:
 
234
  }
235
  """
236
 
 
 
 
 
237
  # main method that serializes a table.
238
  # table_content must be in the presribed input format.
239
  def serialize_table(self, table_content: Dict) -> str:
 
524
 
525
  def process_value(self, table: Any) -> Any:
526
  table_input = deepcopy(table)
527
+ return shuffle_rows(table_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
 
530
  class ShuffleTableColumns(FieldOperator):
 
545
 
546
  def process_value(self, table: Any) -> Any:
547
  table_input = deepcopy(table)
548
+ return shuffle_columns(table_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
 
551
  class LoadJson(FieldOperator):
templates.py CHANGED
@@ -10,6 +10,15 @@ from .dict_utils import dict_set
10
  from .error_utils import Documentation, UnitxtError
11
  from .operator import InstanceOperator
12
  from .random_utils import new_random_generator
 
 
 
 
 
 
 
 
 
13
  from .settings_utils import get_constants
14
  from .type_utils import isoftype
15
 
@@ -46,17 +55,26 @@ class Template(InstanceOperator):
46
  instruction: str = NonPositionalField(default="")
47
  target_prefix: str = NonPositionalField(default="")
48
  title_fields: List[str] = NonPositionalField(default_factory=list)
 
 
 
 
 
 
 
 
 
 
49
 
50
  def input_fields_to_instruction_and_target_prefix(self, input_fields):
51
  instruction = self.apply_formatting(
52
- input_fields, "input field", self.instruction, "instruction", serialize=True
53
  )
54
  target_prefix = self.apply_formatting(
55
  input_fields,
56
  "input field",
57
  self.target_prefix,
58
  "target_prefix",
59
- serialize=True,
60
  )
61
  return instruction, target_prefix
62
 
@@ -65,6 +83,12 @@ class Template(InstanceOperator):
65
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
66
  return input_fields, reference_fields
67
 
 
 
 
 
 
 
68
  def process(
69
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
70
  ) -> Dict[str, Any]:
@@ -78,14 +102,21 @@ class Template(InstanceOperator):
78
 
79
  input_fields = instance.get("input_fields")
80
  reference_fields = instance.get("reference_fields")
81
- input_fields, reference_fields = self.preprocess_input_and_reference_fields(
82
- input_fields, reference_fields
83
- )
 
 
 
 
84
 
85
  self.set_titles(input_fields)
86
- source = self.input_fields_to_source(input_fields)
 
 
 
87
  instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
88
- input_fields
89
  )
90
 
91
  result = {
@@ -97,19 +128,33 @@ class Template(InstanceOperator):
97
  }
98
 
99
  if stream_name == constants.inference_stream:
100
- return result
101
 
102
  if reference_fields is None:
103
  raise ValueError("Should have reference_fields")
104
 
 
 
 
 
 
 
105
  target, references = self.reference_fields_to_target_and_references(
106
- reference_fields
107
  )
108
 
109
  result["target"] = target
110
  result["references"] = references
111
 
112
- return result
 
 
 
 
 
 
 
 
113
 
114
  @abstractmethod
115
  def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
@@ -125,21 +170,13 @@ class Template(InstanceOperator):
125
  ) -> Tuple[str, List[str]]:
126
  pass
127
 
128
- def serialize_data(self, data):
129
- return {
130
- k: ", ".join(str(t) for t in v) if isinstance(v, list) else v
131
- for k, v in data.items()
132
- }
133
-
134
  def apply_formatting(
135
- self, data, data_type, format_str, format_name, serialize=False
136
  ) -> str:
137
- if serialize:
138
- data = self.serialize_data(data)
139
  try:
140
  if format_str is None:
141
  raise UnitxtError(
142
- f"Required field 'output_format' of class {self.__class__.__name__} not set in {self.__class__.__name__}",
143
  Documentation.ADDING_TEMPLATE,
144
  )
145
  return format_str.format(**data)
@@ -197,26 +234,21 @@ class ApplyRandomTemplate(ApplyTemplate):
197
  return random_generator.choice(self.templates)
198
 
199
 
200
- class InputOutputTemplate(Template):
201
- """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
202
-
203
- Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
204
- """
205
-
206
  input_format: str
207
- output_format: str = None
208
 
209
- def input_fields_to_source(
210
- self, input_fields: Dict[str, object]
211
- ) -> Tuple[str, str]:
212
  return self.apply_formatting(
213
  input_fields,
214
  "input field",
215
  self.input_format,
216
  "input_format",
217
- serialize=True,
218
  )
219
 
 
 
 
 
220
  def reference_fields_to_target_and_references(
221
  self, reference_fields: Dict[str, object]
222
  ) -> str:
@@ -225,12 +257,20 @@ class InputOutputTemplate(Template):
225
  "reference field",
226
  self.output_format,
227
  "output_format",
228
- serialize=True,
229
  )
230
  references = [target]
231
  return target, references
232
 
233
 
 
 
 
 
 
 
 
 
 
234
  class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
235
  reference: str
236
 
@@ -242,14 +282,12 @@ class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
242
  "reference field",
243
  self.output_format,
244
  "output_format",
245
- serialize=True,
246
  )
247
  reference = self.apply_formatting(
248
  reference_fields,
249
  "reference field",
250
  self.reference,
251
  "reference",
252
- serialize=True,
253
  )
254
  return target, [reference]
255
 
@@ -374,22 +412,12 @@ class DialogTemplate(InputOutputTemplate):
374
  input_fields[dialog_fields.dialog_field] = dialog_str
375
  return input_fields
376
 
377
- def preprocess_input_and_reference_fields(
378
- self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
379
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
380
- return self.process_dialog(input_fields), reference_fields
381
 
382
 
383
  class DialogPairwiseChoiceTemplate(DialogTemplate, PairwiseChoiceTemplate):
384
- def preprocess_input_and_reference_fields(
385
- self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
386
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
387
- inputs, reference_fields = DialogTemplate.preprocess_input_and_reference_fields(
388
- self, input_fields, reference_fields
389
- )
390
- return PairwiseChoiceTemplate.preprocess_input_and_reference_fields(
391
- self, input_fields, reference_fields
392
- )
393
 
394
 
395
  class PairwiseComparativeRatingTemplate(InputOutputTemplate):
@@ -448,10 +476,9 @@ class PairwiseComparativeRatingTemplate(InputOutputTemplate):
448
  return input_fields, reference_fields
449
 
450
 
451
- class MultipleChoiceTemplate(Template):
452
  """Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer."""
453
 
454
- input_format: str
455
  target_prefix: str = ""
456
  choices_field: str = "choices"
457
  target_field: str = "label"
@@ -493,7 +520,7 @@ class MultipleChoiceTemplate(Template):
493
  "XX",
494
  ]
495
 
496
- def inputs_to_choices(self, data: Dict[str, object], choice_format: str) -> str:
497
  choices = data[self.choices_field]
498
  enumrated_choices = []
499
  for i, choice in enumerate(choices):
@@ -505,12 +532,12 @@ class MultipleChoiceTemplate(Template):
505
  )
506
  return enumrated_choices
507
 
508
- def inputs_to_numerals(self, input_fields: Dict[str, object]) -> Tuple[str, str]:
509
  return self.inputs_to_choices(input_fields, "{choice_numeral}")
510
 
511
  def prepare_multiple_choice_inputs(
512
- self, input_fields: Dict[str, object]
513
- ) -> Dict[str, object]:
514
  choices = self.inputs_to_choices(input_fields, self.source_choice_format)
515
  return {
516
  "numerals": self.inputs_to_numerals(input_fields),
@@ -518,23 +545,10 @@ class MultipleChoiceTemplate(Template):
518
  self.choices_field: self.choices_separator.join(choices),
519
  }
520
 
521
- def input_fields_to_source(
522
- self, input_fields: Dict[str, object]
523
- ) -> Tuple[str, str]:
524
- input_fields = self.prepare_multiple_choice_inputs(input_fields)
525
- return self.apply_formatting(
526
- input_fields,
527
- "input field",
528
- self.input_format,
529
- "input_format",
530
- serialize=True,
531
- )
532
-
533
- def input_fields_to_instruction_and_target_prefix(self, input_fields):
534
- input_fields = self.prepare_multiple_choice_inputs(input_fields)
535
- return super().input_fields_to_instruction_and_target_prefix(input_fields)
536
 
537
- def outputs_to_target_index(self, reference_fields: Dict[str, object]) -> str:
538
  target = reference_fields[self.target_field]
539
 
540
  if not isinstance(target, int):
@@ -547,9 +561,7 @@ class MultipleChoiceTemplate(Template):
547
  ) from e
548
  return target
549
 
550
- def reference_fields_to_target_and_references(
551
- self, reference_fields: Dict[str, object]
552
- ) -> str:
553
  target = reference_fields[self.target_field]
554
 
555
  if not isinstance(target, int):
@@ -571,51 +583,40 @@ class MultipleChoiceTemplate(Template):
571
  Documentation.ADDING_TEMPLATE,
572
  ) from e
573
 
 
 
 
 
 
 
574
  return target, [target]
575
 
576
- def _shuffle_choices(self, instance, stream_name):
577
- if stream_name != constants.inference_stream:
578
- target_index = self.outputs_to_target_index(instance["reference_fields"])
579
- original_label_choice = instance["reference_fields"][self.choices_field][
580
- target_index
581
- ]
582
- choices = instance["input_fields"][self.choices_field]
 
583
 
584
- random_seed = {**instance["input_fields"]}
 
 
585
 
586
- random_generator = new_random_generator(random_seed)
587
- random_generator.shuffle(choices)
588
- instance["input_fields"][self.choices_field] = choices
589
 
590
- if stream_name == constants.inference_stream:
591
- return instance
592
 
593
- instance["reference_fields"][self.choices_field] = choices
594
- instance["reference_fields"][self.target_field] = choices.index(
595
- original_label_choice
596
  )
597
-
598
  return instance
599
 
600
- def process(
601
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
602
- ) -> Dict[str, Any]:
603
- if self.shuffle_choices:
604
- instance = self._shuffle_choices(instance, stream_name)
605
- result = super().process(instance, stream_name)
606
- if stream_name == constants.inference_stream:
607
- result["input_fields"]["options"] = self.inputs_to_choices(
608
- instance["input_fields"], self.target_choice_format
609
- )
610
- else:
611
- if "options" not in result["reference_fields"]:
612
- result["reference_fields"]["options"] = self.inputs_to_choices(
613
- instance["reference_fields"], self.target_choice_format
614
- )
615
- return result
616
-
617
 
618
- class YesNoTemplate(Template):
619
  """A template for generating binary Yes/No questions asking whether an input text is of a specific class.
620
 
621
  input_format:
@@ -641,17 +642,6 @@ class YesNoTemplate(Template):
641
  yes_answer: str = "Yes"
642
  no_answer: str = "No"
643
 
644
- def input_fields_to_source(
645
- self, input_fields: Dict[str, object]
646
- ) -> Tuple[str, str]:
647
- return self.apply_formatting(
648
- input_fields,
649
- "input field",
650
- self.input_format,
651
- "input_format",
652
- serialize=True,
653
- )
654
-
655
  def reference_fields_to_target_and_references(
656
  self, reference_fields: Dict[str, object]
657
  ) -> str:
@@ -695,16 +685,13 @@ class KeyValTemplate(Template):
695
  def process_dict(
696
  self, data: Dict[str, object], key_val_sep, pairs_sep, use_keys
697
  ) -> str:
698
- data = self.serialize_data(data)
699
  pairs = []
700
  for key, val in data.items():
701
  key_val = [key, str(val)] if use_keys else [str(val)]
702
  pairs.append(key_val_sep.join(key_val))
703
  return pairs_sep.join(pairs)
704
 
705
- def input_fields_to_source(
706
- self, input_fields: Dict[str, object]
707
- ) -> Tuple[str, str]:
708
  return self.process_dict(
709
  input_fields,
710
  key_val_sep=self.key_val_separator,
@@ -725,25 +712,16 @@ class KeyValTemplate(Template):
725
 
726
 
727
  class OutputQuantizingTemplate(InputOutputTemplate):
728
- quantum: Union[float, int] = 0.1 # Now supports both int and float
 
 
 
729
 
730
- def reference_fields_to_target_and_references(
731
- self, reference_fields: Dict[str, object]
732
- ) -> str:
733
- if isinstance(self.quantum, int):
734
- # When quantum is an int, format quantized values as ints
735
- quantized_outputs = {
736
- key: f"{int(round(value / self.quantum) * self.quantum)}"
737
- for key, value in reference_fields.items()
738
- }
739
- else:
740
- # When quantum is a float, format quantized values with precision based on quantum
741
- quantum_str = f"{self.quantum:.10f}".rstrip("0").rstrip(".")
742
- quantized_outputs = {
743
- key: f"{round(value / self.quantum) * self.quantum:{quantum_str}}"
744
- for key, value in reference_fields.items()
745
- }
746
- return super().reference_fields_to_target_and_references(quantized_outputs)
747
 
748
 
749
  class MultiLabelTemplate(InputOutputTemplate):
@@ -753,9 +731,9 @@ class MultiLabelTemplate(InputOutputTemplate):
753
  output_format: str = "{labels}"
754
  empty_label: str = "None"
755
 
756
- def reference_fields_to_target_and_references(
757
- self, reference_fields: Dict[str, object]
758
- ) -> str:
759
  labels = reference_fields[self.labels_field]
760
  if not isinstance(labels, list):
761
  raise UnitxtError(
@@ -765,18 +743,29 @@ class MultiLabelTemplate(InputOutputTemplate):
765
  if len(labels) == 0:
766
  labels = [self.empty_label]
767
  labels_str = self.labels_separator.join(labels)
768
- return super().reference_fields_to_target_and_references(
769
- {self.labels_field: labels_str}
770
- )
771
 
772
 
773
  class MultiReferenceTemplate(InputOutputTemplate):
774
  references_field: str = "references"
775
  random_reference: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
776
 
777
  def reference_fields_to_target_and_references(
778
  self, reference_fields: Dict[str, object]
779
- ) -> List[str]:
780
  references = reference_fields[self.references_field]
781
  if not isoftype(references, List[str]):
782
  raise UnitxtError(
@@ -825,12 +814,12 @@ class SpanLabelingBaseTemplate(MultiLabelTemplate):
825
  if self.labels_support is None or span[3] in self.labels_support:
826
  yield span[2], span[3]
827
 
828
- def reference_fields_to_target_and_references(
829
- self, reference_fields: Dict[str, object]
830
- ) -> Dict[str, object]:
831
  span_labels_pairs = self.extract_span_label_pairs(reference_fields)
832
  targets = self.span_label_pairs_to_targets(span_labels_pairs)
833
- return super().reference_fields_to_target_and_references({"labels": targets})
834
 
835
  @abstractmethod
836
  def span_label_pairs_to_targets(self, pairs):
 
10
  from .error_utils import Documentation, UnitxtError
11
  from .operator import InstanceOperator
12
  from .random_utils import new_random_generator
13
+ from .serializers import (
14
+ DialogSerializer,
15
+ ImageSerializer,
16
+ ListSerializer,
17
+ MultiTypeSerializer,
18
+ NumberQuantizingSerializer,
19
+ Serializer,
20
+ TableSerializer,
21
+ )
22
  from .settings_utils import get_constants
23
  from .type_utils import isoftype
24
 
 
55
  instruction: str = NonPositionalField(default="")
56
  target_prefix: str = NonPositionalField(default="")
57
  title_fields: List[str] = NonPositionalField(default_factory=list)
58
+ serializer: Serializer = NonPositionalField(
59
+ default_factory=lambda: MultiTypeSerializer(
60
+ serializers=[
61
+ ImageSerializer(),
62
+ TableSerializer(),
63
+ DialogSerializer(),
64
+ ListSerializer(),
65
+ ]
66
+ )
67
+ )
68
 
69
  def input_fields_to_instruction_and_target_prefix(self, input_fields):
70
  instruction = self.apply_formatting(
71
+ input_fields, "input field", self.instruction, "instruction"
72
  )
73
  target_prefix = self.apply_formatting(
74
  input_fields,
75
  "input field",
76
  self.target_prefix,
77
  "target_prefix",
 
78
  )
79
  return instruction, target_prefix
80
 
 
83
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
84
  return input_fields, reference_fields
85
 
86
+ def preprocess_input_fields(self, input_fields: Dict[str, Any]):
87
+ return input_fields
88
+
89
+ def preprocess_reference_fields(self, reference_fields: Dict[str, Any]):
90
+ return reference_fields
91
+
92
  def process(
93
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
94
  ) -> Dict[str, Any]:
 
102
 
103
  input_fields = instance.get("input_fields")
104
  reference_fields = instance.get("reference_fields")
105
+
106
+ if stream_name != constants.inference_stream:
107
+ input_fields, reference_fields = self.preprocess_input_and_reference_fields(
108
+ input_fields, reference_fields
109
+ )
110
+
111
+ input_fields = self.preprocess_input_fields(input_fields)
112
 
113
  self.set_titles(input_fields)
114
+
115
+ serialized_inputs = self.serialize(input_fields, instance)
116
+
117
+ source = self.input_fields_to_source(serialized_inputs)
118
  instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
119
+ serialized_inputs
120
  )
121
 
122
  result = {
 
128
  }
129
 
130
  if stream_name == constants.inference_stream:
131
+ return self.post_process_instance(result)
132
 
133
  if reference_fields is None:
134
  raise ValueError("Should have reference_fields")
135
 
136
+ reference_fields = self.preprocess_reference_fields(reference_fields)
137
+
138
+ serialized_references = self.serialize(
139
+ reference_fields, instance
140
+ ) # Dict[str, str]
141
+
142
  target, references = self.reference_fields_to_target_and_references(
143
+ serialized_references
144
  )
145
 
146
  result["target"] = target
147
  result["references"] = references
148
 
149
+ return self.post_process_instance(result)
150
+
151
+ def post_process_instance(self, instance):
152
+ return instance
153
+
154
+ def serialize(
155
+ self, data: Dict[str, Any], instance: Dict[str, Any]
156
+ ) -> Dict[str, str]:
157
+ return {k: self.serializer.serialize(v, instance) for k, v in data.items()}
158
 
159
  @abstractmethod
160
  def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
 
170
  ) -> Tuple[str, List[str]]:
171
  pass
172
 
 
 
 
 
 
 
173
  def apply_formatting(
174
+ self, data: Dict[str, Any], data_type: str, format_str: str, format_name: str
175
  ) -> str:
 
 
176
  try:
177
  if format_str is None:
178
  raise UnitxtError(
179
+ f"Required field '{format_name}' of class {self.__class__.__name__} not set in {self.__class__.__name__}",
180
  Documentation.ADDING_TEMPLATE,
181
  )
182
  return format_str.format(**data)
 
234
  return random_generator.choice(self.templates)
235
 
236
 
237
+ class InputFormatTemplate(Template):
 
 
 
 
 
238
  input_format: str
 
239
 
240
+ def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
 
 
241
  return self.apply_formatting(
242
  input_fields,
243
  "input field",
244
  self.input_format,
245
  "input_format",
 
246
  )
247
 
248
+
249
+ class OutputFormatTemplate(Template):
250
+ output_format: str = None
251
+
252
  def reference_fields_to_target_and_references(
253
  self, reference_fields: Dict[str, object]
254
  ) -> str:
 
257
  "reference field",
258
  self.output_format,
259
  "output_format",
 
260
  )
261
  references = [target]
262
  return target, references
263
 
264
 
265
+ class InputOutputTemplate(InputFormatTemplate, OutputFormatTemplate):
266
+ """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
267
+
268
+ Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
269
+ """
270
+
271
+ pass
272
+
273
+
274
  class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
275
  reference: str
276
 
 
282
  "reference field",
283
  self.output_format,
284
  "output_format",
 
285
  )
286
  reference = self.apply_formatting(
287
  reference_fields,
288
  "reference field",
289
  self.reference,
290
  "reference",
 
291
  )
292
  return target, [reference]
293
 
 
412
  input_fields[dialog_fields.dialog_field] = dialog_str
413
  return input_fields
414
 
415
+ def preprocess_input_fields(self, input_fields: Dict[str, Any]):
416
+ return self.process_dialog(input_fields)
 
 
417
 
418
 
419
  class DialogPairwiseChoiceTemplate(DialogTemplate, PairwiseChoiceTemplate):
420
+ pass
 
 
 
 
 
 
 
 
421
 
422
 
423
  class PairwiseComparativeRatingTemplate(InputOutputTemplate):
 
476
  return input_fields, reference_fields
477
 
478
 
479
+ class MultipleChoiceTemplate(InputFormatTemplate):
480
  """Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer."""
481
 
 
482
  target_prefix: str = ""
483
  choices_field: str = "choices"
484
  target_field: str = "label"
 
520
  "XX",
521
  ]
522
 
523
+ def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str:
524
  choices = data[self.choices_field]
525
  enumrated_choices = []
526
  for i, choice in enumerate(choices):
 
532
  )
533
  return enumrated_choices
534
 
535
+ def inputs_to_numerals(self, input_fields: Dict[str, Any]) -> Tuple[str, str]:
536
  return self.inputs_to_choices(input_fields, "{choice_numeral}")
537
 
538
  def prepare_multiple_choice_inputs(
539
+ self, input_fields: Dict[str, Any]
540
+ ) -> Dict[str, Any]:
541
  choices = self.inputs_to_choices(input_fields, self.source_choice_format)
542
  return {
543
  "numerals": self.inputs_to_numerals(input_fields),
 
545
  self.choices_field: self.choices_separator.join(choices),
546
  }
547
 
548
+ def preprocess_input_fields(self, input_fields: Dict[str, Any]) -> Dict[str, Any]:
549
+ return self.prepare_multiple_choice_inputs(input_fields)
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
+ def outputs_to_target_index(self, reference_fields: Dict[str, object]) -> int:
552
  target = reference_fields[self.target_field]
553
 
554
  if not isinstance(target, int):
 
561
  ) from e
562
  return target
563
 
564
+ def preprocess_reference_fields(self, reference_fields: Dict[str, Any]):
 
 
565
  target = reference_fields[self.target_field]
566
 
567
  if not isinstance(target, int):
 
583
  Documentation.ADDING_TEMPLATE,
584
  ) from e
585
 
586
+ return {self.target_field: target}
587
+
588
+ def reference_fields_to_target_and_references(
589
+ self, reference_fields: Dict[str, object]
590
+ ) -> str:
591
+ target = reference_fields[self.target_field]
592
  return target, [target]
593
 
594
+ def preprocess_input_and_reference_fields(
595
+ self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
596
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
597
+ if self.shuffle_choices:
598
+ target_index = self.outputs_to_target_index(reference_fields)
599
+ original_label_choice = reference_fields[self.choices_field][target_index]
600
+ choices = input_fields[self.choices_field]
601
+ random_seed = {**input_fields}
602
 
603
+ random_generator = new_random_generator(random_seed)
604
+ random_generator.shuffle(choices)
605
+ input_fields[self.choices_field] = choices
606
 
607
+ reference_fields[self.choices_field] = choices
608
+ reference_fields[self.target_field] = choices.index(original_label_choice)
 
609
 
610
+ return input_fields, reference_fields
 
611
 
612
+ def post_process_instance(self, instance):
613
+ instance["input_fields"]["options"] = self.inputs_to_choices(
614
+ instance["input_fields"], self.target_choice_format
615
  )
 
616
  return instance
617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
 
619
+ class YesNoTemplate(InputFormatTemplate):
620
  """A template for generating binary Yes/No questions asking whether an input text is of a specific class.
621
 
622
  input_format:
 
642
  yes_answer: str = "Yes"
643
  no_answer: str = "No"
644
 
 
 
 
 
 
 
 
 
 
 
 
645
  def reference_fields_to_target_and_references(
646
  self, reference_fields: Dict[str, object]
647
  ) -> str:
 
685
  def process_dict(
686
  self, data: Dict[str, object], key_val_sep, pairs_sep, use_keys
687
  ) -> str:
 
688
  pairs = []
689
  for key, val in data.items():
690
  key_val = [key, str(val)] if use_keys else [str(val)]
691
  pairs.append(key_val_sep.join(key_val))
692
  return pairs_sep.join(pairs)
693
 
694
+ def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
 
 
695
  return self.process_dict(
696
  input_fields,
697
  key_val_sep=self.key_val_separator,
 
712
 
713
 
714
  class OutputQuantizingTemplate(InputOutputTemplate):
715
+ serializer: MultiTypeSerializer = NonPositionalField(
716
+ default_factory=MultiTypeSerializer
717
+ )
718
+ quantum: Union[float, int] = 0.1
719
 
720
+ def prepare(self):
721
+ super().prepare()
722
+ self.serializer.add_serializers(
723
+ [NumberQuantizingSerializer(quantum=self.quantum)]
724
+ )
 
 
 
 
 
 
 
 
 
 
 
 
725
 
726
 
727
  class MultiLabelTemplate(InputOutputTemplate):
 
731
  output_format: str = "{labels}"
732
  empty_label: str = "None"
733
 
734
+ def preprocess_reference_fields(
735
+ self, reference_fields: Dict[str, Any]
736
+ ) -> Dict[str, Any]:
737
  labels = reference_fields[self.labels_field]
738
  if not isinstance(labels, list):
739
  raise UnitxtError(
 
743
  if len(labels) == 0:
744
  labels = [self.empty_label]
745
  labels_str = self.labels_separator.join(labels)
746
+ return {self.labels_field: labels_str}
 
 
747
 
748
 
749
  class MultiReferenceTemplate(InputOutputTemplate):
750
  references_field: str = "references"
751
  random_reference: bool = False
752
+ serializer: Serializer = NonPositionalField(default_factory=MultiTypeSerializer)
753
+
754
+ def serialize(
755
+ self, data: Dict[str, Any], instance: Dict[str, Any]
756
+ ) -> Dict[str, str]:
757
+ result = {}
758
+ for k, v in data.items():
759
+ if k == self.references_field:
760
+ v = [self.serializer.serialize(item, instance) for item in v]
761
+ else:
762
+ v = self.serializer.serialize(v, instance)
763
+ result[k] = v
764
+ return result
765
 
766
  def reference_fields_to_target_and_references(
767
  self, reference_fields: Dict[str, object]
768
+ ) -> Tuple[str, List[str]]:
769
  references = reference_fields[self.references_field]
770
  if not isoftype(references, List[str]):
771
  raise UnitxtError(
 
814
  if self.labels_support is None or span[3] in self.labels_support:
815
  yield span[2], span[3]
816
 
817
+ def preprocess_reference_fields(
818
+ self, reference_fields: Dict[str, Any]
819
+ ) -> Dict[str, Any]:
820
  span_labels_pairs = self.extract_span_label_pairs(reference_fields)
821
  targets = self.span_label_pairs_to_targets(span_labels_pairs)
822
+ return super().preprocess_reference_fields({"labels": targets})
823
 
824
  @abstractmethod
825
  def span_label_pairs_to_targets(self, pairs):
type_utils.py CHANGED
@@ -4,48 +4,75 @@ import io
4
  import itertools
5
  import re
6
  import typing
 
7
 
8
  from .utils import safe_eval
9
 
10
- _supported_types_strings = [
11
- "Any",
12
- "List[...]",
13
- "Dict[...]",
14
- "Tuple[...]",
15
- "Union[...]",
16
- "Optional[...]",
17
- "int",
18
- "float",
19
- "dict",
20
- "double",
21
- "str",
22
- ]
 
 
 
 
 
 
 
 
23
 
24
  Type = typing.Any
25
 
26
 
27
  class UnsupportedTypeError(ValueError):
28
  def __init__(self, type_object):
29
- supported_types = ", ".join(_supported_types_strings)
30
  super().__init__(
31
  f"Type: '{type_object!s}' is not supported type. Use one of {supported_types}"
32
  )
33
 
34
 
 
 
 
 
35
  _generics = [
36
- typing.List[typing.Any],
37
- typing.Dict[typing.Any, typing.Any],
38
- typing.Tuple[typing.Any],
39
- typing.Union[typing.Any, typing.Any],
40
- typing.Optional[typing.Any],
41
- typing.Any,
 
42
  ]
43
 
44
  _generics_types = [type(t) for t in _generics]
45
 
46
 
 
 
 
 
 
 
 
 
47
  def is_type(object):
48
- return isinstance(object, (type, *_generics_types))
 
 
 
 
 
49
 
50
 
51
  def is_type_dict(object):
@@ -215,34 +242,31 @@ def parse_type_string(type_string: str) -> typing.Any:
215
  and basic Python data types. It also defines a list of safe tokens that are allowed
216
  in the type string.
217
  """
218
- safe_context = {
219
- "Any": typing.Any,
220
- "List": typing.List,
221
- "Dict": typing.Dict,
222
- "Tuple": typing.Tuple,
223
- "Union": typing.Union,
224
- "int": int,
225
- "str": str,
226
- "float": float,
227
- "bool": bool,
228
- "Optional": typing.Optional,
229
- }
230
-
231
  type_string = format_type_string(type_string)
232
 
233
- safe_tokens = ["[", "]", ",", " "]
234
- return safe_eval(type_string, safe_context, safe_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
 
237
  def to_type_string(typing_type):
238
- if not is_type(typing_type):
239
- raise UnsupportedTypeError(typing_type)
240
- type_string = (
241
- str(typing_type)
242
- .replace("typing.", "")
243
- .replace("<class '", "")
244
- .replace("'>", "")
245
- )
246
  assert parse_type_string(type_string), "Is not parsed well"
247
  return type_string
248
 
@@ -447,9 +471,9 @@ def infer_type_string(obj: typing.Any) -> str:
447
  def isoftype(object, typing_type):
448
  """Checks if an object is of a certain typing type, including nested types.
449
 
450
- This function supports simple types (like `int`, `str`), typing types
451
- (like `List[int]`, `Tuple[str, int]`, `Dict[str, int]`), and nested typing
452
- types (like `List[List[int]]`, `Tuple[List[str], int]`, `Dict[str, List[int]]`).
453
 
454
  Args:
455
  object: The object to check.
@@ -457,19 +481,21 @@ def isoftype(object, typing_type):
457
 
458
  Returns:
459
  bool: True if the object is of the specified type, False otherwise.
460
-
461
- Examples:
462
- .. highlight:: python
463
- .. code-block:: python
464
-
465
- isoftype(1, int) # True
466
- isoftype([1, 2, 3], typing.List[int]) # True
467
- isoftype([1, 2, 3], typing.List[str]) # False
468
- isoftype([[1, 2], [3, 4]], typing.List[typing.List[int]]) # True
469
  """
470
  if not is_type(typing_type):
471
  raise UnsupportedTypeError(typing_type)
472
 
 
 
 
 
 
 
 
 
 
 
 
473
  if typing_type == typing.Any:
474
  return True
475
 
@@ -477,15 +503,16 @@ def isoftype(object, typing_type):
477
  origin = typing_type.__origin__
478
  type_args = typing.get_args(typing_type)
479
 
 
 
 
480
  if origin is typing.Union:
481
  return any(isoftype(object, sub_type) for sub_type in type_args)
482
 
483
  if not isinstance(object, origin):
484
  return False
485
-
486
  if origin is list or origin is set:
487
  return all(isoftype(element, type_args[0]) for element in object)
488
-
489
  if origin is dict:
490
  return all(
491
  isoftype(key, type_args[0]) and isoftype(value, type_args[1])
@@ -496,11 +523,77 @@ def isoftype(object, typing_type):
496
  isoftype(element, type_arg)
497
  for element, type_arg in zip(object, type_args)
498
  )
499
- return None
500
 
501
  return isinstance(object, typing_type)
502
 
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  # copied from: https://github.com/bojiang/typing_utils/blob/main/typing_utils/__init__.py
505
  # liscened under Apache License 2.0
506
 
 
4
  import itertools
5
  import re
6
  import typing
7
+ from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
8
 
9
  from .utils import safe_eval
10
 
11
+ _registered_types = {
12
+ "Any": typing.Any,
13
+ "List": typing.List,
14
+ "Dict": typing.Dict,
15
+ "Tuple": typing.Tuple,
16
+ "Union": typing.Union,
17
+ "Optional": typing.Optional,
18
+ "Literal": typing.Literal,
19
+ "int": int,
20
+ "str": str,
21
+ "float": float,
22
+ "bool": bool,
23
+ }
24
+
25
+
26
+ def register_type(new_type):
27
+ assert is_new_type(new_type) or is_typed_dict(
28
+ new_type
29
+ ), "Can register only typing.NewType or typing.TypedDict"
30
+ _registered_types[new_type.__name__] = new_type
31
+
32
 
33
  Type = typing.Any
34
 
35
 
36
  class UnsupportedTypeError(ValueError):
37
  def __init__(self, type_object):
38
+ supported_types = ", ".join(_registered_types.keys())
39
  super().__init__(
40
  f"Type: '{type_object!s}' is not supported type. Use one of {supported_types}"
41
  )
42
 
43
 
44
+ class GenericTypedDict(TypedDict):
45
+ pass
46
+
47
+
48
  _generics = [
49
+ List[Any],
50
+ Dict[Any, Any],
51
+ Tuple[Any],
52
+ Union[Any, Any],
53
+ Optional[Any],
54
+ Any,
55
+ Literal,
56
  ]
57
 
58
  _generics_types = [type(t) for t in _generics]
59
 
60
 
61
+ def is_new_type(object):
62
+ return callable(object) and hasattr(object, "__supertype__")
63
+
64
+
65
+ def is_typed_dict(object):
66
+ return isinstance(object, type(GenericTypedDict))
67
+
68
+
69
  def is_type(object):
70
+ """Checks if the provided object is a type, including generics, Literal, TypedDict, and NewType."""
71
+ return (
72
+ isinstance(object, (type, *_generics_types))
73
+ or is_new_type(object)
74
+ or is_typed_dict(object)
75
+ )
76
 
77
 
78
  def is_type_dict(object):
 
242
  and basic Python data types. It also defines a list of safe tokens that are allowed
243
  in the type string.
244
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  type_string = format_type_string(type_string)
246
 
247
+ return safe_eval(
248
+ type_string, context=_registered_types, allowed_tokens=["[", "]", ",", " "]
249
+ )
250
+
251
+
252
+ def replace_class_names(full_string: str) -> str:
253
+ # Regular expression to match any fully qualified class name and extract the class name
254
+ pattern = r"(?:\w+\.)*<locals>\.(\w+)|(?:\w+\.)*(\w+)"
255
+
256
+ # Function to replace the matched pattern with just the class name
257
+ def replacement(match):
258
+ # If the match has a group for <locals>
259
+ if match.group(1):
260
+ return match.group(1)
261
+ # Otherwise, return the last group (class name)
262
+ return match.group(2)
263
+
264
+ # Use re.sub to replace all occurrences in the string
265
+ return re.sub(pattern, replacement, full_string)
266
 
267
 
268
  def to_type_string(typing_type):
269
+ type_string = strtype(typing_type)
 
 
 
 
 
 
 
270
  assert parse_type_string(type_string), "Is not parsed well"
271
  return type_string
272
 
 
471
  def isoftype(object, typing_type):
472
  """Checks if an object is of a certain typing type, including nested types.
473
 
474
+ This function supports simple types, typing types (List[int], Tuple[str, int]),
475
+ nested typing types (List[List[int]], Tuple[List[str], int]), Literal, TypedDict,
476
+ and NewType.
477
 
478
  Args:
479
  object: The object to check.
 
481
 
482
  Returns:
483
  bool: True if the object is of the specified type, False otherwise.
 
 
 
 
 
 
 
 
 
484
  """
485
  if not is_type(typing_type):
486
  raise UnsupportedTypeError(typing_type)
487
 
488
+ if is_new_type(typing_type):
489
+ typing_type = typing_type.__supertype__
490
+
491
+ if is_typed_dict(typing_type):
492
+ if not isinstance(object, dict):
493
+ return False
494
+ for key, expected_type in typing_type.__annotations__.items():
495
+ if key not in object or not isoftype(object[key], expected_type):
496
+ return False
497
+ return True
498
+
499
  if typing_type == typing.Any:
500
  return True
501
 
 
503
  origin = typing_type.__origin__
504
  type_args = typing.get_args(typing_type)
505
 
506
+ if origin is Literal:
507
+ return object in type_args
508
+
509
  if origin is typing.Union:
510
  return any(isoftype(object, sub_type) for sub_type in type_args)
511
 
512
  if not isinstance(object, origin):
513
  return False
 
514
  if origin is list or origin is set:
515
  return all(isoftype(element, type_args[0]) for element in object)
 
516
  if origin is dict:
517
  return all(
518
  isoftype(key, type_args[0]) and isoftype(value, type_args[1])
 
523
  isoftype(element, type_arg)
524
  for element, type_arg in zip(object, type_args)
525
  )
 
526
 
527
  return isinstance(object, typing_type)
528
 
529
 
530
+ def strtype(typing_type) -> str:
531
+ """Converts a typing type to its string representation.
532
+
533
+ Args:
534
+ typing_type (Any): The typing type to be converted. This can include standard types,
535
+ custom types, or types from the `typing` module, such as `Literal`, `Union`,
536
+ `List`, `Dict`, `Tuple`, `TypedDict`, and `NewType`.
537
+
538
+ Returns:
539
+ str: The string representation of the provided typing type.
540
+
541
+ Raises:
542
+ UnsupportedTypeError: If the provided `typing_type` is not a recognized type.
543
+
544
+ Notes:
545
+ - If `typing_type` is `Literal`, `NewType`, or `TypedDict`, the function returns
546
+ the name of the type.
547
+ - If `typing_type` is `Any`, it returns the string `"Any"`.
548
+ - For other typing constructs like `Union`, `List`, `Dict`, and `Tuple`, the function
549
+ recursively converts each part of the type to its string representation.
550
+ - The function checks the `__origin__` attribute to determine the base type and formats
551
+ the type arguments accordingly.
552
+ """
553
+ if not is_type(typing_type):
554
+ raise UnsupportedTypeError(typing_type)
555
+
556
+ if is_new_type(typing_type) or is_typed_dict(typing_type):
557
+ return typing_type.__name__
558
+
559
+ if typing_type == typing.Any:
560
+ return "Any"
561
+
562
+ if hasattr(typing_type, "__origin__"):
563
+ origin = typing_type.__origin__
564
+ type_args = typing.get_args(typing_type)
565
+
566
+ if type_args[-1] is type(None):
567
+ return (
568
+ "Optional["
569
+ + ", ".join([strtype(sub_type) for sub_type in type_args[:-1]])
570
+ + "]"
571
+ )
572
+
573
+ if origin is Literal:
574
+ return str(typing_type).replace("typing.", "")
575
+ if origin is typing.Union:
576
+ return (
577
+ "Union["
578
+ + ", ".join([strtype(sub_type) for sub_type in type_args])
579
+ + "]"
580
+ )
581
+ if origin is list or origin is set:
582
+ return "List[" + strtype(type_args[0]) + "]"
583
+ if origin is set:
584
+ return "Set[" + strtype(type_args[0]) + "]"
585
+ if origin is dict:
586
+ return "Dict[" + strtype(type_args[0]) + ", " + strtype(type_args[1]) + "]"
587
+ if origin is tuple:
588
+ return (
589
+ "Tuple["
590
+ + ", ".join([strtype(sub_type) for sub_type in type_args])
591
+ + "]"
592
+ )
593
+
594
+ return typing_type.__name__
595
+
596
+
597
  # copied from: https://github.com/bojiang/typing_utils/blob/main/typing_utils/__init__.py
598
  # liscened under Apache License 2.0
599
 
types.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Literal, NewType, TypedDict, Union
2
+
3
+ from .type_utils import register_type
4
+
5
+ Text = NewType("Text", str)
6
+ Number = NewType("Number", Union[float, int])
7
+
8
+
9
+ class Turn(TypedDict):
10
+ role: Literal["system", "user", "agent"]
11
+ content: Text
12
+
13
+
14
+ Dialog = NewType("Dialog", List[Turn])
15
+
16
+
17
+ class Image(TypedDict):
18
+ image: Any
19
+
20
+
21
+ class Audio(TypedDict):
22
+ audio: Any
23
+
24
+
25
+ class Table(TypedDict):
26
+ header: List[str]
27
+ rows: List[List[Any]]
28
+
29
+
30
+ register_type(Text)
31
+ register_type(Number)
32
+ register_type(Turn)
33
+ register_type(Dialog)
34
+ register_type(Table)
35
+ register_type(Audio)
36
+ register_type(Image)
utils.py CHANGED
@@ -2,6 +2,7 @@ import copy
2
  import importlib.util
3
  import json
4
  import os
 
5
  from functools import lru_cache
6
  from typing import Any, Dict
7
 
@@ -87,6 +88,23 @@ def is_module_available(module_name):
87
  return False
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
91
  """Evaluates a given expression in a restricted environment, allowing only specified tokens and context variables.
92
 
@@ -109,7 +127,9 @@ def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
109
  by restricting the available tokens and not exposing built-in functions.
110
  """
111
  allowed_sub_strings = list(context.keys()) + allowed_tokens
112
- if is_made_of_sub_strings(expression, allowed_sub_strings):
 
 
113
  return eval(expression, {"__builtins__": {}}, context)
114
  raise ValueError(
115
  f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
 
2
  import importlib.util
3
  import json
4
  import os
5
+ import re
6
  from functools import lru_cache
7
  from typing import Any, Dict
8
 
 
88
  return False
89
 
90
 
91
+ def remove_numerics_and_quoted_texts(input_str):
92
+ # Remove floats first to avoid leaving stray periods
93
+ input_str = re.sub(r"\d+\.\d+", "", input_str)
94
+
95
+ # Remove integers
96
+ input_str = re.sub(r"\d+", "", input_str)
97
+
98
+ # Remove strings in single quotes
99
+ input_str = re.sub(r"'.*?'", "", input_str)
100
+
101
+ # Remove strings in double quotes
102
+ input_str = re.sub(r'".*?"', "", input_str)
103
+
104
+ # Remove strings in triple quotes
105
+ return re.sub(r'""".*?"""', "", input_str, flags=re.DOTALL)
106
+
107
+
108
  def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
109
  """Evaluates a given expression in a restricted environment, allowing only specified tokens and context variables.
110
 
 
127
  by restricting the available tokens and not exposing built-in functions.
128
  """
129
  allowed_sub_strings = list(context.keys()) + allowed_tokens
130
+ if is_made_of_sub_strings(
131
+ remove_numerics_and_quoted_texts(expression), allowed_sub_strings
132
+ ):
133
  return eval(expression, {"__builtins__": {}}, context)
134
  raise ValueError(
135
  f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.12.4"
 
1
+ version = "1.13.0"