Upload stream.py with huggingface_hub
Browse files
stream.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
from copy import deepcopy
|
2 |
from typing import Dict, Iterable
|
3 |
|
4 |
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
@@ -31,11 +30,11 @@ class Stream(Dataclass):
|
|
31 |
"""
|
32 |
if self.caching:
|
33 |
return Dataset.from_generator
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
|
40 |
def _get_stream(self):
|
41 |
"""Private method to get the stream based on the initiator function.
|
@@ -102,12 +101,20 @@ class MultiStream(dict):
|
|
102 |
|
103 |
def to_dataset(self) -> DatasetDict:
|
104 |
return DatasetDict(
|
105 |
-
{
|
|
|
|
|
|
|
106 |
)
|
107 |
|
108 |
def to_iterable_dataset(self) -> IterableDatasetDict:
|
109 |
return IterableDatasetDict(
|
110 |
-
{
|
|
|
|
|
|
|
|
|
|
|
111 |
)
|
112 |
|
113 |
def __setitem__(self, key, value):
|
@@ -116,17 +123,19 @@ class MultiStream(dict):
|
|
116 |
super().__setitem__(key, value)
|
117 |
|
118 |
@classmethod
|
119 |
-
def from_generators(
|
|
|
|
|
120 |
"""Creates a MultiStream from a dictionary of ReusableGenerators.
|
121 |
|
122 |
Args:
|
123 |
generators (Dict[str, ReusableGenerator]): A dictionary of ReusableGenerators.
|
124 |
caching (bool, optional): Whether the data should be cached or not. Defaults to False.
|
|
|
125 |
|
126 |
Returns:
|
127 |
MultiStream: A MultiStream object.
|
128 |
"""
|
129 |
-
|
130 |
assert all(isinstance(v, ReusableGenerator) for v in generators.values())
|
131 |
return cls(
|
132 |
{
|
@@ -141,17 +150,19 @@ class MultiStream(dict):
|
|
141 |
)
|
142 |
|
143 |
@classmethod
|
144 |
-
def from_iterables(
|
|
|
|
|
145 |
"""Creates a MultiStream from a dictionary of iterables.
|
146 |
|
147 |
Args:
|
148 |
iterables (Dict[str, Iterable]): A dictionary of iterables.
|
149 |
caching (bool, optional): Whether the data should be cached or not. Defaults to False.
|
|
|
150 |
|
151 |
Returns:
|
152 |
MultiStream: A MultiStream object.
|
153 |
"""
|
154 |
-
|
155 |
return cls(
|
156 |
{
|
157 |
key: Stream(
|
|
|
|
|
1 |
from typing import Dict, Iterable
|
2 |
|
3 |
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
|
30 |
"""
|
31 |
if self.caching:
|
32 |
return Dataset.from_generator
|
33 |
+
|
34 |
+
if self.copying:
|
35 |
+
return CopyingReusableGenerator
|
36 |
+
|
37 |
+
return ReusableGenerator
|
38 |
|
39 |
def _get_stream(self):
|
40 |
"""Private method to get the stream based on the initiator function.
|
|
|
101 |
|
102 |
def to_dataset(self) -> DatasetDict:
|
103 |
return DatasetDict(
|
104 |
+
{
|
105 |
+
key: Dataset.from_generator(self.get_generator, gen_kwargs={"key": key})
|
106 |
+
for key in self.keys()
|
107 |
+
}
|
108 |
)
|
109 |
|
110 |
def to_iterable_dataset(self) -> IterableDatasetDict:
|
111 |
return IterableDatasetDict(
|
112 |
+
{
|
113 |
+
key: IterableDataset.from_generator(
|
114 |
+
self.get_generator, gen_kwargs={"key": key}
|
115 |
+
)
|
116 |
+
for key in self.keys()
|
117 |
+
}
|
118 |
)
|
119 |
|
120 |
def __setitem__(self, key, value):
|
|
|
123 |
super().__setitem__(key, value)
|
124 |
|
125 |
@classmethod
|
126 |
+
def from_generators(
|
127 |
+
cls, generators: Dict[str, ReusableGenerator], caching=False, copying=False
|
128 |
+
):
|
129 |
"""Creates a MultiStream from a dictionary of ReusableGenerators.
|
130 |
|
131 |
Args:
|
132 |
generators (Dict[str, ReusableGenerator]): A dictionary of ReusableGenerators.
|
133 |
caching (bool, optional): Whether the data should be cached or not. Defaults to False.
|
134 |
+
copying (bool, optional): Whether the data should be copyied or not. Defaults to False.
|
135 |
|
136 |
Returns:
|
137 |
MultiStream: A MultiStream object.
|
138 |
"""
|
|
|
139 |
assert all(isinstance(v, ReusableGenerator) for v in generators.values())
|
140 |
return cls(
|
141 |
{
|
|
|
150 |
)
|
151 |
|
152 |
@classmethod
|
153 |
+
def from_iterables(
|
154 |
+
cls, iterables: Dict[str, Iterable], caching=False, copying=False
|
155 |
+
):
|
156 |
"""Creates a MultiStream from a dictionary of iterables.
|
157 |
|
158 |
Args:
|
159 |
iterables (Dict[str, Iterable]): A dictionary of iterables.
|
160 |
caching (bool, optional): Whether the data should be cached or not. Defaults to False.
|
161 |
+
copying (bool, optional): Whether the data should be copyied or not. Defaults to False.
|
162 |
|
163 |
Returns:
|
164 |
MultiStream: A MultiStream object.
|
165 |
"""
|
|
|
166 |
return cls(
|
167 |
{
|
168 |
key: Stream(
|