Elron commited on
Commit
5c545d2
·
1 Parent(s): 335fd08

Upload fusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. fusion.py +34 -39
fusion.py CHANGED
@@ -1,26 +1,28 @@
1
- from typing import List, Optional, Generator
2
- from dataclasses import asdict
3
- import random
4
  from abc import abstractmethod
 
 
5
 
6
- from .stream import MultiStream, Stream
7
- from .operator import SourceOperator, StreamSource
8
- from .card import TaskCard, ICLCard
9
  from .common import CommonRecipe
 
 
 
 
10
 
11
  class BaseFusion(SourceOperator):
12
  """
13
  BaseFusion operator that combines multiple streams into one.
14
-
15
  Args:
16
  include_splits: List of splits to include. If None, all splits are included.
17
  """
 
18
  include_splits: Optional[List[str]] = None
19
-
20
  @abstractmethod
21
  def fusion_generator(self, split) -> Generator:
22
  pass
23
-
24
  def splits(self) -> Generator:
25
  splits = []
26
  for origin in self.origins:
@@ -29,25 +31,28 @@ class BaseFusion(SourceOperator):
29
  if self.include_splits is None or s in self.include_splits:
30
  splits.append(s)
31
  return splits
32
-
33
 
34
- def process(self, ) -> MultiStream:
 
 
35
  result = {}
36
  for split in self.splits():
37
- result[split] = Stream(self.fusion_generator, gen_kwargs={'split': split})
38
  return MultiStream(result)
39
 
 
40
  class FixedFusion(BaseFusion):
41
  """
42
  FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task.
43
-
44
  Args:
45
  orgins: List of StreamSource objects.
46
  examples_per_task: Number of examples per task. If None, all examples are returned.
47
  splits: List of splits to include. If None, all splits are included.
48
  """
 
49
  examples_per_task: Optional[int] = None
50
-
51
  def fusion_generator(self, split) -> Generator:
52
  for origin in self.orgins:
53
  iterator = iter(origin()[split])
@@ -56,74 +61,64 @@ class FixedFusion(BaseFusion):
56
  yield next(iterator)
57
  else:
58
  yield from iterator
59
-
60
 
61
  class WeightedFusion(BaseFusion):
62
  """
63
- Fusion operator that combines multiple streams based
64
-
65
  Args:
66
  orgins: List of StreamSource objects.
67
  weights: List of weights for each origin.
68
  total_examples: Total number of examples to return. If None, all examples are returned.
69
  """
 
70
  origins: List[StreamSource] = None
71
  weights: List[float] = None
72
  total_examples: int = None
73
-
74
  def verify(self):
75
  super().verify()
76
  assert self.origins is not None, "origins must be specified"
77
  assert self.weights is not None, "weights must be specified"
78
  assert len(self.origins) == len(self.weights), "origins and weights must have the same length"
79
-
80
  def fusion_generator(self, split) -> Generator:
81
  iterators = [iter(origin()[split]) for origin in self.origins]
82
  total_examples = 0
83
- while (self.total_examples is None or total_examples <= self.total_examples) \
84
- and len(iterators) > 0:
85
  iterator = random.choices(population=iterators, weights=self.weights)[0]
86
  try:
87
  yield next(iterator)
88
  total_examples += 1
89
  except StopIteration:
90
  iterators.remove(iterator)
91
-
 
92
  class TasksFusion(SourceOperator):
93
  """
94
  TasksFusion operator that combines multiple tasks into one.
95
-
96
  Args:
97
  tasks: List of TaskCard objects.
98
  config: ICLCard object.
99
  examples_per_task: Number of examples per task. If None, all examples are returned.
100
  include_splits: List of splits to include. If None, all splits are included.
101
  """
 
102
  tasks: List[TaskCard]
103
  config: ICLCard
104
  examples_per_task: Optional[int] = None
105
  include_splits: Optional[List[str]] = None
106
-
107
  def prepare(self):
108
  self.recipes = []
109
  for task in self.tasks:
110
- recipe = CommonRecipe(
111
- card=task,
112
- **asdict(self.config)
113
- )
114
-
115
  self.fusion = FixedFusion(
116
- origins=self.recipes,
117
- examples_per_task=self.examples_per_task,
118
- include_splits=self.include_splits
119
  )
120
 
121
  def process(self) -> MultiStream:
122
  return self.fusion()
123
-
124
-
125
-
126
-
127
-
128
-
129
-
 
 
 
 
1
  from abc import abstractmethod
2
+ from dataclasses import asdict
3
+ from typing import Generator, List, Optional
4
 
5
+ from .card import ICLCard, TaskCard
 
 
6
  from .common import CommonRecipe
7
+ from .operator import SourceOperator, StreamSource
8
+ from .random_utils import random
9
+ from .stream import MultiStream, Stream
10
+
11
 
12
  class BaseFusion(SourceOperator):
13
  """
14
  BaseFusion operator that combines multiple streams into one.
15
+
16
  Args:
17
  include_splits: List of splits to include. If None, all splits are included.
18
  """
19
+
20
  include_splits: Optional[List[str]] = None
21
+
22
  @abstractmethod
23
  def fusion_generator(self, split) -> Generator:
24
  pass
25
+
26
  def splits(self) -> Generator:
27
  splits = []
28
  for origin in self.origins:
 
31
  if self.include_splits is None or s in self.include_splits:
32
  splits.append(s)
33
  return splits
 
34
 
35
+ def process(
36
+ self,
37
+ ) -> MultiStream:
38
  result = {}
39
  for split in self.splits():
40
+ result[split] = Stream(self.fusion_generator, gen_kwargs={"split": split})
41
  return MultiStream(result)
42
 
43
+
44
  class FixedFusion(BaseFusion):
45
  """
46
  FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task.
47
+
48
  Args:
49
  orgins: List of StreamSource objects.
50
  examples_per_task: Number of examples per task. If None, all examples are returned.
51
  splits: List of splits to include. If None, all splits are included.
52
  """
53
+
54
  examples_per_task: Optional[int] = None
55
+
56
  def fusion_generator(self, split) -> Generator:
57
  for origin in self.orgins:
58
  iterator = iter(origin()[split])
 
61
  yield next(iterator)
62
  else:
63
  yield from iterator
64
+
65
 
66
  class WeightedFusion(BaseFusion):
67
  """
68
+ Fusion operator that combines multiple streams based
69
+
70
  Args:
71
  orgins: List of StreamSource objects.
72
  weights: List of weights for each origin.
73
  total_examples: Total number of examples to return. If None, all examples are returned.
74
  """
75
+
76
  origins: List[StreamSource] = None
77
  weights: List[float] = None
78
  total_examples: int = None
79
+
80
  def verify(self):
81
  super().verify()
82
  assert self.origins is not None, "origins must be specified"
83
  assert self.weights is not None, "weights must be specified"
84
  assert len(self.origins) == len(self.weights), "origins and weights must have the same length"
85
+
86
  def fusion_generator(self, split) -> Generator:
87
  iterators = [iter(origin()[split]) for origin in self.origins]
88
  total_examples = 0
89
+ while (self.total_examples is None or total_examples <= self.total_examples) and len(iterators) > 0:
 
90
  iterator = random.choices(population=iterators, weights=self.weights)[0]
91
  try:
92
  yield next(iterator)
93
  total_examples += 1
94
  except StopIteration:
95
  iterators.remove(iterator)
96
+
97
+
98
  class TasksFusion(SourceOperator):
99
  """
100
  TasksFusion operator that combines multiple tasks into one.
101
+
102
  Args:
103
  tasks: List of TaskCard objects.
104
  config: ICLCard object.
105
  examples_per_task: Number of examples per task. If None, all examples are returned.
106
  include_splits: List of splits to include. If None, all splits are included.
107
  """
108
+
109
  tasks: List[TaskCard]
110
  config: ICLCard
111
  examples_per_task: Optional[int] = None
112
  include_splits: Optional[List[str]] = None
113
+
114
  def prepare(self):
115
  self.recipes = []
116
  for task in self.tasks:
117
+ recipe = CommonRecipe(card=task, **asdict(self.config))
118
+
 
 
 
119
  self.fusion = FixedFusion(
120
+ origins=self.recipes, examples_per_task=self.examples_per_task, include_splits=self.include_splits
 
 
121
  )
122
 
123
  def process(self) -> MultiStream:
124
  return self.fusion()