File size: 5,407 Bytes
7cdc7d0 fe70438 7cdc7d0 fe70438 7cdc7d0 fe70438 7cdc7d0 fe70438 7cdc7d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import csv
import io
from abc import abstractmethod
from typing import Any, Dict, List, Union
from .dataclass import AbstractField, Field
from .operators import InstanceFieldOperator
from .settings_utils import get_constants
from .type_utils import isoftype, to_type_string
from .types import Dialog, Image, Number, Table, Video
constants = get_constants()
class Serializer(InstanceFieldOperator):
def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str:
return self.serialize(value, instance)
@abstractmethod
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
pass
class DefaultSerializer(Serializer):
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
return str(value)
class SingleTypeSerializer(InstanceFieldOperator):
serialized_type: object = AbstractField()
def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str:
if not isoftype(value, self.serialized_type):
raise ValueError(
f"SingleTypeSerializer for type {self.serialized_type} should get this type. got {to_type_string(value)}"
)
return self.serialize(value, instance)
class DefaultListSerializer(Serializer):
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
if isinstance(value, list):
return ", ".join(str(item) for item in value)
return str(value)
class ListSerializer(SingleTypeSerializer):
serialized_type = list
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
return ", ".join(str(item) for item in value)
class DialogSerializer(SingleTypeSerializer):
serialized_type = Dialog
def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
# Convert the Dialog into a string representation, typically combining roles and content
return "\n".join(f"{turn['role']}: {turn['content']}" for turn in value)
class NumberSerializer(SingleTypeSerializer):
serialized_type = Number
def serialize(self, value: Number, instance: Dict[str, Any]) -> str:
# Check if the value is an integer or a float
if isinstance(value, int):
return str(value)
# For floats, format to one decimal place
if isinstance(value, float):
return f"{value:.1f}"
raise ValueError("Unsupported type for NumberSerializer")
class NumberQuantizingSerializer(NumberSerializer):
serialized_type = Number
quantum: Union[float, int] = 0.1
def serialize(self, value: Number, instance: Dict[str, Any]) -> str:
if isoftype(value, Number):
quantized_value = round(value / self.quantum) / (1 / self.quantum)
if isinstance(self.quantum, int):
quantized_value = int(quantized_value)
return str(quantized_value)
raise ValueError("Unsupported type for NumberSerializer")
class TableSerializer(SingleTypeSerializer):
serialized_type = Table
def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
output = io.StringIO()
writer = csv.writer(output, lineterminator="\n")
# Write the header and rows to the CSV writer
writer.writerow(value["header"])
writer.writerows(value["rows"])
# Retrieve the CSV string
return output.getvalue().strip()
class ImageSerializer(SingleTypeSerializer):
serialized_type = Image
def serialize(self, value: Image, instance: Dict[str, Any]) -> str:
if "media" not in instance:
instance["media"] = {}
if "images" not in instance["media"]:
instance["media"]["images"] = []
idx = len(instance["media"]["images"])
instance["media"]["images"].append(
{"image": value["image"], "format": value["format"]}
)
value["image"] = f"media/images/{idx}"
return f'<{constants.image_tag} src="media/images/{idx}">'
class VideoSerializer(ImageSerializer):
serialized_type = Video
def serialize(self, value: Video, instance: Dict[str, Any]) -> str:
serialized_images = []
for image in value:
image = super().serialize(image, instance)
serialized_images.append(image)
return "".join(serialized_images)
class MultiTypeSerializer(Serializer):
serializers: List[SingleTypeSerializer] = Field(
default_factory=lambda: [
ImageSerializer(),
VideoSerializer(),
TableSerializer(),
DialogSerializer(),
]
)
def verify(self):
super().verify()
self._verify_serializers(self.serializers)
def _verify_serializers(self, serializers):
if not isoftype(serializers, List[SingleTypeSerializer]):
raise ValueError(
"MultiTypeSerializer requires the list of serializers to be List[SingleTypeSerializer]."
)
def add_serializers(self, serializers: List[SingleTypeSerializer]):
self._verify_serializers(serializers)
self.serializers = serializers + self.serializers
def serialize(self, value: Any, instance: Dict[str, Any]) -> Any:
for serializer in self.serializers:
if isoftype(value, serializer.serialized_type):
return serializer.serialize(value, instance)
return str(value)
|