Elron commited on
Commit
970dac4
·
verified ·
1 Parent(s): 39d46c6

Upload dialog_operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dialog_operators.py +88 -0
dialog_operators.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dialog Serializers.
2
+
3
+ Dialog serializers are the way to take dialog data and turn it into
4
+ text that can be fed to the model.
5
+
6
+ The format of the dialog is:
7
+
8
+ dialog = [
9
+ {"user": "hello", "system": "hi"},
10
+ {"user": "kkk", "system": ""},
11
+ {"user": "kkk", "system": ""},
12
+ ]
13
+ """
14
+ from typing import Any, Dict, List, Optional
15
+
16
+ from .formats import SystemFormat
17
+ from .operators import InstanceFieldOperator
18
+
19
+
20
+ class SerializeDialog(InstanceFieldOperator):
21
+ """Serializes dialog data for feeding into a model.
22
+
23
+ This class takes structured dialog data and converts it into a text format
24
+ according to a specified template. It allows for the inclusion or exclusion
25
+ of system responses and can operate on a per-turn basis or aggregate the entire
26
+ dialog.
27
+
28
+ Attributes:
29
+ field (str): The field in the input data that contains the dialog.
30
+ to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
31
+ last_user_turn_to_field (Optional[str]): Field to store the last user turn.
32
+ last_system_turn_to_field (Optional[str]): Field to store the last system turn.
33
+ context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
34
+ """
35
+
36
+ format: Optional[SystemFormat] = None
37
+ last_response_to_field: Optional[str] = None
38
+ context_field: Optional[str] = None
39
+ context_seperator: str = " "
40
+
41
+ def standartize_format(self, demo_format):
42
+ turn_format = demo_format.replace("{source}", "{user}")
43
+ turn_format = turn_format.replace("{target}", "{system}")
44
+ return turn_format.replace("{target_prefix}", "")
45
+
46
+ def slice_first_turn(self, turn_format):
47
+ return turn_format[turn_format.index("{user}") :]
48
+
49
+ def slice_last_turn(self, turn_format):
50
+ return turn_format[: turn_format.index("{system}") + len("{system}")]
51
+
52
+ def slice_last_reponse(self, turn_format):
53
+ return turn_format[: turn_format.index("{user}") + len("{user}")]
54
+
55
+ def get_turn_format(self, turn_format, step, length):
56
+ if step == 0:
57
+ turn_format = self.slice_first_turn(turn_format)
58
+ if step == length - 1:
59
+ turn_format = self.slice_last_turn(turn_format)
60
+ if self.last_response_to_field is not None:
61
+ turn_format = self.slice_last_reponse(turn_format)
62
+ return turn_format
63
+
64
+ def get_general_turn_format(self, instance):
65
+ general_format = (
66
+ instance["recipe_metadata"]["format"]
67
+ if self.format is None
68
+ else self.format
69
+ )
70
+ return self.standartize_format(general_format.demo_format)
71
+
72
+ def process_instance_value(
73
+ self, structred_dialog: List[Dict[str, str]], instance: Dict[str, Any]
74
+ ):
75
+ dialog = (
76
+ ""
77
+ if self.context_field is None
78
+ else instance[self.context_field] + self.context_seperator
79
+ )
80
+ general_turn_format = self.get_general_turn_format(instance)
81
+ for i, turn in enumerate(structred_dialog):
82
+ turn_format = self.get_turn_format(
83
+ general_turn_format, i, len(structred_dialog)
84
+ )
85
+ dialog += turn_format.format(**turn)
86
+ if self.last_response_to_field is not None:
87
+ instance[self.last_response_to_field] = turn["system"]
88
+ return dialog