File size: 5,252 Bytes
0479abb
 
 
 
 
 
 
 
 
6918317
 
 
 
 
 
 
 
0479abb
 
 
6918317
0479abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6918317
 
0479abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6918317
0479abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6918317
0479abb
 
6918317
0479abb
 
 
 
 
 
 
 
6918317
 
 
0479abb
 
 
 
 
 
 
 
 
 
 
 
6918317
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import argparse
import os

import nemo.collections.asr as nemo_asr
import pytorch_lightning as ptl
from nemo.utils import exp_manager, logging
from omegaconf import OmegaConf, open_dict


def train_model(
        train_manifest: str = None, 
        val_manifest: str = None, 
        accelerator: str = "cpu", 
        batch_size: int = 1, 
        num_epochs: int = 1, 
        model_save_path: str = None,
) -> None:

    # Loading a STT Quartznet 15x5 model
    model = nemo_asr.models.ASRModel.from_pretrained("stt_en_quartznet15x5")
    
    # New vocabulary for a model
    new_vocabulary = [
        " ",
        "а",
        "б",
        "в",
        "г",
        "д",
        "е",
        "ж",
        "з",
        "и",
        "й",
        "к",
        "л",
        "м",
        "н",
        "о",
        "п",
        "р",
        "с",
        "т",
        "у",
        "ф",
        "х",
        "ц",
        "ч",
        "ш",
        "щ",
        "ъ",
        "ы",
        "ь",
        "э",
        "ю",
        "я",
        "і",
        "ғ",
        "қ",
        "ң",
        "ү",
        "ұ",
        "һ",
        "ә",
        "ө",
    ]
    
    # Configurations
    with open_dict(model.cfg):
        # Setting up the labels and sample rate
        model.cfg.labels = new_vocabulary
        model.cfg.sample_rate = 16000

        # Train dataset
        model.cfg.train_ds.manifest_filepath = train_manifest
        model.cfg.train_ds.labels = new_vocabulary
        model.cfg.train_ds.normalize_transcripts = False
        model.cfg.train_ds.batch_size = batch_size
        model.cfg.train_ds.num_workers = 10
        model.cfg.train_ds.pin_memory = True
        model.cfg.train_ds.trim_silence = True

        # Validation dataset
        model.cfg.validation_ds.manifest_filepath = val_manifest
        model.cfg.validation_ds.labels = new_vocabulary
        model.cfg.validation_ds.normalize_transcripts = False
        model.cfg.validation_ds.batch_size = batch_size
        model.cfg.validation_ds.num_workers = 10
        model.cfg.validation_ds.pin_memory = True
        model.cfg.validation_ds.trim_silence = True

        # Setting up an optimizer and scheduler
        model.cfg.optim.lr = 0.001
        model.cfg.optim.betas = [0.8, 0.5]
        model.cfg.optim.weight_decay = 0.001
        model.cfg.optim.sched.warmup_steps = 500
        model.cfg.optim.sched.min_lr = 1e-6

    model.change_vocabulary(new_vocabulary=new_vocabulary)
    model.setup_training_data(model.cfg.train_ds)
    model.setup_validation_data(model.cfg.validation_ds)

    # Unfreezing encoders to update the parameters
    model.encoder.unfreeze()
    logging.info("Model encoder has been un-frozen")

    # Setting up data augmentation
    model.spec_augmentation = model.from_config_dict(model.cfg.spec_augment)

    # Setting up the metrics
    model._wer.use_cer = True
    model._wer.log_prediction = True

    # Trainer
    trainer = ptl.Trainer(
        accelerator=accelerator,
        max_epochs=num_epochs,
        accumulate_grad_batches=1,
        enable_checkpointing=False,
        logger=False,
        log_every_n_steps=100,
        check_val_every_n_epoch=1,
        precision=16,
    )

    # Setting up model with the trainer
    model.set_trainer(trainer)

    # Experiment tracking
    LANGUAGE = "kz"
    config = exp_manager.ExpManagerConfig(
        exp_dir=f"experiments/lang-{LANGUAGE}/",
        name=f"ASR-Model-Language-{LANGUAGE}",
        checkpoint_callback_params=exp_manager.CallbackParams(monitor="val_wer", mode="min", always_save_nemo=True, save_best_model=True,),
    )
    config = OmegaConf.structured(config)
    exp_manager.exp_manager(trainer, config)

    # Final Configuration
    print("-----------------------------------------------------------")
    print("Updated STT Model Configuration:")
    print(OmegaConf.to_yaml(model.cfg))
    print("-----------------------------------------------------------")

    # Fitting the model
    trainer.fit(model)

    # Saving the model
    if model_save_path:
        model.save_to(f"{model_save_path}")
        print(f"Model saved at path : {os.getcwd() + os.path.sep + model_save_path}")


if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_manifest", default = None, help="Path for train manifest JSON file.")
    parser.add_argument("--val_manifest", default = None, help="Path for validation manifest JSON file.")
    parser.add_argument("--accelerator", default="cpu", help="What accelerator type to use (cpu, gpu, tpu, etc.).")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
    parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train for.")
    parser.add_argument("--model_save_path", default=None, help="Path for saving a trained model.")
    args = parser.parse_args()

    train_model(
        train_manifest=args.train_manifest,
        val_manifest=args.val_manifest,
        accelerator=args.accelerator,
        batch_size=args.batch_size,
        num_epochs=args.num_epochs,
        model_save_path=args.model_save_path,
    )