winglian commited on
Commit
5294653
·
unverified ·
1 Parent(s): 98c25e1

PoSE context length ext (#1567)

Browse files

* PoSE wip

* fixes for pose splitting

* set pose context len so we can pick that up seperately from the usable training context len

* support min sample len and define num chunks

* fix chunk splitting

* support for curriculum/ordered learning with pose

* fix sequence len sort

* add curriculum_sampling to pydantic

src/axolotl/core/trainer_builder.py CHANGED
@@ -212,6 +212,10 @@ class AxolotlTrainingArguments(TrainingArguments):
212
  default=None,
213
  metadata={"help": "path under the model to access the layers"},
214
  )
 
 
 
 
215
 
216
 
217
  class AxolotlTrainer(Trainer):
@@ -347,6 +351,8 @@ class AxolotlTrainer(Trainer):
347
  lengths=get_dataset_lengths(self.train_dataset),
348
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
349
  )
 
 
350
  return super()._get_train_sampler()
351
 
352
  def _get_eval_sampler(
@@ -1193,6 +1199,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
1193
  False if self.cfg.ddp else None
1194
  )
1195
  training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
 
1196
  report_to = None
1197
  if self.cfg.use_wandb:
1198
  report_to = "wandb"
 
212
  default=None,
213
  metadata={"help": "path under the model to access the layers"},
214
  )
215
+ curriculum_sampling: Optional[bool] = field(
216
+ default=None,
217
+ metadata={"help": "whether to use sequential sampling for curriculum learning"},
218
+ )
219
 
220
 
221
  class AxolotlTrainer(Trainer):
 
351
  lengths=get_dataset_lengths(self.train_dataset),
352
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
353
  )
354
+ if self.args.curriculum_sampling:
355
+ return SequentialSampler(self.train_dataset)
356
  return super()._get_train_sampler()
357
 
358
  def _get_eval_sampler(
 
1199
  False if self.cfg.ddp else None
1200
  )
1201
  training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
1202
+ training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
1203
  report_to = None
1204
  if self.cfg.use_wandb:
1205
  report_to = "wandb"
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -503,9 +503,17 @@ class AxolotlInputConfig(
503
  unfrozen_parameters: Optional[List[str]] = None
504
 
505
  sequence_len: int = Field(default=512)
 
506
  sample_packing: Optional[bool] = None
507
  eval_sample_packing: Optional[bool] = None
508
  pad_to_sequence_len: Optional[bool] = None
 
 
 
 
 
 
 
509
 
510
  pretrain_multipack_buffer_size: Optional[int] = 10_000
511
  pretrain_multipack_attn: Optional[bool] = Field(
 
503
  unfrozen_parameters: Optional[List[str]] = None
504
 
505
  sequence_len: int = Field(default=512)
506
+ min_sample_len: Optional[int] = None
507
  sample_packing: Optional[bool] = None
508
  eval_sample_packing: Optional[bool] = None
509
  pad_to_sequence_len: Optional[bool] = None
510
+ curriculum_sampling: Optional[bool] = None
511
+
512
+ # for PoSE context length extension
513
+ use_pose: Optional[bool] = None
514
+ pose_split_on_token_ids: Optional[List[int]] = None
515
+ pose_max_context_len: Optional[int] = None
516
+ pose_num_chunks: Optional[int] = None
517
 
518
  pretrain_multipack_buffer_size: Optional[int] = 10_000
519
  pretrain_multipack_attn: Optional[bool] = Field(
src/axolotl/utils/trainer.py CHANGED
@@ -1,9 +1,10 @@
1
  """Module containing the Trainer class and related functions"""
2
  import math
3
  import os
 
4
  from contextlib import contextmanager
5
  from functools import partial
6
- from typing import List
7
 
8
  import numpy as np
9
  import torch
@@ -98,17 +99,89 @@ def add_position_ids(sample):
98
  return sample
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def add_length(sample):
102
  sample["length"] = len(sample["input_ids"])
103
  return sample
104
 
105
 
106
- def drop_long_seq(sample, sequence_len=2048):
107
- return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
 
 
 
108
 
109
 
110
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
111
- drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
 
 
 
 
112
  with zero_first(is_main_process()):
113
  if cfg.is_preprocess:
114
  min_input_len = np.min(get_dataset_lengths(train_dataset))
@@ -153,7 +226,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
153
  desc="Group By Length",
154
  )
155
 
156
- if cfg.sample_packing:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  train_dataset = train_dataset.map(
158
  add_position_ids,
159
  num_proc=cfg.dataset_processes,
 
1
  """Module containing the Trainer class and related functions"""
2
  import math
3
  import os
4
+ import random
5
  from contextlib import contextmanager
6
  from functools import partial
7
+ from typing import List, Optional
8
 
9
  import numpy as np
10
  import torch
 
99
  return sample
100
 
101
 
102
+ def add_pose_position_ids(
103
+ sample,
104
+ max_context_len=32768,
105
+ split_on_token_ids: Optional[List[int]] = None,
106
+ chunks: int = 2,
107
+ ):
108
+ """
109
+ use the PoSE technique to extend the context length by randomly skipping
110
+ positions in the context. We only want to skip right before tokens in
111
+ the split_on_token_ids list. We should attempt to randomly distribute
112
+ the skips, but we don't need the final position_ids to be the full
113
+ context_len. There may be multiple turns in the context, so we want to
114
+ make sure we take into account the maximum possible number of skips
115
+ remaining in each sample.
116
+ """
117
+
118
+ input_ids = sample["input_ids"]
119
+ sample_len = len(input_ids)
120
+ max_skips = max_context_len - sample_len
121
+
122
+ if split_on_token_ids is None:
123
+ split_on_token_ids = []
124
+
125
+ if split_on_token_ids:
126
+ split_indices = [
127
+ i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
128
+ ]
129
+ else:
130
+ chunk_len = sample_len // chunks
131
+ split_indices = [i * chunk_len for i in range(1, chunks)]
132
+ split_indices.append(len(input_ids)) # make sure we go to the end of the sample
133
+ if split_indices[0] < 2:
134
+ # drop the first split index if it's too close to the beginning
135
+ split_indices = split_indices[1:]
136
+
137
+ position_ids = []
138
+ prev_index = 0
139
+ total_skips = 0
140
+
141
+ for split_index in split_indices:
142
+ num_skips = (
143
+ random.randint(0, max_skips) # nosec B311
144
+ if prev_index != 0 and max_skips
145
+ else 0
146
+ )
147
+ max_skips -= num_skips
148
+ total_skips += num_skips
149
+
150
+ segment_position_ids = list(
151
+ range(prev_index + total_skips, split_index + total_skips)
152
+ )
153
+
154
+ position_ids.extend(segment_position_ids)
155
+ prev_index = split_index
156
+
157
+ sample["sequence_len"] = position_ids[-1]
158
+ position_ids = torch.tensor(position_ids)
159
+
160
+ sample["position_ids"] = position_ids
161
+ sample["length"] = len(position_ids)
162
+ assert len(position_ids) == len(input_ids)
163
+
164
+ return sample
165
+
166
+
167
  def add_length(sample):
168
  sample["length"] = len(sample["input_ids"])
169
  return sample
170
 
171
 
172
+ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
173
+ return (
174
+ len(sample["input_ids"]) <= sequence_len
175
+ and len(sample["input_ids"]) >= min_sequence_len
176
+ )
177
 
178
 
179
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
180
+ drop_long = partial(
181
+ drop_long_seq,
182
+ sequence_len=cfg.sequence_len,
183
+ min_sequence_len=cfg.min_sample_len or 2,
184
+ )
185
  with zero_first(is_main_process()):
186
  if cfg.is_preprocess:
187
  min_input_len = np.min(get_dataset_lengths(train_dataset))
 
226
  desc="Group By Length",
227
  )
228
 
229
+ if cfg.use_pose:
230
+ pose_kwargs = {}
231
+ if cfg.pose_num_chunks is not None:
232
+ pose_kwargs["chunks"] = cfg.pose_num_chunks
233
+ pose_fn = partial(
234
+ add_pose_position_ids,
235
+ max_context_len=cfg.pose_max_context_len,
236
+ split_on_token_ids=cfg.pose_split_on_token_ids,
237
+ **pose_kwargs,
238
+ )
239
+ train_dataset = train_dataset.map(
240
+ pose_fn,
241
+ num_proc=cfg.dataset_processes,
242
+ load_from_cache_file=not cfg.is_preprocess,
243
+ desc="Add position_id column (PoSE)",
244
+ )
245
+ train_dataset = train_dataset.sort("sequence_len")
246
+ if cfg.eval_sample_packing is not False:
247
+ if eval_dataset:
248
+ eval_dataset = eval_dataset.map(
249
+ pose_fn,
250
+ num_proc=cfg.dataset_processes,
251
+ load_from_cache_file=not cfg.is_preprocess,
252
+ desc="Add position_id column (PoSE)",
253
+ )
254
+ elif cfg.sample_packing:
255
  train_dataset = train_dataset.map(
256
  add_position_ids,
257
  num_proc=cfg.dataset_processes,