Upload splitters.py with huggingface_hub
Browse files- splitters.py +18 -3
splitters.py
CHANGED
@@ -70,6 +70,7 @@ class SeparateSplit(Splitter):
|
|
70 |
from_split: str
|
71 |
to_split_names: List[str]
|
72 |
to_split_sizes: List[int]
|
|
|
73 |
|
74 |
def verify(self):
|
75 |
assert (
|
@@ -82,13 +83,14 @@ class SeparateSplit(Splitter):
|
|
82 |
mapping = {
|
83 |
key: {key: [(None, None)]}
|
84 |
for key in multi_stream.keys()
|
85 |
-
if key != self.from_split
|
86 |
}
|
87 |
so_far = 0
|
88 |
for name, size in itertools.zip_longest(
|
89 |
self.to_split_names, self.to_split_sizes
|
90 |
):
|
91 |
-
|
|
|
92 |
if size:
|
93 |
so_far += size
|
94 |
generators = slice_streams(multi_stream, mapping)
|
@@ -131,6 +133,14 @@ class Sampler(Artifact):
|
|
131 |
) -> List[Dict[str, object]]:
|
132 |
pass
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
class RandomSampler(Sampler):
|
136 |
def sample(
|
@@ -172,6 +182,7 @@ class DiverseLabelsSampler(Sampler):
|
|
172 |
|
173 |
choices: str = "choices"
|
174 |
labels: str = "labels"
|
|
|
175 |
|
176 |
def prepare(self):
|
177 |
super().prepare()
|
@@ -207,6 +218,8 @@ class DiverseLabelsSampler(Sampler):
|
|
207 |
labels = {}
|
208 |
for examplar in examplars_pool:
|
209 |
label_repr = self.examplar_repr(examplar)
|
|
|
|
|
210 |
if label_repr not in labels:
|
211 |
labels[label_repr] = []
|
212 |
labels[label_repr].append(examplar)
|
@@ -269,7 +282,9 @@ class SpreadSplit(InstanceOperatorWithMultiStreamAccess):
|
|
269 |
self.local_cache = list(multi_stream[self.source_stream])
|
270 |
|
271 |
source_stream = self.local_cache
|
272 |
-
|
|
|
|
|
273 |
sampled_instances = self.sampler.sample(source_stream)
|
274 |
instance[self.target_field] = sampled_instances
|
275 |
return instance
|
|
|
70 |
from_split: str
|
71 |
to_split_names: List[str]
|
72 |
to_split_sizes: List[int]
|
73 |
+
remove_targets_from_source_split: bool = True
|
74 |
|
75 |
def verify(self):
|
76 |
assert (
|
|
|
83 |
mapping = {
|
84 |
key: {key: [(None, None)]}
|
85 |
for key in multi_stream.keys()
|
86 |
+
if not self.remove_targets_from_source_split or key != self.from_split
|
87 |
}
|
88 |
so_far = 0
|
89 |
for name, size in itertools.zip_longest(
|
90 |
self.to_split_names, self.to_split_sizes
|
91 |
):
|
92 |
+
if self.remove_targets_from_source_split or name != self.from_split:
|
93 |
+
mapping[name] = {self.from_split: [(so_far, size)]}
|
94 |
if size:
|
95 |
so_far += size
|
96 |
generators = slice_streams(multi_stream, mapping)
|
|
|
133 |
) -> List[Dict[str, object]]:
|
134 |
pass
|
135 |
|
136 |
+
def filter_source_by_instance(
|
137 |
+
self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
|
138 |
+
) -> List[Dict[str, object]]:
|
139 |
+
if "inputs" not in instance:
|
140 |
+
raise ValueError(f"'inputs' field is missing from '{instance}'.")
|
141 |
+
|
142 |
+
return list(filter(lambda x: x["inputs"] != instance["inputs"], instances_pool))
|
143 |
+
|
144 |
|
145 |
class RandomSampler(Sampler):
|
146 |
def sample(
|
|
|
182 |
|
183 |
choices: str = "choices"
|
184 |
labels: str = "labels"
|
185 |
+
include_empty_label: bool = True
|
186 |
|
187 |
def prepare(self):
|
188 |
super().prepare()
|
|
|
218 |
labels = {}
|
219 |
for examplar in examplars_pool:
|
220 |
label_repr = self.examplar_repr(examplar)
|
221 |
+
if label_repr == "[]" and not self.include_empty_label:
|
222 |
+
continue
|
223 |
if label_repr not in labels:
|
224 |
labels[label_repr] = []
|
225 |
labels[label_repr].append(examplar)
|
|
|
282 |
self.local_cache = list(multi_stream[self.source_stream])
|
283 |
|
284 |
source_stream = self.local_cache
|
285 |
+
source_stream = self.sampler.filter_source_by_instance(
|
286 |
+
source_stream, instance
|
287 |
+
)
|
288 |
sampled_instances = self.sampler.sample(source_stream)
|
289 |
instance[self.target_field] = sampled_instances
|
290 |
return instance
|