File size: 9,468 Bytes
970dac4 357b16c 970dac4 24df49f 970dac4 7cdc7d0 970dac4 4aee30b 7cdc7d0 970dac4 4aee30b 970dac4 4aee30b 970dac4 7cdc7d0 970dac4 7cdc7d0 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 7cdc7d0 24df49f 7cdc7d0 cc5f321 7cdc7d0 cc5f321 7cdc7d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
"""Dialog Serializers.
Dialog serializers are the way to take dialog data and turn it into
text that can be fed to the model.
The format of the dialog is:
.. code-block:: text
dialog = [
{"user": "hello", "system": "hi"},
{"user": "kkk", "system": ""},
{"user": "kkk", "system": ""},
]
"""
from typing import Any, Dict, List, Optional
from .formats import SystemFormat
from .operators import InstanceFieldOperator
class SerializeDialog(InstanceFieldOperator):
"""Serializes dialog data for feeding into a model.
This class takes structured dialog data and converts it into a text format
according to a specified template. It allows for the inclusion or exclusion
of system responses and can operate on a per-turn basis or aggregate the entire
dialog.
Args:
field (str):
The field in the input data that contains the dialog.
to_field (Optional[str]):
The field in the output data where the serialized dialog will be stored.
last_user_turn_to_field (Optional[str]):
Field to store the last user turn.
last_system_turn_to_field (Optional[str]):
Field to store the last system turn.
context_field (Optional[str]):
Field that contains additional context to be prepended to the dialog.
"""
format: SystemFormat = None
last_response_to_field: Optional[str] = None
context_field: Optional[str] = None
context_separator: str = " "
slice_first_and_last_turns_format: bool = True
def standardize_format(self, demo_format):
turn_format = demo_format.replace("{source}", "{user}")
turn_format = turn_format.replace("{target}", "{system}")
return turn_format.replace("{target_prefix}", "")
def slice_first_turn(self, turn_format):
return turn_format[turn_format.index("{user}") :]
def slice_last_turn(self, turn_format):
return turn_format[: turn_format.index("{system}") + len("{system}")]
def slice_last_response(self, turn_format):
return turn_format[: turn_format.index("{user}") + len("{user}")]
def get_turn_format(self, turn_format, step, length):
if step == 0 and self.slice_first_and_last_turns_format:
turn_format = self.slice_first_turn(turn_format)
if step == length - 1:
if self.slice_first_and_last_turns_format:
turn_format = self.slice_last_turn(turn_format)
if self.last_response_to_field is not None:
turn_format = self.slice_last_response(turn_format)
return turn_format
def get_general_turn_format(self, instance):
general_format = (
instance["recipe_metadata"]["format"]
if self.format is None
else self.format
)
return self.standardize_format(general_format.demo_format)
def process_instance_value(
self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any]
):
dialog = (
""
if self.context_field is None
else instance[self.context_field] + self.context_separator
)
general_turn_format = self.get_general_turn_format(instance)
for i, turn in enumerate(structured_dialog):
turn_format = self.get_turn_format(
general_turn_format, i, len(structured_dialog)
)
dialog += turn_format.format(**turn)
if self.last_response_to_field is not None:
instance[self.last_response_to_field] = turn["system"]
return dialog
class SerializeOpenAiFormatDialog(SerializeDialog):
"""Serializes dialog data for feeding into a model.
This class takes structured dialog data in the OpenAi format, and converts it into a text format
according to a specified template. It allows for the inclusion or exclusion
of system responses and can operate on a per-turn basis or aggregate the entire
dialog.
Args:
field (str):
The field in the input data that contains the dialog.
to_field (Optional[str]):
The field in the output data where the serialized dialog will be stored.
last_user_turn_to_field (Optional[str]):
Field to store the last user turn.
last_system_turn_to_field (Optional[str]):
Field to store the last system turn.
context_field (Optional[str]):
Field that contains additional context to be prepended to the dialog.
"""
is_last_turn_user_only: bool = True
@staticmethod
def validate_openai_dialog_format(dialog: List[Dict[str, str]]) -> None:
"""Validates that the given dialog follows the correct OpenAI format.
The function checks that:
1. The dialog is a list of dictionaries.
2. Each dictionary contains the keys 'role' and 'content'.
3. The 'role' value is either 'user' or 'assistant'.
4. Both 'role' and 'content' values are strings.
5. The first 'role' is 'user'
If the dialog does not conform to the expected format, a descriptive
ValueError is raised indicating the issue.
Args:
dialog (List[Dict[str, str]]): The dialog to validate.
Raises:
ValueError: If the dialog does not meet the format requirements.
"""
if not isinstance(dialog, list):
raise ValueError("Dialog must be a list of dictionaries.")
for i, entry in enumerate(dialog):
if not isinstance(entry, dict):
raise ValueError(
f"Entry {i} is not a dictionary: {entry}. Each entry in the dialog must be a dictionary."
)
if "role" not in entry:
raise ValueError(
f"Entry {i} is missing the 'role' key: {entry}. Each dictionary must have a 'role' key."
)
if "content" not in entry:
raise ValueError(
f"Entry {i} is missing the 'content' key: {entry}. Each dictionary must have a 'content' key."
)
if not isinstance(entry["role"], str):
raise ValueError(
f"Entry {i} has a non-string 'role': {entry['role']}. The 'role' value must be a string."
)
if not isinstance(entry["content"], str):
raise ValueError(
f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
)
if entry["role"].lower() not in {"user", "assistant"}:
raise ValueError(
f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
)
first_entry = dialog[0]
if first_entry["role"].lower() != "user":
raise ValueError(
f"First entry role is expected to be 'user' It is {first_entry['role']}."
)
@staticmethod
def merge_dialog_entries(dialog: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Merges consecutive dialog entries with the same role.
Args:
dialog (List[Dict[str, str]]): The input dialog list where each dictionary has a 'role' and 'content'.
Returns:
List[Dict[str, str]]: A new list where consecutive entries with the same role are merged.
"""
if len(dialog) == 0:
return []
merged_dialog = [dialog[0]]
for entry in dialog[1:]:
if entry["role"] == merged_dialog[-1]["role"]:
merged_dialog[-1]["content"] += " " + entry["content"]
else:
merged_dialog.append(entry)
return merged_dialog
def transform_dialog_to_standard_format(
self, dialog: List[Dict[str, str]]
) -> List[Dict[str, str]]:
"""Transforms a dialog from OpenAI format to a simplified format.
Each dictionary
contains 'user' and 'system' keys with their respective contents. Consecutive entries
with the same role are merged. Entries with invalid roles raise an error.
Args:
dialog (List[Dict[str, str]]): The input dialog in OpenAI format.
Returns:
List[Dict[str, str]]: The transformed dialog.
Raises:
ValueError: If an invalid role is detected.
"""
SerializeOpenAiFormatDialog.validate_openai_dialog_format(dialog)
merged_dialog = SerializeOpenAiFormatDialog.merge_dialog_entries(dialog)
# self.validate_dialog_have_complete_pairs(merged_dialog)
result = []
for i in range(0, len(merged_dialog) - 1, 2):
user_entry = merged_dialog[i]
system_entry = merged_dialog[i + 1]
result.append(
{"user": user_entry["content"], "system": system_entry["content"]}
)
if len(merged_dialog) % 2 != 0:
user_entry = merged_dialog[-1]
result.append({"user": user_entry["content"], "system": ""})
return result
def process_instance_value(
self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any]
):
standard_format_dialog = self.transform_dialog_to_standard_format(
structured_dialog
)
return super().process_instance_value(standard_format_dialog, instance)
|