File size: 5,252 Bytes
0479abb 6918317 0479abb 6918317 0479abb 6918317 0479abb 6918317 0479abb 6918317 0479abb 6918317 0479abb 6918317 0479abb 6918317 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import argparse
import os
import nemo.collections.asr as nemo_asr
import pytorch_lightning as ptl
from nemo.utils import exp_manager, logging
from omegaconf import OmegaConf, open_dict
def train_model(
train_manifest: str = None,
val_manifest: str = None,
accelerator: str = "cpu",
batch_size: int = 1,
num_epochs: int = 1,
model_save_path: str = None,
) -> None:
# Loading a STT Quartznet 15x5 model
model = nemo_asr.models.ASRModel.from_pretrained("stt_en_quartznet15x5")
# New vocabulary for a model
new_vocabulary = [
" ",
"а",
"б",
"в",
"г",
"д",
"е",
"ж",
"з",
"и",
"й",
"к",
"л",
"м",
"н",
"о",
"п",
"р",
"с",
"т",
"у",
"ф",
"х",
"ц",
"ч",
"ш",
"щ",
"ъ",
"ы",
"ь",
"э",
"ю",
"я",
"і",
"ғ",
"қ",
"ң",
"ү",
"ұ",
"һ",
"ә",
"ө",
]
# Configurations
with open_dict(model.cfg):
# Setting up the labels and sample rate
model.cfg.labels = new_vocabulary
model.cfg.sample_rate = 16000
# Train dataset
model.cfg.train_ds.manifest_filepath = train_manifest
model.cfg.train_ds.labels = new_vocabulary
model.cfg.train_ds.normalize_transcripts = False
model.cfg.train_ds.batch_size = batch_size
model.cfg.train_ds.num_workers = 10
model.cfg.train_ds.pin_memory = True
model.cfg.train_ds.trim_silence = True
# Validation dataset
model.cfg.validation_ds.manifest_filepath = val_manifest
model.cfg.validation_ds.labels = new_vocabulary
model.cfg.validation_ds.normalize_transcripts = False
model.cfg.validation_ds.batch_size = batch_size
model.cfg.validation_ds.num_workers = 10
model.cfg.validation_ds.pin_memory = True
model.cfg.validation_ds.trim_silence = True
# Setting up an optimizer and scheduler
model.cfg.optim.lr = 0.001
model.cfg.optim.betas = [0.8, 0.5]
model.cfg.optim.weight_decay = 0.001
model.cfg.optim.sched.warmup_steps = 500
model.cfg.optim.sched.min_lr = 1e-6
model.change_vocabulary(new_vocabulary=new_vocabulary)
model.setup_training_data(model.cfg.train_ds)
model.setup_validation_data(model.cfg.validation_ds)
# Unfreezing encoders to update the parameters
model.encoder.unfreeze()
logging.info("Model encoder has been un-frozen")
# Setting up data augmentation
model.spec_augmentation = model.from_config_dict(model.cfg.spec_augment)
# Setting up the metrics
model._wer.use_cer = True
model._wer.log_prediction = True
# Trainer
trainer = ptl.Trainer(
accelerator=accelerator,
max_epochs=num_epochs,
accumulate_grad_batches=1,
enable_checkpointing=False,
logger=False,
log_every_n_steps=100,
check_val_every_n_epoch=1,
precision=16,
)
# Setting up model with the trainer
model.set_trainer(trainer)
# Experiment tracking
LANGUAGE = "kz"
config = exp_manager.ExpManagerConfig(
exp_dir=f"experiments/lang-{LANGUAGE}/",
name=f"ASR-Model-Language-{LANGUAGE}",
checkpoint_callback_params=exp_manager.CallbackParams(monitor="val_wer", mode="min", always_save_nemo=True, save_best_model=True,),
)
config = OmegaConf.structured(config)
exp_manager.exp_manager(trainer, config)
# Final Configuration
print("-----------------------------------------------------------")
print("Updated STT Model Configuration:")
print(OmegaConf.to_yaml(model.cfg))
print("-----------------------------------------------------------")
# Fitting the model
trainer.fit(model)
# Saving the model
if model_save_path:
model.save_to(f"{model_save_path}")
print(f"Model saved at path : {os.getcwd() + os.path.sep + model_save_path}")
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--train_manifest", default = None, help="Path for train manifest JSON file.")
parser.add_argument("--val_manifest", default = None, help="Path for validation manifest JSON file.")
parser.add_argument("--accelerator", default="cpu", help="What accelerator type to use (cpu, gpu, tpu, etc.).")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train for.")
parser.add_argument("--model_save_path", default=None, help="Path for saving a trained model.")
args = parser.parse_args()
train_model(
train_manifest=args.train_manifest,
val_manifest=args.val_manifest,
accelerator=args.accelerator,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
model_save_path=args.model_save_path,
) |