Elron commited on
Commit
571af6d
·
verified ·
1 Parent(s): 3da1d9d

Upload splitters.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. splitters.py +18 -3
splitters.py CHANGED
@@ -70,6 +70,7 @@ class SeparateSplit(Splitter):
70
  from_split: str
71
  to_split_names: List[str]
72
  to_split_sizes: List[int]
 
73
 
74
  def verify(self):
75
  assert (
@@ -82,13 +83,14 @@ class SeparateSplit(Splitter):
82
  mapping = {
83
  key: {key: [(None, None)]}
84
  for key in multi_stream.keys()
85
- if key != self.from_split
86
  }
87
  so_far = 0
88
  for name, size in itertools.zip_longest(
89
  self.to_split_names, self.to_split_sizes
90
  ):
91
- mapping[name] = {self.from_split: [(so_far, size)]}
 
92
  if size:
93
  so_far += size
94
  generators = slice_streams(multi_stream, mapping)
@@ -131,6 +133,14 @@ class Sampler(Artifact):
131
  ) -> List[Dict[str, object]]:
132
  pass
133
 
 
 
 
 
 
 
 
 
134
 
135
  class RandomSampler(Sampler):
136
  def sample(
@@ -172,6 +182,7 @@ class DiverseLabelsSampler(Sampler):
172
 
173
  choices: str = "choices"
174
  labels: str = "labels"
 
175
 
176
  def prepare(self):
177
  super().prepare()
@@ -207,6 +218,8 @@ class DiverseLabelsSampler(Sampler):
207
  labels = {}
208
  for examplar in examplars_pool:
209
  label_repr = self.examplar_repr(examplar)
 
 
210
  if label_repr not in labels:
211
  labels[label_repr] = []
212
  labels[label_repr].append(examplar)
@@ -269,7 +282,9 @@ class SpreadSplit(InstanceOperatorWithMultiStreamAccess):
269
  self.local_cache = list(multi_stream[self.source_stream])
270
 
271
  source_stream = self.local_cache
272
-
 
 
273
  sampled_instances = self.sampler.sample(source_stream)
274
  instance[self.target_field] = sampled_instances
275
  return instance
 
70
  from_split: str
71
  to_split_names: List[str]
72
  to_split_sizes: List[int]
73
+ remove_targets_from_source_split: bool = True
74
 
75
  def verify(self):
76
  assert (
 
83
  mapping = {
84
  key: {key: [(None, None)]}
85
  for key in multi_stream.keys()
86
+ if not self.remove_targets_from_source_split or key != self.from_split
87
  }
88
  so_far = 0
89
  for name, size in itertools.zip_longest(
90
  self.to_split_names, self.to_split_sizes
91
  ):
92
+ if self.remove_targets_from_source_split or name != self.from_split:
93
+ mapping[name] = {self.from_split: [(so_far, size)]}
94
  if size:
95
  so_far += size
96
  generators = slice_streams(multi_stream, mapping)
 
133
  ) -> List[Dict[str, object]]:
134
  pass
135
 
136
+ def filter_source_by_instance(
137
+ self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
138
+ ) -> List[Dict[str, object]]:
139
+ if "inputs" not in instance:
140
+ raise ValueError(f"'inputs' field is missing from '{instance}'.")
141
+
142
+ return list(filter(lambda x: x["inputs"] != instance["inputs"], instances_pool))
143
+
144
 
145
  class RandomSampler(Sampler):
146
  def sample(
 
182
 
183
  choices: str = "choices"
184
  labels: str = "labels"
185
+ include_empty_label: bool = True
186
 
187
  def prepare(self):
188
  super().prepare()
 
218
  labels = {}
219
  for examplar in examplars_pool:
220
  label_repr = self.examplar_repr(examplar)
221
+ if label_repr == "[]" and not self.include_empty_label:
222
+ continue
223
  if label_repr not in labels:
224
  labels[label_repr] = []
225
  labels[label_repr].append(examplar)
 
282
  self.local_cache = list(multi_stream[self.source_stream])
283
 
284
  source_stream = self.local_cache
285
+ source_stream = self.sampler.filter_source_by_instance(
286
+ source_stream, instance
287
+ )
288
  sampled_instances = self.sampler.sample(source_stream)
289
  instance[self.target_field] = sampled_instances
290
  return instance