Elron commited on
Commit
6e18947
·
1 Parent(s): 8fb5471

Upload generator_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generator_utils.py +57 -0
generator_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import itertools
3
+
4
+
5
+ class ReusableGenerator:
6
+ def __init__(self, generator, gen_argv=[], gen_kwargs={}):
7
+ self._generator = generator
8
+ self._gen_kwargs = gen_kwargs
9
+ self._gen_argv = gen_argv
10
+
11
+ def get_generator(self):
12
+ return self._generator
13
+
14
+ def get_gen_kwargs(self):
15
+ return self._gen_kwargs
16
+
17
+ def construct(self):
18
+ return self._generator(*self._gen_argv, **self._gen_kwargs)
19
+
20
+ def __iter__(self):
21
+ return iter(self.construct())
22
+
23
+ def __call__(self):
24
+ yield from self.construct()
25
+
26
+ def __repr__(self):
27
+ return f"{self.__class__.__name__}({self._generator.__name__}, gen_argv={self._gen_argv}, gen_kwargs={self._gen_kwargs})"
28
+
29
+
30
+ if __name__ == "__main__":
31
+ from itertools import chain, islice
32
+
33
+ # Creating objects of MyIterable
34
+ iterable1 = ReusableGenerator(range, gen_argv=[1, 4])
35
+ iterable2 = ReusableGenerator(range, gen_argv=[4, 7])
36
+
37
+ # Using itertools.chain
38
+ chained = list(chain(iterable1, iterable2))
39
+ print(chained) # Prints: [1, 2, 3, 4, 5, 6]
40
+
41
+ # Using itertools.islice
42
+ sliced = list(islice(ReusableGenerator(range, gen_argv=[1, 7]), 1, 4))
43
+ print(sliced) # Prints: [2, 3, 4]
44
+
45
+ # now same test with generators
46
+ def generator(start, end):
47
+ for i in range(start, end):
48
+ yield i
49
+
50
+ iterable1 = ReusableGenerator(generator, gen_argv=[1, 4])
51
+ iterable2 = ReusableGenerator(generator, gen_argv=[4, 7])
52
+
53
+ chained = list(chain(iterable1, iterable2))
54
+ print(chained) # Prints: [1, 2, 3, 4, 5, 6]
55
+
56
+ sliced = list(islice(ReusableGenerator(generator, gen_argv=[1, 7]), 1, 4))
57
+ print(sliced) # Prints: [2, 3, 4]