Upload splitters.py with huggingface_hub
Browse files- splitters.py +16 -1
splitters.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import itertools
|
|
|
2 |
from dataclasses import field
|
3 |
from typing import Dict, List, Optional
|
4 |
|
@@ -78,7 +79,21 @@ class SliceSplit(Splitter):
|
|
78 |
|
79 |
|
80 |
class Sampler(Artifact):
|
81 |
-
sample_size: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
class RandomSampler(Sampler):
|
|
|
1 |
import itertools
|
2 |
+
from abc import abstractmethod
|
3 |
from dataclasses import field
|
4 |
from typing import Dict, List, Optional
|
5 |
|
|
|
79 |
|
80 |
|
81 |
class Sampler(Artifact):
|
82 |
+
sample_size: int = None
|
83 |
+
|
84 |
+
def prepare(self):
|
85 |
+
super().prepare()
|
86 |
+
self.set_size(self.sample_size)
|
87 |
+
|
88 |
+
def set_size(self, size):
|
89 |
+
if isinstance(size, str):
|
90 |
+
assert size.isdigit(), f"sample_size must be a natural number, got {self.sample_size}"
|
91 |
+
size = int(size)
|
92 |
+
self.sample_size = size
|
93 |
+
|
94 |
+
@abstractmethod
|
95 |
+
def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
|
96 |
+
pass
|
97 |
|
98 |
|
99 |
class RandomSampler(Sampler):
|