transiteration commited on
Commit
6918317
1 Parent(s): 33c0b37

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +18 -9
train.py CHANGED
@@ -7,10 +7,18 @@ from nemo.utils import exp_manager, logging
7
  from omegaconf import OmegaConf, open_dict
8
 
9
 
10
- def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_size: int, num_epochs: int, model_save_path: str = None,) -> None:
 
 
 
 
 
 
 
11
 
12
  # Loading a STT Quartznet 15x5 model
13
  model = nemo_asr.models.ASRModel.from_pretrained("stt_en_quartznet15x5")
 
14
  # New vocabulary for a model
15
  new_vocabulary = [
16
  " ",
@@ -56,7 +64,8 @@ def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_
56
  "ә",
57
  "ө",
58
  ]
59
-
 
60
  with open_dict(model.cfg):
61
  # Setting up the labels and sample rate
62
  model.cfg.labels = new_vocabulary
@@ -104,7 +113,7 @@ def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_
104
 
105
  # Trainer
106
  trainer = ptl.Trainer(
107
- accelerator="gpu",
108
  max_epochs=num_epochs,
109
  accumulate_grad_batches=1,
110
  enable_checkpointing=False,
@@ -133,10 +142,10 @@ def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_
133
  print(OmegaConf.to_yaml(model.cfg))
134
  print("-----------------------------------------------------------")
135
 
136
- # # Fitting the model
137
  trainer.fit(model)
138
 
139
- # # Saving the model
140
  if model_save_path:
141
  model.save_to(f"{model_save_path}")
142
  print(f"Model saved at path : {os.getcwd() + os.path.sep + model_save_path}")
@@ -145,9 +154,9 @@ def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_
145
  if __name__ == "__main__":
146
  # Parse command line arguments
147
  parser = argparse.ArgumentParser()
148
- parser.add_argument("--train_manifest", help="Path for train manifest JSON file.")
149
- parser.add_argument("--val_manifest", help="Path for validation manifest JSON file.")
150
- parser.add_argument("--accelerator", help="What accelerator type to use (cpu, gpu, tpu, etc.).")
151
  parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
152
  parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train for.")
153
  parser.add_argument("--model_save_path", default=None, help="Path for saving a trained model.")
@@ -160,4 +169,4 @@ if __name__ == "__main__":
160
  batch_size=args.batch_size,
161
  num_epochs=args.num_epochs,
162
  model_save_path=args.model_save_path,
163
- )
 
7
  from omegaconf import OmegaConf, open_dict
8
 
9
 
10
+ def train_model(
11
+ train_manifest: str = None,
12
+ val_manifest: str = None,
13
+ accelerator: str = "cpu",
14
+ batch_size: int = 1,
15
+ num_epochs: int = 1,
16
+ model_save_path: str = None,
17
+ ) -> None:
18
 
19
  # Loading a STT Quartznet 15x5 model
20
  model = nemo_asr.models.ASRModel.from_pretrained("stt_en_quartznet15x5")
21
+
22
  # New vocabulary for a model
23
  new_vocabulary = [
24
  " ",
 
64
  "ә",
65
  "ө",
66
  ]
67
+
68
+ # Configurations
69
  with open_dict(model.cfg):
70
  # Setting up the labels and sample rate
71
  model.cfg.labels = new_vocabulary
 
113
 
114
  # Trainer
115
  trainer = ptl.Trainer(
116
+ accelerator=accelerator,
117
  max_epochs=num_epochs,
118
  accumulate_grad_batches=1,
119
  enable_checkpointing=False,
 
142
  print(OmegaConf.to_yaml(model.cfg))
143
  print("-----------------------------------------------------------")
144
 
145
+ # Fitting the model
146
  trainer.fit(model)
147
 
148
+ # Saving the model
149
  if model_save_path:
150
  model.save_to(f"{model_save_path}")
151
  print(f"Model saved at path : {os.getcwd() + os.path.sep + model_save_path}")
 
154
  if __name__ == "__main__":
155
  # Parse command line arguments
156
  parser = argparse.ArgumentParser()
157
+ parser.add_argument("--train_manifest", default = None, help="Path for train manifest JSON file.")
158
+ parser.add_argument("--val_manifest", default = None, help="Path for validation manifest JSON file.")
159
+ parser.add_argument("--accelerator", default="cpu", help="What accelerator type to use (cpu, gpu, tpu, etc.).")
160
  parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
161
  parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train for.")
162
  parser.add_argument("--model_save_path", default=None, help="Path for saving a trained model.")
 
169
  batch_size=args.batch_size,
170
  num_epochs=args.num_epochs,
171
  model_save_path=args.model_save_path,
172
+ )