marinone94 commited on
Commit
dbe5e43
·
1 Parent(s): b7db389

WIP: mix datasets

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -87,16 +87,17 @@ if hf_token is not None:
87
  with open("/root/.huggingface/token", "w") as f:
88
  f.write(hf_token)
89
  logger.info("Huggingface API key set")
90
- except PermissionError:
91
  logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
92
  else:
93
  logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
94
 
95
- wandb.login(key=wandb_token, relogin=True, timeout=5)
96
- wandb.init(project="whisper", entity="pn-aa")
97
 
98
  logger.info("Wandb API key set, logging to wandb")
99
 
 
100
  @dataclass
101
  class ModelArguments:
102
  """
@@ -300,7 +301,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
300
  model_input_name = self.processor.model_input_names[0]
301
  input_features = [{model_input_name: feature[model_input_name]} for feature in features]
302
  label_features = [{"input_ids": feature["labels"]} for feature in features]
303
- lang_features = [f"<|{TO_LANGUAGE_CODE[feature['language']]}|>" for feature in features]
304
 
305
  batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
306
 
@@ -313,15 +314,19 @@ class DataCollatorSpeechSeq2SeqWithPadding:
313
  # cut bos token here as it's append later anyways
314
  if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
315
  labels = labels[:, 1:]
316
- lang_token_ids = self.processor.tokenizer(lang_features).input_ids
317
- # Replace language and task if they are in the beginning, otherwise add them
318
- if (labels[:, 1] == self.task_id).all().cpu().item():
319
- labels[:, 0] = lang_token_ids
320
- labels[:, 1] = torch.full_like(labels[:, 1], self.task_id)
321
- else:
322
- # convert task id to tensor of labels dim to concatenate
323
- task_id = torch.full_like(labels[:, 0], self.task_id)
324
- labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
 
 
 
 
325
 
326
  batch["labels"] = labels
327
 
@@ -358,30 +363,54 @@ def notify_me(recipient, message=None):
358
  smtp_obj.quit()
359
 
360
 
361
- def load_maybe_streaming_dataset(dataset_names, dataset_config_names, split="train", streaming=True, **kwargs):
 
 
 
 
 
 
 
 
362
  """
363
  Utility function to load a dataset in streaming mode. For datasets with multiple splits,
364
  each split is loaded individually and then splits combined by taking alternating examples from
365
  each (interleaving).
366
  """
367
- column_names = None
368
- if "column_names" in kwargs:
369
- column_names = kwargs.pop("column_names").split(",")
 
370
 
371
  if "," in dataset_names or "+" in split:
372
  # load multiple splits separated by the `+` symbol with streaming mode
373
  dataset_splits = []
374
- for dataset_name, dataset_config_name, split_names, lang in zip(
375
- dataset_names.split(","), dataset_config_names.split(","), split.split(","), kwargs.pop("language").split(",")
376
  ):
377
  for split_name in split_names.split("+"):
378
- dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
 
 
 
379
  raw_datasets_features = list(dataset.features.keys())
380
- if column_names[0] not in raw_datasets_features:
381
- if len(column_names) == 1 or column_names[1] not in raw_datasets_features:
382
  raise ValueError("Column name not found in dataset.")
383
- dataset = dataset.rename_columns(column_names[1], column_names[0])
384
- dataset["language"] = lang
 
 
 
 
 
 
 
 
 
 
 
 
385
  dataset_splits.append(dataset)
386
 
387
  # interleave multiple splits to form one dataset
@@ -460,6 +489,14 @@ def main():
460
  # Set seed before initializing model.
461
  set_seed(training_args.seed)
462
 
 
 
 
 
 
 
 
 
463
  # 4. Load dataset
464
  logger.info("*** Load dataset ***")
465
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
@@ -471,8 +508,10 @@ def main():
471
  split=data_args.train_split_name,
472
  use_auth_token=hf_token if model_args.use_auth_token else None,
473
  streaming=data_args.streaming,
474
- column_names=data_args.text_column_name,
475
- language=data_args.language_train
 
 
476
  )
477
 
478
  if training_args.do_eval:
@@ -482,7 +521,10 @@ def main():
482
  split=data_args.eval_split_name,
483
  use_auth_token=hf_token if model_args.use_auth_token else None,
484
  streaming=data_args.streaming,
485
- language=data_args.language_eval
 
 
 
486
  )
487
 
488
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
@@ -518,12 +560,6 @@ def main():
518
  if training_args.gradient_checkpointing:
519
  config.update({"use_cache": False})
520
 
521
- feature_extractor = AutoFeatureExtractor.from_pretrained(
522
- model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
523
- cache_dir=model_args.cache_dir,
524
- revision=model_args.model_revision,
525
- use_auth_token=hf_token if model_args.use_auth_token else None,
526
- )
527
  tokenizer = AutoTokenizer.from_pretrained(
528
  model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
529
  cache_dir=model_args.cache_dir,
@@ -548,21 +584,19 @@ def main():
548
  if model_args.freeze_encoder:
549
  model.freeze_encoder()
550
 
551
- if data_args.language is not None and len(data_args.language.split(",")) == 1:
552
  # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
553
  # If more than a langugae is specified, it will be specified in the data collator
554
- tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
555
- elif data_args.language is not None and len(data_args.language.split(",")) > 1:
556
  # make sure language and task are not stored in the model config
557
  model.config.forced_decoder_ids = None
558
 
559
  # 6. Resample speech dataset if necessary
560
- logger.info("*** Resample dataset ***")
561
- dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
562
- if dataset_sampling_rate != feature_extractor.sampling_rate:
563
- raw_datasets = raw_datasets.cast_column(
564
- data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
565
- )
566
 
567
  # 7. Preprocessing the datasets.
568
  # We need to read the audio files as arrays and tokenize the targets.
@@ -606,7 +640,7 @@ def main():
606
  return batch
607
 
608
  with training_args.main_process_first(desc="dataset map pre-processing"):
609
- raw_datasets_features.remove("language")
610
  vectorized_datasets = raw_datasets.map(
611
  prepare_dataset,
612
  remove_columns=raw_datasets_features,
@@ -765,8 +799,8 @@ def main():
765
  kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
766
  else:
767
  kwargs["dataset"] = data_args.dataset_name
768
- if "common_voice" in data_args.dataset_name:
769
- kwargs["language"] = data_args.dataset_config_name[:2]
770
  if model_args.model_index_name is not None:
771
  kwargs["model_name"] = model_args.model_index_name
772
 
 
87
  with open("/root/.huggingface/token", "w") as f:
88
  f.write(hf_token)
89
  logger.info("Huggingface API key set")
90
+ except (PermissionError, OSError):
91
  logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
92
  else:
93
  logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
94
 
95
+ # wandb.login(key=wandb_token, relogin=True, timeout=5)
96
+ # wandb.init(project="whisper", entity="pn-aa")
97
 
98
  logger.info("Wandb API key set, logging to wandb")
99
 
100
+
101
  @dataclass
102
  class ModelArguments:
103
  """
 
301
  model_input_name = self.processor.model_input_names[0]
302
  input_features = [{model_input_name: feature[model_input_name]} for feature in features]
303
  label_features = [{"input_ids": feature["labels"]} for feature in features]
304
+ # lang_features = [f"<|{TO_LANGUAGE_CODE[feature['language']]}|>" for feature in features]
305
 
306
  batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
307
 
 
314
  # cut bos token here as it's append later anyways
315
  if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
316
  labels = labels[:, 1:]
317
+ # lang_token_ids = self.processor.tokenizer(lang_features).input_ids
318
+ # # Replace language and task if they are in the beginning, otherwise add them
319
+ # if (labels[:, 1] == self.task_id).all().cpu().item():
320
+ # labels[:, 0] = lang_token_ids
321
+ # labels[:, 1] = torch.full_like(labels[:, 1], self.task_id)
322
+ # else:
323
+ # # convert task id to tensor of labels dim to concatenate
324
+ # task_id = torch.full_like(labels[:, 0], self.task_id)
325
+ # labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
326
+
327
+ # Set language and task to pad token
328
+ labels[:, 0] = torch.full_like(labels[:, 0], -100)
329
+ labels[:, 1] = torch.full_like(labels[:, 1], -100)
330
 
331
  batch["labels"] = labels
332
 
 
363
  smtp_obj.quit()
364
 
365
 
366
+ def load_maybe_streaming_dataset(
367
+ dataset_names,
368
+ dataset_config_names,
369
+ split="train",
370
+ streaming=True,
371
+ audio_column_name=None,
372
+ sampling_rate=None,
373
+ **kwargs
374
+ ):
375
  """
376
  Utility function to load a dataset in streaming mode. For datasets with multiple splits,
377
  each split is loaded individually and then splits combined by taking alternating examples from
378
  each (interleaving).
379
  """
380
+ text_column_names = None
381
+ if "text_column_name" in kwargs:
382
+ text_column_names = kwargs.pop("text_column_name").split(",")
383
+ text_col_name_ref = text_column_names[0]
384
 
385
  if "," in dataset_names or "+" in split:
386
  # load multiple splits separated by the `+` symbol with streaming mode
387
  dataset_splits = []
388
+ for dataset_name, dataset_config_name, split_names in zip(
389
+ dataset_names.split(","), dataset_config_names.split(","), split.split(",")
390
  ):
391
  for split_name in split_names.split("+"):
392
+ if dataset_config_name:
393
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
394
+ else:
395
+ dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
396
  raw_datasets_features = list(dataset.features.keys())
397
+ if text_col_name_ref not in raw_datasets_features:
398
+ if len(text_column_names) == 1:
399
  raise ValueError("Column name not found in dataset.")
400
+ flag = False
401
+ for text_column_name in text_column_names:
402
+ if text_column_name in raw_datasets_features:
403
+ dataset = dataset.rename_column(text_column_name, text_col_name_ref)
404
+ flag = True
405
+ break
406
+ if flag is False:
407
+ raise ValueError("None of the text column names provided found in dataset."
408
+ f"Text columns: {text_column_names}"
409
+ f"Dataset columns: {raw_datasets_features}")
410
+ if audio_column_name is not None and sampling_rate is not None:
411
+ dataset = dataset.cast_column(
412
+ audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
413
+ )
414
  dataset_splits.append(dataset)
415
 
416
  # interleave multiple splits to form one dataset
 
489
  # Set seed before initializing model.
490
  set_seed(training_args.seed)
491
 
492
+ # Load feature extractor
493
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
494
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
495
+ cache_dir=model_args.cache_dir,
496
+ revision=model_args.model_revision,
497
+ use_auth_token=hf_token if model_args.use_auth_token else None,
498
+ )
499
+
500
  # 4. Load dataset
501
  logger.info("*** Load dataset ***")
502
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
 
508
  split=data_args.train_split_name,
509
  use_auth_token=hf_token if model_args.use_auth_token else None,
510
  streaming=data_args.streaming,
511
+ text_column_name=data_args.text_column_name,
512
+ audio_column_name=data_args.audio_column_name,
513
+ sampling_rate=feature_extractor.sampling_rate,
514
+ # language=data_args.language_train
515
  )
516
 
517
  if training_args.do_eval:
 
521
  split=data_args.eval_split_name,
522
  use_auth_token=hf_token if model_args.use_auth_token else None,
523
  streaming=data_args.streaming,
524
+ text_column_name=data_args.text_column_name,
525
+ audio_column_name=data_args.audio_column_name,
526
+ sampling_rate=feature_extractor.sampling_rate,
527
+ # language=data_args.language_eval
528
  )
529
 
530
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
 
560
  if training_args.gradient_checkpointing:
561
  config.update({"use_cache": False})
562
 
 
 
 
 
 
 
563
  tokenizer = AutoTokenizer.from_pretrained(
564
  model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
565
  cache_dir=model_args.cache_dir,
 
584
  if model_args.freeze_encoder:
585
  model.freeze_encoder()
586
 
587
+ if data_args.language_train is not None and len(data_args.language_train.split(",")) == 1:
588
  # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
589
  # If more than a langugae is specified, it will be specified in the data collator
590
+ tokenizer.set_prefix_tokens(language=data_args.language_train, task=data_args.task)
591
+ elif data_args.language_train is not None and len(data_args.language_train.split(",")) > 1:
592
  # make sure language and task are not stored in the model config
593
  model.config.forced_decoder_ids = None
594
 
595
  # 6. Resample speech dataset if necessary
596
+ # logger.info("*** Resample dataset ***")
597
+ # dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
598
+ # if dataset_sampling_rate != feature_extractor.sampling_rate:
599
+
 
 
600
 
601
  # 7. Preprocessing the datasets.
602
  # We need to read the audio files as arrays and tokenize the targets.
 
640
  return batch
641
 
642
  with training_args.main_process_first(desc="dataset map pre-processing"):
643
+ # raw_datasets_features.remove("language")
644
  vectorized_datasets = raw_datasets.map(
645
  prepare_dataset,
646
  remove_columns=raw_datasets_features,
 
799
  kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
800
  else:
801
  kwargs["dataset"] = data_args.dataset_name
802
+ # if "common_voice" in data_args.dataset_name:
803
+ # kwargs["language"] = data_args.dataset_config_name[:2]
804
  if model_args.model_index_name is not None:
805
  kwargs["model_name"] = model_args.model_index_name
806
 
test_run_nordic.sh CHANGED
@@ -1,7 +1,7 @@
1
  python $1run_speech_recognition_seq2seq_streaming.py \
2
  --model_name_or_path="openai/whisper-tiny" \
3
  --dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
4
- --dataset_train_config_name="sv-SE,da,nn-NO,,no-distant,,16k_mp3_nynorsk,sv_se,da_dk,nb_no" \
5
  --language_train="swedish,danish,norwegian,swedish,norwegian,danish,norwegian,swedish,danish,norwegian" \
6
  --train_split_name="train+validation,train+validation,train+validation,train,train+test,train,train+validation,train+validation,train+validation,train+validation" \
7
  --dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
@@ -25,7 +25,7 @@ python $1run_speech_recognition_seq2seq_streaming.py \
25
  --generation_max_length="225" \
26
  --length_column_name="input_length" \
27
  --max_duration_in_seconds="30" \
28
- --text_column_name="sentence,text" \
29
  --freeze_feature_encoder="False" \
30
  --report_to="wandb" \
31
  --metric_for_best_model="wer" \
 
1
  python $1run_speech_recognition_seq2seq_streaming.py \
2
  --model_name_or_path="openai/whisper-tiny" \
3
  --dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
4
+ --dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant,,16K_mp3_nynorsk,sv_se,da_dk,nb_no" \
5
  --language_train="swedish,danish,norwegian,swedish,norwegian,danish,norwegian,swedish,danish,norwegian" \
6
  --train_split_name="train+validation,train+validation,train+validation,train,train+test,train,train+validation,train+validation,train+validation,train+validation" \
7
  --dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
 
25
  --generation_max_length="225" \
26
  --length_column_name="input_length" \
27
  --max_duration_in_seconds="30" \
28
+ --text_column_name="sentence,text,raw_transcription" \
29
  --freeze_feature_encoder="False" \
30
  --report_to="wandb" \
31
  --metric_for_best_model="wer" \
test_run_nordic_cv.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python $1run_speech_recognition_seq2seq_streaming.py \
2
+ --model_name_or_path="openai/whisper-tiny" \
3
+ --dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
4
+ --dataset_train_config_name="sv-SE,da,nn-NO" \
5
+ --language_train="swedish,danish,norwegian" \
6
+ --train_split_name="train+validation,train+validation,train+validation" \
7
+ --dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
8
+ --dataset_eval_config_name="sv-SE,da,nn-NO" \
9
+ --language_eval="swedish,danish,norwegian" \
10
+ --eval_split_name="test" \
11
+ --model_index_name="Whisper Tiny Swedish" \
12
+ --max_train_samples="64" \
13
+ --max_eval_samples="32" \
14
+ --max_steps="500" \
15
+ --output_dir="./" \
16
+ --per_device_train_batch_size="8" \
17
+ --per_device_eval_batch_size="4" \
18
+ --logging_steps="25" \
19
+ --learning_rate="1e-5" \
20
+ --warmup_steps="500" \
21
+ --evaluation_strategy="steps" \
22
+ --eval_steps="1000" \
23
+ --save_strategy="steps" \
24
+ --save_steps="1000" \
25
+ --generation_max_length="225" \
26
+ --length_column_name="input_length" \
27
+ --max_duration_in_seconds="30" \
28
+ --text_column_name="sentence,text" \
29
+ --freeze_feature_encoder="False" \
30
+ --metric_for_best_model="wer" \
31
+ --greater_is_better="False" \
32
+ --load_best_model_at_end \
33
+ --gradient_checkpointing \
34
+ --overwrite_output_dir \
35
+ --do_train \
36
+ --do_eval \
37
+ --predict_with_generate \
38
+ --do_normalize_eval \
39
+ --streaming \
40
+ --use_auth_token \
41
+ --push_to_hub