balaramas pranavkarande commited on
Commit
617b44a
1 Parent(s): 9cb3e5e

Optimized to generate translation on only 1 sample (#2)

Browse files

- Optimized to generate translation on only 1 sample (5de52431f33de0ae0fd50fbf42eca7127ed572c5)


Co-authored-by: Pranav karande <[email protected]>

Files changed (2) hide show
  1. app.py +3 -4
  2. generate.py +426 -0
app.py CHANGED
@@ -10,7 +10,6 @@ import sys
10
  import os
11
  import subprocess
12
  from pydub import AudioSegment
13
- from huggingface_hub import snapshot_download
14
 
15
  def install_fairseq():
16
  try:
@@ -67,7 +66,7 @@ def run_my_code(input_text, language):
67
 
68
  print("------Performing translation...")
69
 
70
- translation_result = subprocess.run(["fairseq-generate", data_root, "--config-yaml", "config_st.yaml", "--gen-subset", "tst-COMMON_st", "--task", "speech_to_text", "--path", model_checkpoint, "--max-tokens", "50000", "--beam", "5", "--scoring", "sacrebleu"], capture_output=True, text=True)
71
  translation_result_text = translation_result.stdout
72
 
73
  lines = translation_result_text.split("\n")
@@ -91,7 +90,7 @@ install_fairseq()
91
  # gr.inputs.Dropdown(list(LANGUAGE_CODES.keys()), default="Hindi", label="From English to Languages X..."),
92
  # ]
93
 
94
- #input_textbox = gr.inputs.Textbox(label="test2.wav")
95
  #input=gr.inputs.Audio(source="microphone", type="filepath", label="Record something (in English)...")
96
  #audio=convert_audio_to_16k_wav(input)
97
  output_textbox = gr.outputs.Textbox(label="Output Text")
@@ -99,7 +98,7 @@ output_textbox = gr.outputs.Textbox(label="Output Text")
99
  # Create a Gradio interface
100
  iface = gr.Interface(
101
  fn=run_my_code,
102
- inputs=[gr.inputs.Audio(source="microphone", type="filepath", label="Record something (in English)..."), gr.inputs.Radio(["Hindi", "French"], label="Language")],
103
  outputs=output_textbox,
104
  title="English to Hindi Translator")
105
 
 
10
  import os
11
  import subprocess
12
  from pydub import AudioSegment
 
13
 
14
  def install_fairseq():
15
  try:
 
66
 
67
  print("------Performing translation...")
68
 
69
+ translation_result = subprocess.run(["python", "generate.py", data_root, "--config-yaml", "config_st.yaml", "--gen-subset", "tst-COMMON_st", "--task", "speech_to_text", "--path", model_checkpoint], capture_output=True, text=True)
70
  translation_result_text = translation_result.stdout
71
 
72
  lines = translation_result_text.split("\n")
 
90
  # gr.inputs.Dropdown(list(LANGUAGE_CODES.keys()), default="Hindi", label="From English to Languages X..."),
91
  # ]
92
 
93
+ input_textbox = gr.inputs.Textbox(label="test2.wav")
94
  #input=gr.inputs.Audio(source="microphone", type="filepath", label="Record something (in English)...")
95
  #audio=convert_audio_to_16k_wav(input)
96
  output_textbox = gr.outputs.Textbox(label="Output Text")
 
98
  # Create a Gradio interface
99
  iface = gr.Interface(
100
  fn=run_my_code,
101
+ inputs=[input_textbox, gr.inputs.Radio(["Hindi", "French"], label="Language")],
102
  outputs=output_textbox,
103
  title="English to Hindi Translator")
104
 
generate.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Translate pre-processed data with a trained model.
8
+ """
9
+
10
+ import ast
11
+ import logging
12
+ import math
13
+ import os
14
+ import sys
15
+ from argparse import Namespace
16
+ from itertools import chain
17
+
18
+ import numpy as np
19
+ import torch
20
+ from omegaconf import DictConfig
21
+
22
+ from fairseq import checkpoint_utils, options, scoring, tasks, utils
23
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
24
+ from fairseq.logging import progress_bar
25
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
26
+
27
+
28
+ def main(cfg: DictConfig):
29
+
30
+ if isinstance(cfg, Namespace):
31
+ cfg = convert_namespace_to_omegaconf(cfg)
32
+
33
+ assert cfg.common_eval.path is not None, "--path required for generation!"
34
+ assert (
35
+ not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
36
+ ), "--sampling requires --nbest to be equal to --beam"
37
+ assert (
38
+ cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw"
39
+ ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
40
+
41
+ if cfg.common_eval.results_path is not None:
42
+ os.makedirs(cfg.common_eval.results_path, exist_ok=True)
43
+ output_path = os.path.join(
44
+ cfg.common_eval.results_path,
45
+ "generate-{}.txt".format(cfg.dataset.gen_subset),
46
+ )
47
+ with open(output_path, "w", buffering=1, encoding="utf-8") as h:
48
+ return _main(cfg, h)
49
+ else:
50
+ return _main(cfg, sys.stdout)
51
+
52
+
53
+ def get_symbols_to_strip_from_output(generator):
54
+ if hasattr(generator, "symbols_to_strip_from_output"):
55
+ return generator.symbols_to_strip_from_output
56
+ else:
57
+ return {generator.eos}
58
+
59
+
60
+ def _main(cfg: DictConfig, output_file):
61
+ logging.basicConfig(
62
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
63
+ datefmt="%Y-%m-%d %H:%M:%S",
64
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
65
+ stream=output_file,
66
+ )
67
+ logger = logging.getLogger("fairseq_cli.generate")
68
+
69
+ utils.import_user_module(cfg.common)
70
+
71
+ if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
72
+ cfg.dataset.max_tokens = 12000
73
+ logger.info(cfg)
74
+
75
+ # Fix seed for stochastic decoding
76
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
77
+ np.random.seed(cfg.common.seed)
78
+ utils.set_torch_seed(cfg.common.seed)
79
+
80
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
81
+
82
+ # Load dataset splits
83
+ task = tasks.setup_task(cfg.task)
84
+
85
+ # Set dictionaries
86
+ try:
87
+ src_dict = getattr(task, "source_dictionary", None)
88
+ except NotImplementedError:
89
+ src_dict = None
90
+ tgt_dict = task.target_dictionary
91
+
92
+ overrides = ast.literal_eval(cfg.common_eval.model_overrides)
93
+
94
+ # Load ensemble
95
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
96
+ models, saved_cfg = checkpoint_utils.load_model_ensemble(
97
+ utils.split_paths(cfg.common_eval.path),
98
+ arg_overrides=overrides,
99
+ task=task,
100
+ suffix=cfg.checkpoint.checkpoint_suffix,
101
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
102
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
103
+ )
104
+
105
+ # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
106
+ task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
107
+
108
+ if cfg.generation.lm_path is not None:
109
+ overrides["data"] = cfg.task.data
110
+
111
+ try:
112
+ lms, _ = checkpoint_utils.load_model_ensemble(
113
+ [cfg.generation.lm_path], arg_overrides=overrides, task=None
114
+ )
115
+ except:
116
+ logger.warning(
117
+ f"Failed to load language model! Please make sure that the language model dict is the same "
118
+ f"as target dict and is located in the data dir ({cfg.task.data})"
119
+ )
120
+ raise
121
+
122
+ assert len(lms) == 1
123
+ else:
124
+ lms = [None]
125
+
126
+ # Optimize ensemble for generation
127
+ for model in chain(models, lms):
128
+ if model is None:
129
+ continue
130
+ if cfg.common.fp16:
131
+ model.half()
132
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
133
+ model.cuda()
134
+ model.prepare_for_inference_(cfg)
135
+
136
+ # Load alignment dictionary for unknown word replacement
137
+ # (None if no unknown word replacement, empty if no path to align dictionary)
138
+ align_dict = utils.load_align_dict(cfg.generation.replace_unk)
139
+
140
+ # Load dataset (possibly sharded)
141
+ itr = task.get_batch_iterator(
142
+ dataset=task.dataset(cfg.dataset.gen_subset),
143
+ max_tokens=cfg.dataset.max_tokens,
144
+ max_sentences=cfg.dataset.batch_size,
145
+ max_positions=utils.resolve_max_positions(
146
+ task.max_positions(), *[m.max_positions() for m in models]
147
+ ),
148
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
149
+ #required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
150
+ seed=cfg.common.seed,
151
+ num_shards=cfg.distributed_training.distributed_world_size,
152
+ shard_id=cfg.distributed_training.distributed_rank,
153
+ num_workers=cfg.dataset.num_workers,
154
+ data_buffer_size=cfg.dataset.data_buffer_size,
155
+ ).next_epoch_itr(shuffle=False)
156
+ print("Hello world", itr.n)
157
+ progress = progress_bar.progress_bar(
158
+ itr,
159
+ log_format=cfg.common.log_format,
160
+ log_interval=cfg.common.log_interval,
161
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
162
+ )
163
+
164
+ # Initialize generator
165
+ gen_timer = StopwatchMeter()
166
+
167
+ extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
168
+ generator = task.build_generator(
169
+ models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
170
+ )
171
+
172
+ # Handle tokenization and BPE
173
+ tokenizer = task.build_tokenizer(cfg.tokenizer)
174
+ bpe = task.build_bpe(cfg.bpe)
175
+
176
+ def decode_fn(x):
177
+ if bpe is not None:
178
+ x = bpe.decode(x)
179
+ if tokenizer is not None:
180
+ x = tokenizer.decode(x)
181
+ return x
182
+
183
+ scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
184
+
185
+ num_sentences = 0
186
+ has_target = True
187
+ wps_meter = TimeMeter()
188
+ for sample in progress:
189
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
190
+ if "net_input" not in sample:
191
+ continue
192
+
193
+ prefix_tokens = None
194
+ if cfg.generation.prefix_size > 0:
195
+ prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
196
+
197
+ constraints = None
198
+ if "constraints" in sample:
199
+ constraints = sample["constraints"]
200
+
201
+ gen_timer.start()
202
+ hypos = task.inference_step(
203
+ generator,
204
+ models,
205
+ sample,
206
+ prefix_tokens=prefix_tokens,
207
+ constraints=constraints,
208
+ )
209
+ # for ijkl in hypos:
210
+ # if ("tokens" not in ijkl[0]):
211
+ # print("Hello there bruh")
212
+ # print(ijkl)
213
+ # print(type(hypos))
214
+ # print(hypos[0])
215
+ #hypos = [ijkl for ijkl in hypos if ijkl != []]
216
+ num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
217
+ gen_timer.stop(num_generated_tokens)
218
+
219
+ for i, sample_id in enumerate(sample["id"].tolist()):
220
+ has_target = sample["target"] is not None
221
+
222
+ # Remove padding
223
+ if "src_tokens" in sample["net_input"]:
224
+ src_tokens = utils.strip_pad(
225
+ sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
226
+ )
227
+ else:
228
+ src_tokens = None
229
+
230
+ target_tokens = None
231
+ if has_target:
232
+ target_tokens = (
233
+ utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
234
+ )
235
+
236
+ # Either retrieve the original sentences or regenerate them from tokens.
237
+ if align_dict is not None:
238
+ src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
239
+ sample_id
240
+ )
241
+ target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
242
+ sample_id
243
+ )
244
+ else:
245
+ if src_dict is not None:
246
+ src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
247
+ else:
248
+ src_str = ""
249
+ if has_target:
250
+ target_str = tgt_dict.string(
251
+ target_tokens,
252
+ cfg.common_eval.post_process,
253
+ escape_unk=True,
254
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(
255
+ generator
256
+ ),
257
+ )
258
+
259
+ src_str = decode_fn(src_str)
260
+ if has_target:
261
+ target_str = decode_fn(target_str)
262
+
263
+ if not cfg.common_eval.quiet:
264
+ if src_dict is not None:
265
+ print("S-{}\t{}".format(sample_id, src_str), file=output_file)
266
+ if has_target:
267
+ print("T-{}\t{}".format(sample_id, target_str), file=output_file)
268
+
269
+ # Process top predictions
270
+
271
+ for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
272
+ hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
273
+ hypo_tokens=hypo["tokens"].int().cpu(),
274
+ src_str=src_str,
275
+ alignment=hypo["alignment"],
276
+ align_dict=align_dict,
277
+ tgt_dict=tgt_dict,
278
+ remove_bpe=cfg.common_eval.post_process,
279
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
280
+ )
281
+ detok_hypo_str = decode_fn(hypo_str)
282
+ if not cfg.common_eval.quiet:
283
+ score = hypo["score"] / math.log(2) # convert to base 2
284
+ # original hypothesis (after tokenization and BPE)
285
+ print(
286
+ "H-{}\t{}\t{}".format(sample_id, score, hypo_str),
287
+ file=output_file,
288
+ )
289
+ # detokenized hypothesis
290
+ print(
291
+ "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str),
292
+ file=output_file,
293
+ )
294
+ print(
295
+ "P-{}\t{}".format(
296
+ sample_id,
297
+ " ".join(
298
+ map(
299
+ lambda x: "{:.4f}".format(x),
300
+ # convert from base e to base 2
301
+ hypo["positional_scores"]
302
+ .div_(math.log(2))
303
+ .tolist(),
304
+ )
305
+ ),
306
+ ),
307
+ file=output_file,
308
+ )
309
+
310
+ if cfg.generation.print_alignment == "hard":
311
+ print(
312
+ "A-{}\t{}".format(
313
+ sample_id,
314
+ " ".join(
315
+ [
316
+ "{}-{}".format(src_idx, tgt_idx)
317
+ for src_idx, tgt_idx in alignment
318
+ ]
319
+ ),
320
+ ),
321
+ file=output_file,
322
+ )
323
+ if cfg.generation.print_alignment == "soft":
324
+ print(
325
+ "A-{}\t{}".format(
326
+ sample_id,
327
+ " ".join(
328
+ [",".join(src_probs) for src_probs in alignment]
329
+ ),
330
+ ),
331
+ file=output_file,
332
+ )
333
+
334
+ if cfg.generation.print_step:
335
+ print(
336
+ "I-{}\t{}".format(sample_id, hypo["steps"]),
337
+ file=output_file,
338
+ )
339
+
340
+ if cfg.generation.retain_iter_history:
341
+ for step, h in enumerate(hypo["history"]):
342
+ _, h_str, _ = utils.post_process_prediction(
343
+ hypo_tokens=h["tokens"].int().cpu(),
344
+ src_str=src_str,
345
+ alignment=None,
346
+ align_dict=None,
347
+ tgt_dict=tgt_dict,
348
+ remove_bpe=None,
349
+ )
350
+ print(
351
+ "E-{}_{}\t{}".format(sample_id, step, h_str),
352
+ file=output_file,
353
+ )
354
+
355
+ # Score only the top hypothesis
356
+ if has_target and j == 0:
357
+ if (
358
+ align_dict is not None
359
+ or cfg.common_eval.post_process is not None
360
+ ):
361
+ # Convert back to tokens for evaluation with unk replacement and/or without BPE
362
+ target_tokens = tgt_dict.encode_line(
363
+ target_str, add_if_not_exist=True
364
+ )
365
+ hypo_tokens = tgt_dict.encode_line(
366
+ detok_hypo_str, add_if_not_exist=True
367
+ )
368
+ if hasattr(scorer, "add_string"):
369
+ scorer.add_string(target_str, detok_hypo_str)
370
+ else:
371
+ scorer.add(target_tokens, hypo_tokens)
372
+
373
+ wps_meter.update(num_generated_tokens)
374
+ progress.log({"wps": round(wps_meter.avg)})
375
+ num_sentences += (
376
+ sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
377
+ )
378
+
379
+ logger.info("NOTE: hypothesis and token scores are output in base 2")
380
+ logger.info(
381
+ "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
382
+ num_sentences,
383
+ gen_timer.n,
384
+ gen_timer.sum,
385
+ num_sentences / gen_timer.sum,
386
+ 1.0 / gen_timer.avg,
387
+ )
388
+ )
389
+ if has_target:
390
+ if cfg.bpe and not cfg.generation.sacrebleu:
391
+ if cfg.common_eval.post_process:
392
+ logger.warning(
393
+ "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
394
+ )
395
+ else:
396
+ logger.warning(
397
+ "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization"
398
+ )
399
+ # use print to be consistent with other main outputs: S-, H-, T-, D- and so on
400
+ print(
401
+ "Generate {} with beam={}: {}".format(
402
+ cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
403
+ ),
404
+ file=output_file,
405
+ )
406
+
407
+ return scorer
408
+
409
+
410
+ def cli_main():
411
+ parser = options.get_generation_parser()
412
+ # TODO: replace this workaround with refactoring of `AudioPretraining`
413
+ parser.add_argument(
414
+ "--arch",
415
+ "-a",
416
+ metavar="ARCH",
417
+ default="wav2vec2",
418
+ help="Model architecture. For constructing tasks that rely on "
419
+ "model args (e.g. `AudioPretraining`)",
420
+ )
421
+ args = options.parse_args_and_arch(parser)
422
+ main(args)
423
+
424
+
425
+ if __name__ == "__main__":
426
+ cli_main()