Elron commited on
Commit
fc6c7eb
·
1 Parent(s): 34f08fc

Upload operator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operator.py +18 -3
operator.py CHANGED
@@ -3,6 +3,7 @@ from dataclasses import field
3
  from typing import Any, Dict, Generator, List, Optional, Union
4
 
5
  from .artifact import Artifact
 
6
  from .random_utils import nested_seed
7
  from .stream import MultiStream, Stream
8
 
@@ -83,9 +84,14 @@ class SourceOperator(StreamSource):
83
 
84
  """
85
 
 
 
86
  def __call__(self) -> MultiStream:
87
  with nested_seed():
88
- return self.process()
 
 
 
89
 
90
  @abstractmethod
91
  def process(self) -> MultiStream:
@@ -102,8 +108,13 @@ class StreamInitializerOperator(StreamSource):
102
 
103
  """
104
 
 
 
105
  def __call__(self, *args, **kwargs) -> MultiStream:
106
  with nested_seed():
 
 
 
107
  return self.process(*args, **kwargs)
108
 
109
  @abstractmethod
@@ -118,6 +129,8 @@ class MultiStreamOperator(StreamingOperator):
118
  A multi-stream operator is a type of `StreamingOperator` that operates on an entire MultiStream object at once. It takes a `MultiStream` as input and produces a `MultiStream` as output. The `process` method should be implemented by subclasses to define the specific operations to be performed on the input `MultiStream`.
119
  """
120
 
 
 
121
  def __call__(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
122
  with nested_seed():
123
  return self._process_multi_stream(multi_stream)
@@ -125,6 +138,8 @@ class MultiStreamOperator(StreamingOperator):
125
  def _process_multi_stream(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
126
  result = self.process(multi_stream)
127
  assert isinstance(result, MultiStream), "MultiStreamOperator must return a MultiStream"
 
 
128
  return result
129
 
130
  @abstractmethod
@@ -198,7 +213,7 @@ class SingleStreamReducer(StreamingOperator):
198
  return result
199
 
200
  @abstractmethod
201
- def process(self, stream: Stream) -> Any:
202
  pass
203
 
204
 
@@ -296,7 +311,7 @@ class InstanceOperatorWithGlobalAccess(StreamingOperator):
296
 
297
  if self.cache_accessible_streams:
298
  for stream in self.accessible_streams.values():
299
- stream.set_caching(True)
300
 
301
  for stream_name, stream in multi_stream.items():
302
  stream = Stream(self.generator, gen_kwargs={"stream": stream, "multi_stream": self.accessible_streams})
 
3
  from typing import Any, Dict, Generator, List, Optional, Union
4
 
5
  from .artifact import Artifact
6
+ from .dataclass import NonPositionalField
7
  from .random_utils import nested_seed
8
  from .stream import MultiStream, Stream
9
 
 
84
 
85
  """
86
 
87
+ caching: bool = NonPositionalField(default=None)
88
+
89
  def __call__(self) -> MultiStream:
90
  with nested_seed():
91
+ multi_stream = self.process()
92
+ if self.caching is not None:
93
+ multi_stream.set_caching(self.caching)
94
+ return multi_stream
95
 
96
  @abstractmethod
97
  def process(self) -> MultiStream:
 
108
 
109
  """
110
 
111
+ caching: bool = NonPositionalField(default=None)
112
+
113
  def __call__(self, *args, **kwargs) -> MultiStream:
114
  with nested_seed():
115
+ multi_stream = self.process(*args, **kwargs)
116
+ if self.caching is not None:
117
+ multi_stream.set_caching(self.caching)
118
  return self.process(*args, **kwargs)
119
 
120
  @abstractmethod
 
129
  A multi-stream operator is a type of `StreamingOperator` that operates on an entire MultiStream object at once. It takes a `MultiStream` as input and produces a `MultiStream` as output. The `process` method should be implemented by subclasses to define the specific operations to be performed on the input `MultiStream`.
130
  """
131
 
132
+ caching: bool = NonPositionalField(default=None)
133
+
134
  def __call__(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
135
  with nested_seed():
136
  return self._process_multi_stream(multi_stream)
 
138
  def _process_multi_stream(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
139
  result = self.process(multi_stream)
140
  assert isinstance(result, MultiStream), "MultiStreamOperator must return a MultiStream"
141
+ if self.caching is not None:
142
+ result.set_caching(self.caching)
143
  return result
144
 
145
  @abstractmethod
 
213
  return result
214
 
215
  @abstractmethod
216
+ def process(self, stream: Stream) -> Stream:
217
  pass
218
 
219
 
 
311
 
312
  if self.cache_accessible_streams:
313
  for stream in self.accessible_streams.values():
314
+ stream.caching = True
315
 
316
  for stream_name, stream in multi_stream.items():
317
  stream = Stream(self.generator, gen_kwargs={"stream": stream, "multi_stream": self.accessible_streams})