transiteration
commited on
Commit
•
6918317
1
Parent(s):
33c0b37
Update train.py
Browse files
train.py
CHANGED
@@ -7,10 +7,18 @@ from nemo.utils import exp_manager, logging
|
|
7 |
from omegaconf import OmegaConf, open_dict
|
8 |
|
9 |
|
10 |
-
def train_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Loading a STT Quartznet 15x5 model
|
13 |
model = nemo_asr.models.ASRModel.from_pretrained("stt_en_quartznet15x5")
|
|
|
14 |
# New vocabulary for a model
|
15 |
new_vocabulary = [
|
16 |
" ",
|
@@ -56,7 +64,8 @@ def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_
|
|
56 |
"ә",
|
57 |
"ө",
|
58 |
]
|
59 |
-
|
|
|
60 |
with open_dict(model.cfg):
|
61 |
# Setting up the labels and sample rate
|
62 |
model.cfg.labels = new_vocabulary
|
@@ -104,7 +113,7 @@ def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_
|
|
104 |
|
105 |
# Trainer
|
106 |
trainer = ptl.Trainer(
|
107 |
-
accelerator=
|
108 |
max_epochs=num_epochs,
|
109 |
accumulate_grad_batches=1,
|
110 |
enable_checkpointing=False,
|
@@ -133,10 +142,10 @@ def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_
|
|
133 |
print(OmegaConf.to_yaml(model.cfg))
|
134 |
print("-----------------------------------------------------------")
|
135 |
|
136 |
-
#
|
137 |
trainer.fit(model)
|
138 |
|
139 |
-
#
|
140 |
if model_save_path:
|
141 |
model.save_to(f"{model_save_path}")
|
142 |
print(f"Model saved at path : {os.getcwd() + os.path.sep + model_save_path}")
|
@@ -145,9 +154,9 @@ def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_
|
|
145 |
if __name__ == "__main__":
|
146 |
# Parse command line arguments
|
147 |
parser = argparse.ArgumentParser()
|
148 |
-
parser.add_argument("--train_manifest", help="Path for train manifest JSON file.")
|
149 |
-
parser.add_argument("--val_manifest", help="Path for validation manifest JSON file.")
|
150 |
-
parser.add_argument("--accelerator", help="What accelerator type to use (cpu, gpu, tpu, etc.).")
|
151 |
parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
|
152 |
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train for.")
|
153 |
parser.add_argument("--model_save_path", default=None, help="Path for saving a trained model.")
|
@@ -160,4 +169,4 @@ if __name__ == "__main__":
|
|
160 |
batch_size=args.batch_size,
|
161 |
num_epochs=args.num_epochs,
|
162 |
model_save_path=args.model_save_path,
|
163 |
-
)
|
|
|
7 |
from omegaconf import OmegaConf, open_dict
|
8 |
|
9 |
|
10 |
+
def train_model(
|
11 |
+
train_manifest: str = None,
|
12 |
+
val_manifest: str = None,
|
13 |
+
accelerator: str = "cpu",
|
14 |
+
batch_size: int = 1,
|
15 |
+
num_epochs: int = 1,
|
16 |
+
model_save_path: str = None,
|
17 |
+
) -> None:
|
18 |
|
19 |
# Loading a STT Quartznet 15x5 model
|
20 |
model = nemo_asr.models.ASRModel.from_pretrained("stt_en_quartznet15x5")
|
21 |
+
|
22 |
# New vocabulary for a model
|
23 |
new_vocabulary = [
|
24 |
" ",
|
|
|
64 |
"ә",
|
65 |
"ө",
|
66 |
]
|
67 |
+
|
68 |
+
# Configurations
|
69 |
with open_dict(model.cfg):
|
70 |
# Setting up the labels and sample rate
|
71 |
model.cfg.labels = new_vocabulary
|
|
|
113 |
|
114 |
# Trainer
|
115 |
trainer = ptl.Trainer(
|
116 |
+
accelerator=accelerator,
|
117 |
max_epochs=num_epochs,
|
118 |
accumulate_grad_batches=1,
|
119 |
enable_checkpointing=False,
|
|
|
142 |
print(OmegaConf.to_yaml(model.cfg))
|
143 |
print("-----------------------------------------------------------")
|
144 |
|
145 |
+
# Fitting the model
|
146 |
trainer.fit(model)
|
147 |
|
148 |
+
# Saving the model
|
149 |
if model_save_path:
|
150 |
model.save_to(f"{model_save_path}")
|
151 |
print(f"Model saved at path : {os.getcwd() + os.path.sep + model_save_path}")
|
|
|
154 |
if __name__ == "__main__":
|
155 |
# Parse command line arguments
|
156 |
parser = argparse.ArgumentParser()
|
157 |
+
parser.add_argument("--train_manifest", default = None, help="Path for train manifest JSON file.")
|
158 |
+
parser.add_argument("--val_manifest", default = None, help="Path for validation manifest JSON file.")
|
159 |
+
parser.add_argument("--accelerator", default="cpu", help="What accelerator type to use (cpu, gpu, tpu, etc.).")
|
160 |
parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
|
161 |
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train for.")
|
162 |
parser.add_argument("--model_save_path", default=None, help="Path for saving a trained model.")
|
|
|
169 |
batch_size=args.batch_size,
|
170 |
num_epochs=args.num_epochs,
|
171 |
model_save_path=args.model_save_path,
|
172 |
+
)
|