ajayarora1235 commited on
Commit
4738a88
1 Parent(s): 15a72fe

test voicecraft merge

Browse files
Files changed (38) hide show
  1. app.py +184 -24
  2. lib/voicecraft/LICENSE-CODE +437 -0
  3. lib/voicecraft/LICENSE-MODEL +42 -0
  4. lib/voicecraft/README.md +160 -0
  5. lib/voicecraft/config.py +86 -0
  6. lib/voicecraft/data/__init__.py +0 -0
  7. lib/voicecraft/data/gigaspeech.py +156 -0
  8. lib/voicecraft/data/phonemize_encodec_encode_hf.py +206 -0
  9. lib/voicecraft/data/tokenizer.py +149 -0
  10. lib/voicecraft/demo/84_121550_000074_000000.wav +0 -0
  11. lib/voicecraft/demo/temp/84_121550_000074_000000.txt +1 -0
  12. lib/voicecraft/demo/temp/mfa_alignments/84_121550_000074_000000.csv +109 -0
  13. lib/voicecraft/edit_utils.py +49 -0
  14. lib/voicecraft/environment.yml +417 -0
  15. lib/voicecraft/inference_speech_editing.ipynb +0 -0
  16. lib/voicecraft/inference_speech_editing_scale.py +226 -0
  17. lib/voicecraft/inference_tts.ipynb +312 -0
  18. lib/voicecraft/inference_tts_scale.py +190 -0
  19. lib/voicecraft/main.py +45 -0
  20. lib/voicecraft/models/codebooks_patterns.py +538 -0
  21. lib/voicecraft/models/modules/__init__.py +0 -0
  22. lib/voicecraft/models/modules/activation.py +653 -0
  23. lib/voicecraft/models/modules/embedding.py +98 -0
  24. lib/voicecraft/models/modules/sampling.py +63 -0
  25. lib/voicecraft/models/modules/scaling.py +1406 -0
  26. lib/voicecraft/models/modules/transformer.py +698 -0
  27. lib/voicecraft/models/modules/utils.py +37 -0
  28. lib/voicecraft/models/voicecraft.py +1406 -0
  29. lib/voicecraft/pretrained_models/.gitkeep +0 -0
  30. lib/voicecraft/start-jupyter.bat +29 -0
  31. lib/voicecraft/start-jupyter.sh +21 -0
  32. lib/voicecraft/steps/__init__.py +0 -0
  33. lib/voicecraft/steps/optim.py +1123 -0
  34. lib/voicecraft/steps/trainer.py +467 -0
  35. lib/voicecraft/steps/trainer_utils.py +628 -0
  36. lib/voicecraft/z_scripts/e830M.sh +71 -0
  37. requirements.txt +12 -1
  38. run.sh +4 -1
app.py CHANGED
@@ -1,4 +1,15 @@
1
  import subprocess, torch, os, traceback, sys, warnings, shutil, numpy as np
 
 
 
 
 
 
 
 
 
 
 
2
  from mega import Mega
3
  os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
4
  import threading
@@ -1463,6 +1474,118 @@ def ilariaTTS(text, ttsvoice):
1463
  aud_path = save_to_wav('./temp_ilaria.mp3')
1464
  return aud_path, aud_path
1465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1466
  def upload_to_dataset(files, dir):
1467
  if dir == '':
1468
  dir = './dataset'
@@ -1561,30 +1684,67 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="pink", secondary_hue="rose")
1561
  refresh_button2 = gr.Button("Refresh", variant="primary", size='sm')
1562
  record_button.change(fn=save_to_wav, inputs=[record_button], outputs=[input_audio0])
1563
  record_button.change(fn=change_choices2, inputs=[], outputs=[input_audio0])
1564
- # with gr.Row():
1565
- # with gr.Accordion('ElevenLabs / Google TTS', open=False):
1566
- # with gr.Column():
1567
- # lang = gr.Radio(label='Chinese & Japanese do not work with ElevenLabs currently.',choices=['en','it','es','fr','pt','zh-CN','de','hi','ja'], value='en')
1568
- # api_box = gr.Textbox(label="Enter your API Key for ElevenLabs, or leave empty to use GoogleTTS", value='')
1569
- # elevenid=gr.Dropdown(label="Voice:", choices=eleven_voices)
1570
- # with gr.Column():
1571
- # tfs = gr.Textbox(label="Input your Text", interactive=True, value="This is a test.")
1572
- # tts_button = gr.Button(value="Speak")
1573
- # tts_button.click(fn=elevenTTS, inputs=[api_box,tfs, elevenid, lang], outputs=[record_button, input_audio0])
1574
- # with gr.Row():
1575
- # with gr.Accordion('Wav2Lip', open=False, visible=False):
1576
- # with gr.Row():
1577
- # size = gr.Radio(label='Resolution:',choices=['Half','Full'])
1578
- # face = gr.UploadButton("Upload A Character",type='file')
1579
- # faces = gr.Dropdown(label="OR Choose one:", choices=['None','Ben Shapiro','Andrew Tate'])
1580
- # with gr.Row():
1581
- # preview = gr.Textbox(label="Status:",interactive=False)
1582
- # face.upload(fn=success_message,inputs=[face], outputs=[preview, faces])
1583
- # with gr.Row():
1584
- # animation = gr.Video(type='filepath')
1585
- # refresh_button2.click(fn=change_choices2, inputs=[], outputs=[input_audio0, animation])
1586
- # with gr.Row():
1587
- # animate_button = gr.Button('Animate')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1588
 
1589
  with gr.Column():
1590
  vc_output2 = gr.Audio(
 
1
  import subprocess, torch, os, traceback, sys, warnings, shutil, numpy as np
2
+
3
+ import pandas as pd
4
+ import torchaudio
5
+ from lib.voicecraft.data.tokenizer import (
6
+ AudioTokenizer,
7
+ TextTokenizer,
8
+ )
9
+ import whisper
10
+ import os
11
+ import time
12
+
13
  from mega import Mega
14
  os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
15
  import threading
 
1474
  aud_path = save_to_wav('./temp_ilaria.mp3')
1475
  return aud_path, aud_path
1476
 
1477
+ def transcribe_btn_click(model_choice, audio_choice, transcribed_text):
1478
+ model = whisper.load_model(model_choice) # pass the value of model_choice to whisper.load_model()
1479
+ result = model.transcribe(audio_choice) # pass the value of audio_choice to model.transcribe()
1480
+ print("transcribe text: " + result["text"])
1481
+
1482
+ # point to the original file or record the file
1483
+ # write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file
1484
+ orig_audio = audio_choice
1485
+ orig_transcript = result["text"]
1486
+ # move the audio and transcript to temp folder
1487
+ temp_folder = "./demo/temp"
1488
+ os.makedirs(temp_folder, exist_ok=True)
1489
+ os.system(f"cp {orig_audio} {temp_folder}")
1490
+ filename = os.path.splitext(orig_audio.split("/")[-1])[0]
1491
+ with open(f"{temp_folder}/{filename}.txt", "w") as f:
1492
+ f.write(orig_transcript)
1493
+ # run MFA to get the alignment
1494
+ align_temp = f"{temp_folder}/mfa_alignments"
1495
+ os.makedirs(align_temp, exist_ok=True)
1496
+
1497
+ if os.path.exists(f"{align_temp}/{filename}.csv"):
1498
+ pass
1499
+ print("mfa.cvs file exists already")
1500
+ else:
1501
+ print(align_temp + " is None")
1502
+ os.system(f"mfa align -j 1 --output_format csv --clean {temp_folder} english_us_arpa english_us_arpa {align_temp}")
1503
+
1504
+
1505
+ # if the above fails, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue
1506
+ # or try a larger model
1507
+ # os.system(f"mfa align -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000")
1508
+ print("yes")
1509
+ global audio_fn
1510
+ audio_fn = f"{temp_folder}/{filename}.wav"
1511
+ global transcript_fn
1512
+ transcript_fn = f"{temp_folder}/{filename}.txt"
1513
+ global align_fn
1514
+ align_fn = f"{align_temp}/{filename}.csv"
1515
+
1516
+ df = pd.read_csv(align_fn)
1517
+ # Select the first three columns
1518
+ df = df.iloc[:, :3]
1519
+
1520
+ # Convert DataFrame to HTML
1521
+ html = df.to_html(index=False)
1522
+
1523
+ return [result["text"], html]
1524
+
1525
+
1526
+ def run(seed, stop_repetition, sample_batch_size, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
1527
+ temperature, kvcache, cutoff_value, target_transcript, silence_tokens):
1528
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
1529
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
1530
+ # take a look at demo/temp/mfa_alignment, decide which part of the audio to use as prompt
1531
+ cut_off_sec = cutoff_value # NOTE: according to forced-alignment file, the word "common" stop as 3.01 sec, this should be different for different audio
1532
+ target_transcript = target_transcript
1533
+ info = torchaudio.info(audio_fn)
1534
+ audio_dur = info.num_frames / info.sample_rate
1535
+
1536
+ assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
1537
+ prompt_end_frame = int(cut_off_sec * info.sample_rate)
1538
+
1539
+ # # load model, tokenizer, and other necessary files
1540
+ # # original file loaded it each time. here we load it only once
1541
+ # global model_loaded
1542
+ # f model_loaded==False:
1543
+ from lib.voicecraft.models import voicecraft
1544
+ voicecraft_name = "giga830M.pth"
1545
+ ckpt_fn = f"./pretrained_models/{voicecraft_name}"
1546
+ encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
1547
+ if not os.path.exists(ckpt_fn):
1548
+ os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
1549
+ os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
1550
+ if not os.path.exists(encodec_fn):
1551
+ os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
1552
+ os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
1553
+
1554
+ ckpt = torch.load(ckpt_fn, map_location="cpu")
1555
+ model = voicecraft.VoiceCraft(ckpt["config"])
1556
+ model.load_state_dict(ckpt["model"])
1557
+ model.to(config.device)
1558
+ model.eval()
1559
+
1560
+ phn2num = ckpt['phn2num']
1561
+
1562
+ text_tokenizer = TextTokenizer(backend="espeak")
1563
+ audio_tokenizer = AudioTokenizer(signature=encodec_fn) # will also put the neural codec model on gpu
1564
+
1565
+ # # run the model to get the output
1566
+ decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition,
1567
+ 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
1568
+ "silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size}
1569
+ from lib.voicecraft.inference_tts_scale import inference_one_sample
1570
+ concated_audio, gen_audio = inference_one_sample(model, ckpt["config"], phn2num, text_tokenizer, audio_tokenizer,
1571
+ audio_fn, target_transcript, config.device, decode_config,
1572
+ prompt_end_frame)
1573
+
1574
+ # save segments for comparison
1575
+ concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
1576
+ # logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}")
1577
+
1578
+ output_dir = "./demo/generated_tts"
1579
+ os.makedirs(output_dir, exist_ok=True)
1580
+ seg_save_fn_gen = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_gen_seed{seed}.wav"
1581
+ seg_save_fn_concat = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_concat_seed{seed}.wav"
1582
+
1583
+
1584
+ torchaudio.save(seg_save_fn_gen, gen_audio, int(codec_audio_sr))
1585
+ torchaudio.save(seg_save_fn_concat, concated_audio, int(codec_audio_sr))
1586
+
1587
+ return [seg_save_fn_concat, seg_save_fn_gen]
1588
+
1589
  def upload_to_dataset(files, dir):
1590
  if dir == '':
1591
  dir = './dataset'
 
1684
  refresh_button2 = gr.Button("Refresh", variant="primary", size='sm')
1685
  record_button.change(fn=save_to_wav, inputs=[record_button], outputs=[input_audio0])
1686
  record_button.change(fn=change_choices2, inputs=[], outputs=[input_audio0])
1687
+
1688
+ with gr.Row():
1689
+ with gr.Column():
1690
+ input_audio = gr.Audio(label="Input Audio", type="filepath")
1691
+ transcribe_btn_model = gr.Radio(value="base.en", interactive=True, label="what whisper model to download",
1692
+ choices=["tiny.en", "base.en", "small.en", "medium.en", "large"],
1693
+ info="VRAM usage: tiny.en 1 GB, base.en 1GB, small.en 2GB, medium.en 5GB, large 10GB.")
1694
+ transcribed_text = gr.Textbox(label="transcibed text + mfa",
1695
+ info="write down the transcript for the file, or run whisper model to get the transcript. Takes time to download whisper models on first run")
1696
+ transcribe_info_text = gr.TextArea(label="How to use",
1697
+ value="running everything for the first time will download necessary models (4GB for main encoder + model) \n load a voice and choose your whisper model, base works most of the time. \n transcription and mfa takes ~50s on a 3090 for a 7s audio clip, rerun this when uploading a new audio clip only\nchoose the END value of the cut off word \n")
1698
+ transcribe_btn = gr.Button(value="transcribe and create mfa")
1699
+ seed = gr.Number(label='seed', interactive=True, value=1)
1700
+ stop_repitition = gr.Radio(label="stop_repitition", interactive=True, choices=[1, 2, 3], value=3,
1701
+ info="if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1")
1702
+ sample_batch_size = gr.Radio(label="sample_batch_size", interactive=True, choices=[4, 3, 2], value=4,
1703
+ info="if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4")
1704
+ left_margin = gr.Number(label='left_margin', interactive=True, value=0.08, step=0.01,
1705
+ info=" not used for TTS, only for speech editing")
1706
+ right_margin = gr.Number(label='right_margin', interactive=True, value=0.08, step=0.01,
1707
+ info=" not used for TTS, only for speech editing")
1708
+ codecaudio_sr = gr.Number(label='codec_audio_sr', interactive=True, value=16000)
1709
+ codec_sr = gr.Number(label='codec', interactive=True, value=50)
1710
+ top_k = gr.Number(label='top_k', interactive=True, value=0)
1711
+ top_p = gr.Number(label='top_p', interactive=True, value=0.8)
1712
+ temperature = gr.Number(label='temperature', interactive=True, value=1)
1713
+ kvcache = gr.Number(label='kvcache', interactive=True, value=1,
1714
+ info='set to 0 to use less VRAM, results may be worse and slower inference')
1715
+ silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]")
1716
+
1717
+ with gr.Column():
1718
+ output_audio_con = gr.Audio(label="Output Audio concatenated")
1719
+ output_audio_gen = gr.Audio(label="Output Audio generated")
1720
+ cutoff_value = gr.Number(label="cutoff_time", interactive=True, step=0.01)
1721
+ run_btn = gr.Button(value="run")
1722
+ target_transcript = gr.Textbox(label="target transcript")
1723
+ cvs_file_html = gr.HTML()
1724
+
1725
+ transcribe_btn.click(fn=transcribe_btn_click, inputs=[transcribe_btn_model, input_audio, transcribed_text],
1726
+ outputs=[transcribed_text, cvs_file_html])
1727
+
1728
+ run_btn.click(fn=run,
1729
+ inputs=[
1730
+ seed,
1731
+ stop_repitition,
1732
+ sample_batch_size,
1733
+ left_margin,
1734
+ right_margin,
1735
+ codecaudio_sr,
1736
+ codec_sr,
1737
+ top_k,
1738
+ top_p,
1739
+ temperature,
1740
+ kvcache,
1741
+ cutoff_value,
1742
+ target_transcript,
1743
+ silence_tokens],
1744
+ outputs=[
1745
+ output_audio_con,
1746
+ output_audio_gen
1747
+ ])
1748
 
1749
  with gr.Column():
1750
  vc_output2 = gr.Audio(
lib/voicecraft/LICENSE-CODE ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
lib/voicecraft/LICENSE-MODEL ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Coqui Public Model License 1.0.0
2
+ https://coqui.ai/cpml.txt
3
+
4
+ This license allows only non-commercial use of a machine learning model and its outputs.
5
+
6
+ Acceptance
7
+ In order to get any license under these terms, you must agree to them as both strict obligations and conditions to all your licenses.
8
+
9
+ Licenses
10
+ The licensor grants you a copyright license to do everything you might do with the model that would otherwise infringe the licensor's copyright in it, for any non-commercial purpose. The licensor grants you a patent license that covers patent claims the licensor can license, or becomes able to license, that you would infringe by using the model in the form provided by the licensor, for any non-commercial purpose.
11
+
12
+ Non-commercial Purpose
13
+ Non-commercial purposes include any of the following uses of the model or its output, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output.
14
+
15
+ Personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, amateur pursuits, or religious observance.
16
+ Use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development. Use of the model to train other models for commercial use is not a non-commercial purpose.
17
+ Use by any charitable organization for charitable purposes, or for testing or evaluation. Use for revenue-generating activity, including projects directly funded by government grants, is not a non-commercial purpose.
18
+ Notices
19
+ You must ensure that anyone who gets a copy of any part of the model, or any modification of the model, or their output, from you also gets a copy of these terms or the URL for them above.
20
+
21
+ No Other Rights
22
+ These terms do not allow you to sublicense or transfer any of your licenses to anyone else, or prevent the licensor from granting licenses to anyone else. These terms do not imply any other licenses.
23
+
24
+ Patent Defense
25
+ If you make any written claim that the model infringes or contributes to infringement of any patent, your licenses for the model granted under these terms ends immediately. If your company makes such a claim, your patent license ends immediately for work on behalf of your company.
26
+
27
+ Violations
28
+ The first time you are notified in writing that you have violated any of these terms, or done anything with the model or its output that is not covered by your licenses, your licenses can nonetheless continue if you come into full compliance with these terms, and take practical steps to correct past violations, within 30 days of receiving notice. Otherwise, all your licenses end immediately.
29
+
30
+ No Liability
31
+ AS FAR AS THE LAW ALLOWS, THE MODEL AND ITS OUTPUT COME AS IS, WITHOUT ANY WARRANTY OR CONDITION, AND THE LICENSOR WILL NOT BE LIABLE TO YOU FOR ANY DAMAGES ARISING OUT OF THESE TERMS OR THE USE OR NATURE OF THE MODEL OR ITS OUTPUT, UNDER ANY KIND OF LEGAL CLAIM. IF THIS PROVISION IS NOT ENFORCEABLE IN YOUR JURISDICTION, YOUR LICENSES ARE VOID.
32
+
33
+ Definitions
34
+ The licensor is the individual or entity offering these terms, and the model is the model the licensor makes available under these terms, including any documentation or similar information about the model.
35
+
36
+ You refers to the individual or entity agreeing to these terms.
37
+
38
+ Your company is any legal entity, sole proprietorship, or other kind of organization that you work for, plus all organizations that have control over, are under the control of, or are under common control with that organization. Control means ownership of substantially all the assets of an entity, or the power to direct its management and policies by vote, contract, or otherwise. Control can be direct or indirect.
39
+
40
+ Your licenses are all the licenses granted to you under these terms.
41
+
42
+ Use means anything you do with the model or its output requiring one of your licenses.
lib/voicecraft/README.md ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
2
+ [Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf)
3
+
4
+
5
+ ### TL;DR
6
+ VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
7
+
8
+ To clone or edit an unseen voice, VoiceCraft needs only a few seconds of reference.
9
+
10
+ ## News
11
+ :star: 03/28/2024: Model weights are up on HuggingFace🤗 [here](https://huggingface.co/pyp1/VoiceCraft/tree/main)!
12
+
13
+ ## TODO
14
+ - [x] Codebase upload
15
+ - [x] Environment setup
16
+ - [x] Inference demo for speech editing and TTS
17
+ - [x] Training guidance
18
+ - [x] RealEdit dataset and training manifest
19
+ - [x] Model weights (both 330M and 830M, the former seems to be just as good)
20
+ - [ ] Write colab notebooks for better hands-on experience
21
+ - [ ] HuggingFace Spaces demo
22
+ - [ ] Better guidance on training/finetuning
23
+
24
+ ## How to run TTS inference
25
+ There are two ways:
26
+ 1. with docker. see [quickstart](#quickstart)
27
+ 2. without docker. see [envrionment setup](#environment-setup)
28
+
29
+ When you are inside the docker image or you have installed all dependencies, Checkout [`inference_tts.ipynb`](./inference_tts.ipynb).
30
+
31
+ If you want to do model development such as training/finetuning, I recommend following [envrionment setup](#environment-setup) and [training](#training).
32
+
33
+ ## QuickStart
34
+ :star: To try out TTS inference with VoiceCraft, the best way is using docker. Thank [@ubergarm](https://github.com/ubergarm) and [@jayc88](https://github.com/jay-c88) for making this happen.
35
+
36
+ Tested on Linux and Windows and should work with any host with docker installed.
37
+ ```bash
38
+ # 1. clone the repo on in a directory on a drive with plenty of free space
39
+ git clone [email protected]:jasonppy/VoiceCraft.git
40
+ cd VoiceCraft
41
+
42
+ # 2. assumes you have docker installed with nvidia container container-toolkit (windows has this built into the driver)
43
+ # https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/1.13.5/install-guide.html
44
+ # sudo apt-get install -y nvidia-container-toolkit-base || yay -Syu nvidia-container-toolkit || echo etc...
45
+
46
+ # 3. Try to start an existing container otherwise create a new one passing in all GPUs
47
+ ./start-jupyter.sh # linux
48
+ start-jupyter.bat # windows
49
+
50
+ # 4. now open a webpage on the host box to the URL shown at the bottom of:
51
+ docker logs jupyter
52
+
53
+ # 5. optionally look inside from another terminal
54
+ docker exec -it jupyter /bin/bash
55
+ export USER=(your_linux_username_used_above)
56
+ export HOME=/home/$USER
57
+ sudo apt-get update
58
+
59
+ # 6. confirm video card(s) are visible inside container
60
+ nvidia-smi
61
+
62
+ # 7. Now in browser, open inference_tts.ipynb and work through one cell at a time
63
+ echo GOOD LUCK
64
+ ```
65
+
66
+ ## Environment setup
67
+ ```bash
68
+ conda create -n voicecraft python=3.9.16
69
+ conda activate voicecraft
70
+
71
+ pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft
72
+ pip install xformers==0.0.22
73
+ pip install torchaudio==2.0.2 torch==2.0.1 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201
74
+ apt-get install ffmpeg # if you don't already have ffmpeg installed
75
+ apt-get install espeak-ng # backend for the phonemizer installed below
76
+ pip install tensorboard==2.16.2
77
+ pip install phonemizer==3.2.1
78
+ pip install datasets==2.16.0
79
+ pip install torchmetrics==0.11.1
80
+ # install MFA for getting forced-alignment, this could take a few minutes
81
+ conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068
82
+ # conda install pocl # above gives an warning for installing pocl, not sure if really need this
83
+
84
+ # to run ipynb
85
+ conda install -n voicecraft ipykernel --no-deps --force-reinstall
86
+ ```
87
+
88
+ If you have encountered version issues when running things, checkout [environment.yml](./environment.yml) for exact matching.
89
+
90
+ ## Inference Examples
91
+ Checkout [`inference_speech_editing.ipynb`](./inference_speech_editing.ipynb) and [`inference_tts.ipynb`](./inference_tts.ipynb)
92
+
93
+ ## Training
94
+ To train an VoiceCraft model, you need to prepare the following parts:
95
+ 1. utterances and their transcripts
96
+ 2. encode the utterances into codes using e.g. Encodec
97
+ 3. convert transcripts into phoneme sequence, and a phoneme set (we named it vocab.txt)
98
+ 4. manifest (i.e. metadata)
99
+
100
+ Step 1,2,3 are handled in [./data/phonemize_encodec_encode_hf.py](./data/phonemize_encodec_encode_hf.py), where
101
+ 1. Gigaspeech is downloaded through HuggingFace. Note that you need to sign an agreement in order to download the dataset (it needs your auth token)
102
+ 2. phoneme sequence and encodec codes are also extracted using the script.
103
+
104
+ An example run:
105
+
106
+ ```bash
107
+ conda activate voicecraft
108
+ export CUDA_VISIBLE_DEVICES=0
109
+ cd ./data
110
+ python phonemize_encodec_encode_hf.py \
111
+ --dataset_size xs \
112
+ --download_to path/to/store_huggingface_downloads \
113
+ --save_dir path/to/store_extracted_codes_and_phonemes \
114
+ --encodec_model_path path/to/encodec_model \
115
+ --mega_batch_size 120 \
116
+ --batch_size 32 \
117
+ --max_len 30000
118
+ ```
119
+ where encodec_model_path is avaliable [here](https://huggingface.co/pyp1/VoiceCraft). This model is trained on Gigaspeech XL, it has 56M parameters, 4 codebooks, each codebook has 2048 codes. Details are described in our [paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf). If you encounter OOM during extraction, try decrease the batch_size and/or max_len.
120
+ The extracted codes, phonemes, and vocab.txt will be stored at `path/to/store_extracted_codes_and_phonemes/${dataset_size}/{encodec_16khz_4codebooks,phonemes,vocab.txt}`.
121
+
122
+ As for manifest, please download train.txt and validation.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main), and put them under `path/to/store_extracted_codes_and_phonemes/manifest/`. Please also download vocab.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main) if you want to use our pretrained VoiceCraft model (so that the phoneme-to-token matching is the same).
123
+
124
+ Now, you are good to start training!
125
+
126
+ ```bash
127
+ conda activate voicecraft
128
+ cd ./z_scripts
129
+ bash e830M.sh
130
+ ```
131
+
132
+
133
+ ## License
134
+ The codebase is under CC BY-NC-SA 4.0 ([LICENSE-CODE](./LICENSE-CODE)), and the model weights are under Coqui Public Model License 1.0.0 ([LICENSE-MODEL](./LICENSE-MODEL)). Note that we use some of the code from other repository that are under different licenses: `./models/codebooks_patterns.py` is under MIT license; `./models/modules`, `./steps/optim.py`, `data/tokenizer.py` are under Apache License, Version 2.0; the phonemizer we used is under GNU 3.0 License.
135
+
136
+ <!-- How to use g2p to convert english text into IPA phoneme sequence
137
+ first install it with `pip install g2p`
138
+ ```python
139
+ from g2p import make_g2p
140
+ transducer = make_g2p('eng', 'eng-ipa')
141
+ transducer("hello").output_string
142
+ # it will output: 'hʌloʊ'
143
+ ``` -->
144
+
145
+ ## Acknowledgement
146
+ We thank Feiteng for his [VALL-E reproduction](https://github.com/lifeiteng/vall-e), and we thank audiocraft team for open-sourcing [encodec](https://github.com/facebookresearch/audiocraft).
147
+
148
+ ## Citation
149
+ ```
150
+ @article{peng2024voicecraft,
151
+ author = {Peng, Puyuan and Huang, Po-Yao and Li, Daniel and Mohamed, Abdelrahman and Harwath, David},
152
+ title = {VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild},
153
+ journal = {arXiv},
154
+ year = {2024},
155
+ }
156
+ ```
157
+
158
+ ## Disclaimer
159
+ Any organization or individual is prohibited from using any technology mentioned in this paper to generate or edit someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws.
160
+
lib/voicecraft/config.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def MyParser():
5
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
6
+ # general training
7
+ parser.add_argument("--seed", type=int, default=1)
8
+ parser.add_argument("--precision", type=str, default="float16")
9
+ parser.add_argument("--num_workers", type=int, default=8)
10
+ parser.add_argument("--resume", action="store_true", default=False)
11
+ parser.add_argument("--tb_write_every_n_steps", type=int, default=100)
12
+ parser.add_argument("--print_every_n_steps", type=int, default=400)
13
+ parser.add_argument("--val_every_n_steps", type=int, default=800)
14
+ parser.add_argument("--lr", type=float, default=0.05)
15
+ parser.add_argument("--batch_size", type=int, default=100, help="this is the effective batch size, no matter whether using gradient_accumulation_steps, not used if we specified max_num_tokens")
16
+ parser.add_argument("--max_num_tokens", type=int, default=100000, help="max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note this is the final effective batch size per GPU, i.e. gradient accumulated batch size per gpu")
17
+ parser.add_argument("--val_max_num_tokens", type=int, default=None, help="FOR validation")
18
+ parser.add_argument("--num_buckets", type=int, default=6, help='used for dynamic batching, bucketing the samples based on the number of tokens')
19
+ parser.add_argument("--dynamic_batching", type=int, default=0)
20
+ parser.add_argument("--weight_decay", type=float, default=1e-2)
21
+ parser.add_argument("--warmup_fraction", type=float, default=0.01, help="use linear warmup, the proportion of the training steps that are used for warming up")
22
+ parser.add_argument("--num_epochs", type=int, default=10)
23
+ parser.add_argument("--num_steps", type=int, default=None, help="if not None, will ignore n_epochs and use num_steps as the total number of amount of training, can try e.g. 400000 i.e. 400k steps")
24
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
25
+ parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="the value for torch.nn.utils.clip_grad_norm_(), not used if we use ScaledAdam optimizer")
26
+ parser.add_argument("--early_stop_step", type=int, default=3200, help="stop training after this many steps of non-improvement")
27
+ parser.add_argument("--early_stop_threshold", type=float, default=-1.0, help="early stop after the improvement is below this threshold for certain number of steps")
28
+
29
+ # optimizer focused
30
+ parser.add_argument("--optimizer_name", type=str, default="AdamW", help="can also use ScaledAdam, in which case we'll also use the Eden scheduler")
31
+ parser.add_argument("--reduce_lr_start_step", type=int, default=3000, help='after which significantly reduce the lr. a param for the eden optimizer')
32
+ parser.add_argument("--pseudo_epoch_size", type=int, default=3000, help="only use for Eden scheduler.")
33
+ parser.add_argument("--reduce_lr_start_epoch", type=int, default=4)
34
+ parser.add_argument("--clipping_update_period", type=int, default=600)
35
+
36
+
37
+ # path
38
+ parser.add_argument("--exp_dir", type=str, default=None, help="will be combined with dataset name")
39
+ parser.add_argument("--dataset", type=str, help="e.g. 'libritts', 'gigaspeech', they are folder name in the data dir also")
40
+ parser.add_argument("--dataset_dir", type=str, help="need to be compatible with corresponding dataset py file")
41
+ parser.add_argument("--phn_folder_name", type=str, default="phonemes", help="for libritts I also have arpa phns, in which case should be phonemes_arpa")
42
+ parser.add_argument("--encodec_folder_name", type=str, default="encodec_16khz_4codebooks", help="folder where encodec codes are stored")
43
+ parser.add_argument("--manifest_name", type=str, default="manifest", help="metadata filename")
44
+
45
+ # data focused
46
+ parser.add_argument("--pad_x", type=int, default=1, help="whether or not always pad x to have text_max_length. select 1 to get the maximal memory consumption, but the actual case should be smaller, better to have it being 0")
47
+ parser.add_argument("--audio_max_length", type=float, default=20, help="in second, crop or drop the audio is length is longer than this")
48
+ parser.add_argument("--audio_min_length", type=float, default=2, help="in second, drop the audio if length is shorter than this")
49
+ parser.add_argument("--text_max_length", type=int, default=400, help='if too long, we crop or drop')
50
+ parser.add_argument("--text_min_length", type=float, default=10, help="if too short, will drop")
51
+ parser.add_argument("--encodec_sr", type=int, default=50, help="for my encodec that takes 16kHz audio with a downsample rate of 320, the codec sample rate is 50Hz, i.e. 50 codes (x n_codebooks) per second")
52
+ parser.add_argument("--drop_long", type=int, default=0, help="if this is true, will drop example whose encodec sequence or phone sequence is too long, rather than cropping, to reduce hellucination")
53
+
54
+ # encodec and token rearrangement
55
+ parser.add_argument('--mask_len_min', type=int, default=1, help='Minimum mask length')
56
+ parser.add_argument('--mask_len_max', type=int, default=600, help='Maximum mask length')
57
+ parser.add_argument("--eos", type=int, default=-1, help="this is to be used with reduced_eog, where we end the utterance with eos, and end the generated segment with eog, also when this is used, the n_special should be 4")
58
+ parser.add_argument("--reduced_eog", type=int, default=0, help="for the non-final segments, do not insert eog at the end, this could hopefully solve the early stopping issue when doing tts")
59
+ parser.add_argument("--special_first", type=int, default=0, help="if 1, need to have special tokens to be the first few tokens, e.g. 0, 1, 2, which means we need to adjust the preprocessing and postprocessing of the encodec codes. note that we hard coded to have 3 special tokens")
60
+ parser.add_argument("--n_special", type=int, default=3, help="empty, eog, pad, (eos)")
61
+ parser.add_argument("--codebook_weight", type=str, default=None, help="e.g. ['5','1','0.5','0.1']")
62
+ parser.add_argument("--max_mask_portion",type=float,default=0.7,help="should mask a utterance for more than this portion")
63
+ parser.add_argument("--max_n_spans", type=int, default=3, help='maximal number of spans, only use when using multicm3, this is used to decide number of mask_embedding, and max clamp value if use Poisson distribution, if use uniform distribution to sample number of spans if will be uniform(1,max_n_spans)')
64
+ parser.add_argument("--shuffle_mask_embedding", type=int, default=0, help="whether shuffle the mask embedding, so that mask:0 is not the most well trained, default is not shuffling. The default has it's benefit, as it make sure that mask:0 always appear the first")
65
+ parser.add_argument("--mask_sample_dist", type=str, default="poisson1", help="uniform or poissonx, e.g. poisson1, meaning the parameter lambda is 1, it will most likely sample 1 masks")
66
+ parser.add_argument("--min_gap", type=int, default=5, help="after sampled starts, delete later one if it closer to the former start than the min_gap")
67
+ parser.add_argument('--n_codebooks', type=int, default=4)
68
+ parser.add_argument('--text_vocab_size', type=int, default=100, help='Size of text vocabulary')
69
+ parser.add_argument('--text_pad_token', type=int, default=100, help='padding of the text tokens, not attended')
70
+ parser.add_argument('--audio_vocab_size', type=str, default='2048', help="Size of audio vocabulary")
71
+ parser.add_argument("--empty_token", default=2048, type=int, help="indicating the no token at the position for the codebook")
72
+ parser.add_argument('--eog', type=int, default=2049, help='End of generation token')
73
+ parser.add_argument('--audio_pad_token', type=int, default=2050, help='padding of the encodec codes, not attended')
74
+
75
+ # model focused
76
+ parser.add_argument('--d_model', type=int, default=2048, help='Model dimension')
77
+ parser.add_argument('--audio_embedding_dim', type=int, default=2048, help='dimension for encodec continues embedding (before being quantized)')
78
+ parser.add_argument('--text_embedding_dropout', type=float, default=0.1, help='Dropout for text embedding')
79
+ parser.add_argument('--audio_embedding_dropout', type=float, default=0, help='Dropout for audio embedding')
80
+ parser.add_argument('--text_positional_embedding_dropout', type=float, default=0.1, help='Dropout for text positional embedding')
81
+ parser.add_argument('--audio_positional_embedding_dropout', type=float, default=0.1, help='Dropout for audio positional embedding')
82
+ parser.add_argument('--trm_dropout', type=float, default=0.1, help='Dropout for transformer')
83
+ parser.add_argument('--nhead', type=int, default=16, help='Number of attention heads')
84
+ parser.add_argument('--num_decoder_layers', type=int, default=16, help='Number of decoder layers')
85
+ parser.add_argument('--load_model_from', type=str, default=None, help='Path to load model from, this will be effective last, so will overwrite all previous load, including resume')
86
+ return parser
lib/voicecraft/data/__init__.py ADDED
File without changes
lib/voicecraft/data/gigaspeech.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import copy
5
+ import logging
6
+ import shutil
7
+
8
+ class dataset(torch.utils.data.Dataset):
9
+ def __init__(self, args, split):
10
+ super().__init__()
11
+ self.args = args
12
+ self.split = split
13
+ assert self.split in ['train', 'validation', 'test']
14
+ manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt")
15
+
16
+ with open(manifest_fn, "r") as rf:
17
+ data = [l.strip().split("\t") for l in rf.readlines()]
18
+ lengths_list = [int(item[-1]) for item in data]
19
+ self.data = []
20
+ self.lengths_list = []
21
+ for d, l in zip(data, lengths_list):
22
+ if l >= self.args.encodec_sr*self.args.audio_min_length:
23
+ if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
24
+ continue
25
+ self.data.append(d)
26
+ self.lengths_list.append(l)
27
+ logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}")
28
+
29
+ # phoneme vocabulary
30
+ vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt")
31
+ shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt"))
32
+ with open(vocab_fn, "r") as f:
33
+ temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0]
34
+ self.phn2num = {item[1]:int(item[0]) for item in temp}
35
+
36
+ self.symbol_set = set(["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"])
37
+
38
+ def __len__(self):
39
+ return len(self.lengths_list)
40
+
41
+ def _load_phn_enc(self, index):
42
+ item = self.data[index]
43
+ pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
44
+ ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt")
45
+ try:
46
+ with open(pf, "r") as p, open(ef, "r") as e:
47
+ phns = [l.strip() for l in p.readlines()]
48
+ assert len(phns) == 1, phns
49
+ x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"], as they are not in training set annotation
50
+ encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
51
+
52
+ assert len(encos) == self.args.n_codebooks, ef
53
+ if self.args.special_first:
54
+ y = [[int(n)+self.args.n_special for n in l] for l in encos]
55
+ else:
56
+ y = [[int(n) for n in l] for l in encos]
57
+ except Exception as e:
58
+ logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
59
+ logging.info(f"error message: {e}")
60
+ return [], [[]]
61
+
62
+ return x, y
63
+
64
+ def __getitem__(self, index):
65
+ x, y = self._load_phn_enc(index)
66
+ x_len, y_len = len(x), len(y[0])
67
+
68
+ if x_len == 0 or y_len == 0:
69
+ return {
70
+ "x": None,
71
+ "x_len": None,
72
+ "y": None,
73
+ "y_len": None,
74
+ "y_mask_interval": None, # index y_mask_interval[1] is the position of start_of_continue token
75
+ "extra_mask_start": None # this is only used in VE1
76
+ }
77
+ while y_len < self.args.encodec_sr*self.args.audio_min_length:
78
+ assert not self.args.dynamic_batching
79
+ index = random.choice(range(len(self))) # regenerate an index
80
+ x, y = self._load_phn_enc(index)
81
+ x_len, y_len = len(x), len(y[0])
82
+ if self.args.drop_long:
83
+ while x_len > self.args.text_max_length or y_len > self.args.encodec_sr*self.args.audio_max_length:
84
+ index = random.choice(range(len(self))) # regenerate an index
85
+ x, y = self._load_phn_enc(index)
86
+ x_len, y_len = len(x), len(y[0])
87
+
88
+ ### padding and cropping below ###
89
+ ### padding and cropping below ###
90
+ # adjust the length of encodec codes, pad to max_len or randomly crop
91
+ orig_y_len = copy.copy(y_len)
92
+ max_len = int(self.args.audio_max_length * self.args.encodec_sr)
93
+ if y_len > max_len:
94
+ audio_start = random.choice(range(0, y_len-max_len))
95
+ for i in range(len(y)):
96
+ y[i] = y[i][audio_start:(audio_start+max_len)]
97
+ y_len = max_len
98
+ else:
99
+ audio_start = 0
100
+ if not self.args.dynamic_batching:
101
+ pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
102
+ for i in range(len(y)):
103
+ y[i] = y[i] + pad
104
+
105
+ # adjust text
106
+ # if audio is cropped, and text is longer than max, crop max based on how audio is cropped
107
+ if audio_start > 0 and len(x) > self.args.text_max_length: # if audio is longer than max and text is long than max, start text the way audio started
108
+ x = x[int(len(x)*audio_start/orig_y_len):]
109
+ if len(x) > self.args.text_max_length: # if text is still longer than max, cut the end
110
+ x = x[:self.args.text_max_length]
111
+
112
+ x_len = len(x)
113
+ if x_len > self.args.text_max_length:
114
+ text_start = random.choice(range(0, x_len - self.args.text_max_length))
115
+ x = x[text_start:text_start+self.args.text_max_length]
116
+ x_len = self.args.text_max_length
117
+ elif self.args.pad_x and x_len <= self.args.text_max_length:
118
+ pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
119
+ x = x + pad
120
+ ### padding and cropping above ###
121
+ ### padding and cropping above ###
122
+
123
+ return {
124
+ "x": torch.LongTensor(x),
125
+ "x_len": x_len,
126
+ "y": torch.LongTensor(y),
127
+ "y_len": y_len
128
+ }
129
+
130
+
131
+ def collate(self, batch):
132
+ out = {key:[] for key in batch[0]}
133
+ for item in batch:
134
+ if item['x'] == None: # deal with load failure
135
+ continue
136
+ for key, val in item.items():
137
+ out[key].append(val)
138
+ res = {}
139
+ if self.args.pad_x:
140
+ res["x"] = torch.stack(out["x"], dim=0)
141
+ else:
142
+ res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
143
+ res["x_lens"] = torch.LongTensor(out["x_len"])
144
+ if self.args.dynamic_batching:
145
+ if out['y'][0].ndim==2:
146
+ res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
147
+ res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
148
+ else:
149
+ assert out['y'][0].ndim==1, out['y'][0].shape
150
+ res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
151
+ else:
152
+ res['y'] = torch.stack(out['y'], dim=0)
153
+ res["y_lens"] = torch.LongTensor(out["y_len"])
154
+ res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
155
+ res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
156
+ return res
lib/voicecraft/data/phonemize_encodec_encode_hf.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ def parse_args():
3
+ parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
4
+ parser.add_argument("--dataset_size", type=str, default='xs', help='sizes of gigaspeech, xs, s, m, l, xl. we use xl for VoiceCraft training, xs is good for debugging')
5
+ parser.add_argument('--download_to', type=str, default="/data/scratch/pyp/datasets/gigaspeech_debug", help="dir where you want the huggingface gigaspeech dataset to be downloaded to")
6
+ parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest_debug", help="path to the manifest, phonemes, and encodec codes dirs")
7
+ parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
8
+ parser.add_argument('--n_workers', type=int, default=4, help="Number of parallel worker processes")
9
+ parser.add_argument('--mega_batch_size', type=int, default=100, help="Number of samples in each mega batch for multiprocess dataloading")
10
+ parser.add_argument('--batch_size', type=int, default=4, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
11
+ parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
12
+ parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
13
+ parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
14
+ parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
15
+ parser.add_argument('--max_len', type=int, default=30000, help='max length of audio in samples, if exceed, will cut a batch into half to process, decrease this number if OOM on your machine')
16
+ return parser.parse_args()
17
+ if __name__ == "__main__":
18
+ import logging
19
+ formatter = (
20
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
21
+ )
22
+ logging.basicConfig(format=formatter, level=logging.INFO)
23
+ args = parse_args()
24
+
25
+ import os
26
+ import numpy as np
27
+ import torch
28
+ import tqdm
29
+ import time
30
+ from datasets import load_dataset, DownloadConfig
31
+
32
+ from tokenizer import TextTokenizer, tokenize_text
33
+
34
+ # get the path
35
+ phn_save_root = os.path.join(args.save_dir, args.dataset_size, "phonemes")
36
+ codes_save_root = os.path.join(args.save_dir, args.dataset_size, "encodec_16khz_4codebooks")
37
+ vocab_fn = os.path.join(args.save_dir, args.dataset_size, "vocab.txt")
38
+ os.makedirs(phn_save_root, exist_ok=True)
39
+ os.makedirs(codes_save_root, exist_ok=True)
40
+
41
+
42
+ def sort_by_audio_len(lens):
43
+ inds = np.argsort(lens).tolist()
44
+ logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.")
45
+ logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.")
46
+ logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.")
47
+ logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.")
48
+ return inds[::-1]
49
+
50
+ def write_array_to_txt_file(array, filename):
51
+ with open(filename, 'w') as f:
52
+ for a in array[:-1]:
53
+ f.write(' '.join(map(str, a))+'\n')
54
+ f.write(' '.join(map(str, array[-1])))
55
+
56
+
57
+ ### phonemization
58
+ # load tokenizer
59
+ # load the encodec model
60
+ from audiocraft.solvers import CompressionSolver
61
+ model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
62
+ model = model.cuda()
63
+ model = model.eval()
64
+ text_tokenizer = TextTokenizer()
65
+
66
+
67
+ # https://github.com/SpeechColab/GigaSpeech
68
+ # there are only four different punctuations
69
+ # need to check whether there are other < started strings
70
+ punc2sym = {" <COMMA>": ",", " <PERIOD>": ".", " <QUESTIONMARK>": "?", " <EXCLAMATIONPOINT>": "!"} # note the space in front of each punc name
71
+ gar2sym = {"<SIL>": "#%#", "<MUSIC>": "##%", "<NOISE>": "%%#", "<OTHER>":"%#%"} # so that they are savely keep as the original sym when using tokenize_text
72
+ punc2sym.update(gar2sym)
73
+
74
+ word2sym = { "h æ ʃ h ɐ ʃ p ɚ s ɛ n t": "<MUSIC>", "h æ ʃ p ɚ s ɛ n t h æ ʃ": "<SIL>", "p ɚ s ɛ n t h ɐ ʃ p ɚ s ɛ n t": "<OTHER>", "p ɚ s ɛ n t p ɚ s ɛ n t h æ ʃ": "<NOISE>"}
75
+ forbidden_words = set(['#%#', '##%', '%%#', '%#%'])
76
+
77
+ dc = DownloadConfig(cache_dir=args.download_to)
78
+ stime = time.time()
79
+ logging.info("loading the dataset...")
80
+ gs = load_dataset("speechcolab/gigaspeech", args.dataset_size, use_auth_token=True, cache_dir = args.download_to, download_config=dc)
81
+ logging.info(f"time spend on loading the dataset: {time.time() - stime:.2f} seconds")
82
+
83
+ splits = ['validation', 'test', 'train']
84
+
85
+ logging.info(f"gigaspeech dataset {args.dataset_size} info: {gs}")
86
+ logging.info(f"phonemizing...")
87
+ phn_vocab = set()
88
+ all_lens = []
89
+
90
+ # you will see a ton of [WARNING] words_mismatch.py:88......, it's not a issue
91
+ for split in tqdm.tqdm(splits):
92
+ skip = 0
93
+ logging.info(f"now processing split {split}...")
94
+ for item in tqdm.tqdm(gs[split]):
95
+ save_fn = os.path.join(phn_save_root, item['segment_id']+".txt")
96
+ text = item['text']
97
+ if sum(word in forbidden_words for word in text.split(" ")):
98
+ logging.info(f"skip {item['segment_id']}, because it contains forbiden words. It's transcript: {text}")
99
+ skip += 1
100
+ continue
101
+ for k, v in punc2sym.items():
102
+ text = text.replace(k, v)
103
+ phn = tokenize_text(text_tokenizer, text)
104
+ phn_seq = " ".join(phn)
105
+ for k, v in word2sym.items():
106
+ phn_seq = phn_seq.replace(k, v)
107
+ phn_vocab.update(phn_seq.split(" "))
108
+ all_lens.append(len(phn_seq.split(" ")))
109
+ with open(save_fn, "w") as f:
110
+ f.write(phn_seq)
111
+ logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
112
+
113
+ print(f"phn vocab size: {len(list(phn_vocab))}")
114
+ print("phn sequence stats: ")
115
+ print(f"longest: {max(all_lens)}")
116
+ print(f"shortest: {min(all_lens)}")
117
+ print(f"median: {np.quantile(all_lens, 0.5)}")
118
+ print(f"95 percentile longest: {np.quantile(all_lens, 0.95)}")
119
+ print("write vocabulary to ", vocab_fn)
120
+ with open(vocab_fn, "w") as f:
121
+ for i, phn in enumerate(list(phn_vocab)):
122
+ if i < len(list(phn_vocab)) - 1:
123
+ f.write(f"{str(i)} {phn}\n")
124
+ else:
125
+ f.write(f"{str(i)} {phn}")
126
+
127
+ class mydataset(torch.utils.data.Dataset):
128
+ def __init__(self, split):
129
+ super().__init__()
130
+ self.data = gs[split]
131
+ def __len__(self):
132
+ return len(self.data)
133
+ def __getitem__(self, ind):
134
+ try:
135
+ segment_id, audio, sr, text, begin_time, end_time = self.data[ind]['segment_id'], torch.from_numpy(self.data[ind]['audio']['array']).float(), self.data[ind]['audio']['sampling_rate'], self.data[ind]['text'], self.data[ind]['begin_time'], self.data[ind]['end_time']
136
+ except:
137
+ return None, None, None, None, None, None
138
+
139
+ return segment_id, audio, sr, text, begin_time, end_time
140
+ def collate(self, batch):
141
+ res = {'segment_id': [], "audio": [], "sr": [], "text": [], "begin_time": [], "end_time": []}
142
+ for item in batch:
143
+ if item[0] != None:
144
+ res['segment_id'].append(item[0])
145
+ res['audio'].append(item[1])
146
+ res['sr'].append(item[2])
147
+ res['text'].append(item[3])
148
+ res['begin_time'].append(item[4])
149
+ res['end_time'].append(item[5])
150
+ return res
151
+
152
+
153
+ ## encodec codes extraction
154
+ logging.info("encodec encoding...")
155
+ train_dataset = mydataset('train')
156
+ train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
157
+ validation_dataset = mydataset('validation')
158
+ validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
159
+ test_dataset = mydataset('test')
160
+ test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
161
+ splits = ['validation', 'test', 'train']
162
+ loaders = [validation_loader, test_loader, train_loader]
163
+ # splits = ['validation'] # for debug
164
+ # loaders = [validation_loader]
165
+ for split, loader in zip(splits, loaders):
166
+ skip = 0
167
+ logging.info(f"now processing split {split}...")
168
+ mega_n_steps = int(np.ceil(len(gs[split]) / args.mega_batch_size))
169
+ logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {args.mega_batch_size} samples")
170
+ for m, mega_batch in enumerate(loader):
171
+ logging.info(f"====================================")
172
+ logging.info(f"====================================")
173
+ logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
174
+ lengths = np.array(mega_batch['end_time']) - np.array(mega_batch['begin_time'])
175
+ sorted_inds = sort_by_audio_len(lengths)
176
+ for j in range(len(sorted_inds))[::-1]:
177
+ if lengths[sorted_inds[j]] < 0.2 or lengths[sorted_inds[j]] > args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
178
+ skip += 1
179
+ del sorted_inds[j]
180
+
181
+ n_steps = int(np.ceil(len(sorted_inds) / args.batch_size))
182
+ for n in tqdm.tqdm(range(n_steps), disable=True):
183
+ inds_used = sorted_inds[n*args.batch_size:(n+1)*args.batch_size]
184
+ audio_batch = [mega_batch['audio'][id] for id in inds_used]
185
+ sr_batch = [mega_batch['sr'][id] for id in inds_used]
186
+ segment_id_batch = [mega_batch['segment_id'][id] for id in inds_used]
187
+ text_batch = [mega_batch['text'][id] for id in inds_used]
188
+ padded_wav = torch.nn.utils.rnn.pad_sequence(audio_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
189
+ all_lens = [lengths[id] for id in inds_used]
190
+ with torch.no_grad():
191
+ if max(all_lens) > args.max_len and len(all_lens) > 1: # NOTE decrease args.max_len if OOM, or chunk it into more than 2 forward passes
192
+ codes = []
193
+ inwav = padded_wav.cuda()
194
+ codes.append(model.encode(inwav[:len(inwav)//2])[0].cpu())
195
+ codes.append(model.encode(inwav[len(inwav)//2:])[0].cpu())
196
+ codes = torch.cat(codes, dim=0)
197
+ else:
198
+ encoded_frames = model.encode(padded_wav.cuda())
199
+ # logging.info(f"encoded_frames: {encoded_frames[0].shape}")
200
+ codes = encoded_frames[0].cpu()
201
+
202
+ for i, length in enumerate(all_lens):
203
+ save_fn = os.path.join(codes_save_root, segment_id_batch[i]+".txt")
204
+ actual_len = round(length * args.model_code_sr) # 320 is downsample rate for this model
205
+ cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
206
+ write_array_to_txt_file(cur_code, save_fn)
lib/voicecraft/data/tokenizer.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from dataclasses import asdict, dataclass
18
+ from typing import Any, Dict, List, Optional, Pattern, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torchaudio
23
+ # from lhotse.features import FeatureExtractor
24
+ # from lhotse.utils import Seconds, compute_num_frames
25
+ from phonemizer.backend import EspeakBackend
26
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
27
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
28
+ from phonemizer.punctuation import Punctuation
29
+ from phonemizer.separator import Separator
30
+
31
+
32
+
33
+ class TextTokenizer:
34
+ """Phonemize Text."""
35
+
36
+ def __init__(
37
+ self,
38
+ language="en-us",
39
+ backend="espeak",
40
+ separator=Separator(word="_", syllable="-", phone="|"),
41
+ preserve_punctuation=True,
42
+ punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
43
+ with_stress: bool = False,
44
+ tie: Union[bool, str] = False,
45
+ language_switch: LanguageSwitch = "keep-flags",
46
+ words_mismatch: WordMismatch = "ignore",
47
+ ) -> None:
48
+ phonemizer = EspeakBackend(
49
+ language,
50
+ punctuation_marks=punctuation_marks,
51
+ preserve_punctuation=preserve_punctuation,
52
+ with_stress=with_stress,
53
+ tie=tie,
54
+ language_switch=language_switch,
55
+ words_mismatch=words_mismatch,
56
+ )
57
+
58
+ self.backend = phonemizer
59
+ self.separator = separator
60
+
61
+ def to_list(self, phonemized: str) -> List[str]:
62
+ fields = []
63
+ for word in phonemized.split(self.separator.word):
64
+ # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
65
+ pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
66
+ fields.extend(
67
+ [p for p in pp if p != self.separator.phone]
68
+ + [self.separator.word]
69
+ )
70
+ assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
71
+ self.separator.phone
72
+ )
73
+ return fields[:-1]
74
+
75
+ def __call__(self, text, strip=True) -> List[List[str]]:
76
+ if isinstance(text, str):
77
+ text = [text]
78
+
79
+ phonemized = self.backend.phonemize(
80
+ text, separator=self.separator, strip=strip, njobs=1
81
+ )
82
+ return [self.to_list(p) for p in phonemized]
83
+
84
+
85
+ def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
86
+ phonemes = tokenizer([text.strip()])
87
+ return phonemes[0] # k2symbols
88
+
89
+ def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
90
+ assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
91
+ if target_channels == 1:
92
+ wav = wav.mean(0, keepdim=True)
93
+ elif target_channels == 2:
94
+ *shape, _, length = wav.shape
95
+ wav = wav.expand(*shape, target_channels, length)
96
+ elif wav.shape[0] == 1:
97
+ wav = wav.expand(target_channels, -1)
98
+ wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
99
+ return wav
100
+
101
+ class AudioTokenizer:
102
+ """EnCodec audio."""
103
+
104
+ def __init__(
105
+ self,
106
+ device: Any = None,
107
+ signature = None
108
+ ) -> None:
109
+ from audiocraft.solvers import CompressionSolver
110
+ model = CompressionSolver.model_from_checkpoint(signature)
111
+ self.sample_rate = model.sample_rate
112
+ self.channels = model.channels
113
+
114
+ if not device:
115
+ device = torch.device("cpu")
116
+ if torch.cuda.is_available():
117
+ device = torch.device("cuda:0")
118
+
119
+ self._device = device
120
+
121
+ self.codec = model.to(device)
122
+
123
+ @property
124
+ def device(self):
125
+ return self._device
126
+
127
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
128
+ codes = self.codec.encode(wav.to(self.device))
129
+ return [(codes[0], None)]
130
+
131
+ def decode(self, frames: torch.Tensor) -> torch.Tensor:
132
+ frames = frames[0][0] # [1,4,T]
133
+ return self.codec.decode(frames)
134
+
135
+
136
+
137
+ def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
138
+ # Load and pre-process the audio waveform
139
+ if offset != -1 and num_frames!=-1:
140
+ wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
141
+ else:
142
+ wav, sr = torchaudio.load(audio_path)
143
+ wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
144
+ wav = wav.unsqueeze(0)
145
+
146
+ # Extract discrete codes from EnCodec
147
+ with torch.no_grad():
148
+ encoded_frames = tokenizer.encode(wav)
149
+ return encoded_frames
lib/voicecraft/demo/84_121550_000074_000000.wav ADDED
Binary file (508 kB). View file
 
lib/voicecraft/demo/temp/84_121550_000074_000000.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,
lib/voicecraft/demo/temp/mfa_alignments/84_121550_000074_000000.csv ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Begin,End,Label,Type,Speaker
2
+ 0.03,0.18,but,words,temp
3
+ 0.18,0.32,when,words,temp
4
+ 0.32,0.48,i,words,temp
5
+ 0.48,0.64,had,words,temp
6
+ 0.64,1.19,approached,words,temp
7
+ 1.22,1.58,so,words,temp
8
+ 1.58,1.91,near,words,temp
9
+ 1.91,2.07,to,words,temp
10
+ 2.07,2.42,them,words,temp
11
+ 2.53,2.61,the,words,temp
12
+ 2.61,3.01,common,words,temp
13
+ 3.05,3.62,object,words,temp
14
+ 3.68,3.93,which,words,temp
15
+ 3.93,4.02,the,words,temp
16
+ 4.02,4.34,sense,words,temp
17
+ 4.34,4.97,deceives,words,temp
18
+ 5.04,5.54,lost,words,temp
19
+ 5.54,6.0,not,words,temp
20
+ 6.0,6.14,by,words,temp
21
+ 6.14,6.67,distance,words,temp
22
+ 6.79,7.05,any,words,temp
23
+ 7.05,7.18,of,words,temp
24
+ 7.18,7.34,its,words,temp
25
+ 7.34,7.87,marks,words,temp
26
+ 0.03,0.06,B,phones,temp
27
+ 0.06,0.09,AH1,phones,temp
28
+ 0.09,0.18,T,phones,temp
29
+ 0.18,0.23,W,phones,temp
30
+ 0.23,0.27,EH1,phones,temp
31
+ 0.27,0.32,N,phones,temp
32
+ 0.32,0.48,AY1,phones,temp
33
+ 0.48,0.49,HH,phones,temp
34
+ 0.49,0.6,AE1,phones,temp
35
+ 0.6,0.64,D,phones,temp
36
+ 0.64,0.7,AH0,phones,temp
37
+ 0.7,0.83,P,phones,temp
38
+ 0.83,0.88,R,phones,temp
39
+ 0.88,0.99,OW1,phones,temp
40
+ 0.99,1.12,CH,phones,temp
41
+ 1.12,1.19,T,phones,temp
42
+ 1.22,1.4,S,phones,temp
43
+ 1.4,1.58,OW1,phones,temp
44
+ 1.58,1.7,N,phones,temp
45
+ 1.7,1.84,IH1,phones,temp
46
+ 1.84,1.91,R,phones,temp
47
+ 1.91,2.01,T,phones,temp
48
+ 2.01,2.07,AH0,phones,temp
49
+ 2.07,2.13,DH,phones,temp
50
+ 2.13,2.3,EH1,phones,temp
51
+ 2.3,2.42,M,phones,temp
52
+ 2.53,2.55,DH,phones,temp
53
+ 2.55,2.61,AH0,phones,temp
54
+ 2.61,2.73,K,phones,temp
55
+ 2.73,2.85,AA1,phones,temp
56
+ 2.85,2.9,M,phones,temp
57
+ 2.9,2.95,AH0,phones,temp
58
+ 2.95,3.01,N,phones,temp
59
+ 3.05,3.22,AA1,phones,temp
60
+ 3.22,3.27,B,phones,temp
61
+ 3.27,3.34,JH,phones,temp
62
+ 3.34,3.48,EH0,phones,temp
63
+ 3.48,3.54,K,phones,temp
64
+ 3.54,3.62,T,phones,temp
65
+ 3.68,3.69,HH,phones,temp
66
+ 3.69,3.76,W,phones,temp
67
+ 3.76,3.8,IH1,phones,temp
68
+ 3.8,3.93,CH,phones,temp
69
+ 3.93,3.95,DH,phones,temp
70
+ 3.95,4.02,AH0,phones,temp
71
+ 4.02,4.12,S,phones,temp
72
+ 4.12,4.21,EH1,phones,temp
73
+ 4.21,4.27,N,phones,temp
74
+ 4.27,4.34,S,phones,temp
75
+ 4.34,4.42,D,phones,temp
76
+ 4.42,4.45,IH0,phones,temp
77
+ 4.45,4.59,S,phones,temp
78
+ 4.59,4.79,IY1,phones,temp
79
+ 4.79,4.87,V,phones,temp
80
+ 4.87,4.97,Z,phones,temp
81
+ 5.04,5.12,L,phones,temp
82
+ 5.12,5.33,AO1,phones,temp
83
+ 5.33,5.42,S,phones,temp
84
+ 5.42,5.54,T,phones,temp
85
+ 5.54,5.7,N,phones,temp
86
+ 5.7,5.89,AA1,phones,temp
87
+ 5.89,6.0,T,phones,temp
88
+ 6.0,6.05,B,phones,temp
89
+ 6.05,6.14,AY1,phones,temp
90
+ 6.14,6.24,D,phones,temp
91
+ 6.24,6.3,IH1,phones,temp
92
+ 6.3,6.38,S,phones,temp
93
+ 6.38,6.45,T,phones,temp
94
+ 6.45,6.51,AH0,phones,temp
95
+ 6.51,6.57,N,phones,temp
96
+ 6.57,6.67,S,phones,temp
97
+ 6.79,6.89,EH1,phones,temp
98
+ 6.89,6.95,N,phones,temp
99
+ 6.95,7.05,IY0,phones,temp
100
+ 7.05,7.13,AH0,phones,temp
101
+ 7.13,7.18,V,phones,temp
102
+ 7.18,7.22,IH0,phones,temp
103
+ 7.22,7.29,T,phones,temp
104
+ 7.29,7.34,S,phones,temp
105
+ 7.34,7.39,M,phones,temp
106
+ 7.39,7.5,AA1,phones,temp
107
+ 7.5,7.58,R,phones,temp
108
+ 7.58,7.7,K,phones,temp
109
+ 7.7,7.87,S,phones,temp
lib/voicecraft/edit_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_span(orig, new, editType):
2
+ orig_list = orig.split(" ")
3
+ new_list = new.split(" ")
4
+
5
+ flag = False # this indicate whether the actual edit follow the specified editType
6
+ if editType == "deletion":
7
+ assert len(orig_list) > len(new_list), f"the edit type is deletion, but new is not shorter than original:\n new: {new}\n orig: {orig}"
8
+ diff = len(orig_list) - len(new_list)
9
+ for i, (o, n) in enumerate(zip(orig_list, new_list)):
10
+ if o != n: # assume the index of the first different word is the starting index of the orig_span
11
+
12
+ orig_span = [i, i + diff - 1] # assume that the indices are starting and ending index of the deleted part
13
+ new_span = [i-1, i] # but for the new span, the starting and ending index is the two words that surround the deleted part
14
+ flag = True
15
+ break
16
+
17
+
18
+ elif editType == "insertion":
19
+ assert len(orig_list) < len(new_list), f"the edit type is insertion, but the new is not longer than the original:\n new: {new}\n orig: {orig}"
20
+ diff = len(new_list) - len(orig_list)
21
+ for i, (o, n) in enumerate(zip(orig_list, new_list)):
22
+ if o != n: # insertion is just the opposite of deletion
23
+ new_span = [i, i + diff - 1] # NOTE if only inserted one word, s and e will be the same
24
+ orig_span = [i-1, i]
25
+ flag = True
26
+ break
27
+
28
+ elif editType == "substitution":
29
+ new_span = []
30
+ orig_span = []
31
+ for i, (o, n) in enumerate(zip(orig_list, new_list)):
32
+ if o != n:
33
+ new_span = [i]
34
+ orig_span = [i]
35
+ break
36
+ assert len(new_span) == 1 and len(orig_span) == 1, f"new_span: {new_span}, orig_span: {orig_span}"
37
+ for j, (o, n) in enumerate(zip(orig_list[::-1], new_list[::-1])):
38
+ if o != n:
39
+ new_span.append(len(new_list) - j -1)
40
+ orig_span.append(len(orig_list) - j - 1)
41
+ flag = True
42
+ break
43
+ else:
44
+ raise RuntimeError(f"editType unknown: {editType}")
45
+
46
+ if not flag:
47
+ raise RuntimeError(f"wrong editing with the specified edit type:\n original: {orig}\n new: {new}\n, editType: {editType}")
48
+
49
+ return orig_span, new_span
lib/voicecraft/environment.yml ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: voicecraft
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=conda_forge
7
+ - _openmp_mutex=4.5=2_gnu
8
+ - aom=3.8.2=h59595ed_0
9
+ - asttokens=2.4.1=pyhd8ed1ab_0
10
+ - atk-1.0=2.38.0=hd4edc92_1
11
+ - audioread=3.0.1=py39hf3d152e_1
12
+ - backcall=0.2.0=pyh9f0ad1d_0
13
+ - baumwelch=0.3.7=h00ab1b0_5
14
+ - biopython=1.79=py39hb9d737c_3
15
+ - brotli=1.1.0=hd590300_1
16
+ - brotli-bin=1.1.0=hd590300_1
17
+ - brotli-python=1.1.0=py39h3d6467e_1
18
+ - bzip2=1.0.8=hd590300_5
19
+ - ca-certificates=2024.2.2=hbcca054_0
20
+ - cairo=1.18.0=h3faef2a_0
21
+ - certifi=2024.2.2=pyhd8ed1ab_0
22
+ - cffi=1.16.0=py39h7a31438_0
23
+ - charset-normalizer=3.3.2=pyhd8ed1ab_0
24
+ - click=8.1.7=unix_pyh707e725_0
25
+ - colorama=0.4.6=pyhd8ed1ab_0
26
+ - comm=0.2.2=pyhd8ed1ab_0
27
+ - contourpy=1.2.0=py39h7633fee_0
28
+ - cycler=0.12.1=pyhd8ed1ab_0
29
+ - dataclassy=1.0.1=pyhd8ed1ab_0
30
+ - dav1d=1.2.1=hd590300_0
31
+ - debugpy=1.8.1=py39h3d6467e_0
32
+ - decorator=5.1.1=pyhd8ed1ab_0
33
+ - executing=2.0.1=pyhd8ed1ab_0
34
+ - expat=2.6.2=h59595ed_0
35
+ - ffmpeg=6.1.1=gpl_h38e077a_106
36
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
37
+ - font-ttf-inconsolata=3.000=h77eed37_0
38
+ - font-ttf-source-code-pro=2.038=h77eed37_0
39
+ - font-ttf-ubuntu=0.83=h77eed37_1
40
+ - fontconfig=2.14.2=h14ed4e7_0
41
+ - fonts-conda-ecosystem=1=0
42
+ - fonts-conda-forge=1=0
43
+ - fonttools=4.49.0=py39hd1e30aa_0
44
+ - freetype=2.12.1=h267a509_2
45
+ - fribidi=1.0.10=h36c2ea0_0
46
+ - gdk-pixbuf=2.42.10=h829c605_5
47
+ - gettext=0.21.1=h27087fc_0
48
+ - giflib=5.2.1=h0b41bf4_3
49
+ - gmp=6.3.0=h59595ed_1
50
+ - gnutls=3.7.9=hb077bed_0
51
+ - graphite2=1.3.13=h58526e2_1001
52
+ - graphviz=9.0.0=h78e8752_1
53
+ - greenlet=3.0.3=py39h3d6467e_0
54
+ - gtk2=2.24.33=h280cfa0_4
55
+ - gts=0.7.6=h977cf35_4
56
+ - harfbuzz=8.3.0=h3d44ed6_0
57
+ - hdbscan=0.8.33=py39h44dd56e_4
58
+ - icu=73.2=h59595ed_0
59
+ - idna=3.6=pyhd8ed1ab_0
60
+ - importlib-metadata=7.0.2=pyha770c72_0
61
+ - importlib-resources=6.3.0=pyhd8ed1ab_0
62
+ - importlib_metadata=7.0.2=hd8ed1ab_0
63
+ - importlib_resources=6.3.0=pyhd8ed1ab_0
64
+ - ipykernel=6.29.3=pyhd33586a_0
65
+ - jedi=0.19.1=pyhd8ed1ab_0
66
+ - joblib=1.3.2=pyhd8ed1ab_0
67
+ - jupyter_client=8.6.1=pyhd8ed1ab_0
68
+ - jupyter_core=5.7.2=py39hf3d152e_0
69
+ - kaldi=5.5.1068=cpu_h31769b2_2
70
+ - keyutils=1.6.1=h166bdaf_0
71
+ - kiwisolver=1.4.5=py39h7633fee_1
72
+ - kneed=0.8.5=pyhd8ed1ab_0
73
+ - krb5=1.21.2=h659d440_0
74
+ - lame=3.100=h166bdaf_1003
75
+ - lazy_loader=0.3=pyhd8ed1ab_0
76
+ - lcms2=2.16=hb7c19ff_0
77
+ - ld_impl_linux-64=2.40=h41732ed_0
78
+ - lerc=4.0.0=h27087fc_0
79
+ - libabseil=20240116.1=cxx17_h59595ed_2
80
+ - libass=0.17.1=h8fe9dca_1
81
+ - libblas=3.9.0=21_linux64_openblas
82
+ - libbrotlicommon=1.1.0=hd590300_1
83
+ - libbrotlidec=1.1.0=hd590300_1
84
+ - libbrotlienc=1.1.0=hd590300_1
85
+ - libcblas=3.9.0=21_linux64_openblas
86
+ - libclang-cpp15=15.0.7=default_hb11cfb5_4
87
+ - libdeflate=1.19=hd590300_0
88
+ - libdrm=2.4.120=hd590300_0
89
+ - libedit=3.1.20191231=he28a2e2_2
90
+ - libexpat=2.6.2=h59595ed_0
91
+ - libffi=3.4.2=h7f98852_5
92
+ - libflac=1.4.3=h59595ed_0
93
+ - libgcc-ng=13.2.0=h807b86a_5
94
+ - libgd=2.3.3=h119a65a_9
95
+ - libgfortran-ng=13.2.0=h69a702a_5
96
+ - libgfortran5=13.2.0=ha4646dd_5
97
+ - libglib=2.80.0=hf2295e7_0
98
+ - libgomp=13.2.0=h807b86a_5
99
+ - libhwloc=2.9.3=default_h554bfaf_1009
100
+ - libiconv=1.17=hd590300_2
101
+ - libidn2=2.3.7=hd590300_0
102
+ - libjpeg-turbo=3.0.0=hd590300_1
103
+ - liblapack=3.9.0=21_linux64_openblas
104
+ - liblapacke=3.9.0=21_linux64_openblas
105
+ - libllvm14=14.0.6=hcd5def8_4
106
+ - libllvm15=15.0.7=hb3ce162_4
107
+ - libllvmspirv15=15.0.0=h0cdce71_1
108
+ - libnsl=2.0.1=hd590300_0
109
+ - libogg=1.3.4=h7f98852_1
110
+ - libopenblas=0.3.26=pthreads_h413a1c8_0
111
+ - libopenvino=2024.0.0=h2e90f83_1
112
+ - libopenvino-auto-batch-plugin=2024.0.0=hd5fc58b_1
113
+ - libopenvino-auto-plugin=2024.0.0=hd5fc58b_1
114
+ - libopenvino-hetero-plugin=2024.0.0=h3ecfda7_1
115
+ - libopenvino-intel-cpu-plugin=2024.0.0=h2e90f83_1
116
+ - libopenvino-intel-gpu-plugin=2024.0.0=h2e90f83_1
117
+ - libopenvino-ir-frontend=2024.0.0=h3ecfda7_1
118
+ - libopenvino-onnx-frontend=2024.0.0=h757c851_1
119
+ - libopenvino-paddle-frontend=2024.0.0=h757c851_1
120
+ - libopenvino-pytorch-frontend=2024.0.0=h59595ed_1
121
+ - libopenvino-tensorflow-frontend=2024.0.0=hca94c1a_1
122
+ - libopenvino-tensorflow-lite-frontend=2024.0.0=h59595ed_1
123
+ - libopus=1.3.1=h7f98852_1
124
+ - libpciaccess=0.18=hd590300_0
125
+ - libpng=1.6.43=h2797004_0
126
+ - libpq=16.2=h33b98f1_0
127
+ - libprotobuf=4.25.3=h08a7969_0
128
+ - librosa=0.10.1=pyhd8ed1ab_0
129
+ - librsvg=2.56.3=he3f83f7_1
130
+ - libsndfile=1.2.2=hc60ed4a_1
131
+ - libsodium=1.0.18=h36c2ea0_1
132
+ - libsqlite=3.45.2=h2797004_0
133
+ - libstdcxx-ng=13.2.0=h7e041cc_5
134
+ - libtasn1=4.19.0=h166bdaf_0
135
+ - libtiff=4.6.0=ha9c0a0a_2
136
+ - libunistring=0.9.10=h7f98852_0
137
+ - libuuid=2.38.1=h0b41bf4_0
138
+ - libva=2.21.0=hd590300_0
139
+ - libvorbis=1.3.7=h9c3ff4c_0
140
+ - libvpx=1.14.0=h59595ed_0
141
+ - libwebp=1.3.2=h658648e_1
142
+ - libwebp-base=1.3.2=hd590300_0
143
+ - libxcb=1.15=h0b41bf4_0
144
+ - libxcrypt=4.4.36=hd590300_1
145
+ - libxml2=2.12.5=h232c23b_0
146
+ - libzlib=1.2.13=hd590300_5
147
+ - llvm-spirv-15=15.0.0=h0cdce71_1
148
+ - mad=0.15.1b=h9c3ff4c_1
149
+ - markdown-it-py=3.0.0=pyhd8ed1ab_0
150
+ - matplotlib-base=3.8.3=py39he9076e7_0
151
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
152
+ - mdurl=0.1.2=pyhd8ed1ab_0
153
+ - montreal-forced-aligner=2.2.17=pyhd8ed1ab_0
154
+ - mpg123=1.32.4=h59595ed_0
155
+ - msgpack-python=1.0.7=py39h7633fee_0
156
+ - munkres=1.1.4=pyh9f0ad1d_0
157
+ - ncurses=6.4=h59595ed_2
158
+ - nest-asyncio=1.6.0=pyhd8ed1ab_0
159
+ - nettle=3.9.1=h7ab15ed_0
160
+ - ngram=1.3.14=h924138e_2
161
+ - numba=0.59.0=py39h615d6bd_1
162
+ - numpy=1.26.4=py39h474f0d3_0
163
+ - ocl-icd=2.3.2=hd590300_0
164
+ - openfst=1.8.2=h924138e_2
165
+ - openh264=2.4.1=h59595ed_0
166
+ - openjpeg=2.5.2=h488ebb8_0
167
+ - openssl=3.2.1=hd590300_0
168
+ - p11-kit=0.24.1=hc5aa10d_0
169
+ - packaging=24.0=pyhd8ed1ab_0
170
+ - pandas=2.2.1=py39hddac248_0
171
+ - pango=1.52.1=ha41ecd1_0
172
+ - parso=0.8.3=pyhd8ed1ab_0
173
+ - patsy=0.5.6=pyhd8ed1ab_0
174
+ - pcre2=10.43=hcad00b1_0
175
+ - pexpect=4.9.0=pyhd8ed1ab_0
176
+ - pgvector-python=0.2.5=pyhe093146_0
177
+ - pickleshare=0.7.5=py_1003
178
+ - pillow=10.2.0=py39had0adad_0
179
+ - pip=24.0=pyhd8ed1ab_0
180
+ - pixman=0.43.2=h59595ed_0
181
+ - platformdirs=4.2.0=pyhd8ed1ab_0
182
+ - pocl=5.0=h03a6ac1_2
183
+ - pocl-core=5.0=hdaecddf_2
184
+ - pocl-cpu=5.0=he901f76_2
185
+ - pocl-cpu-minimal=5.0=h5ccd973_2
186
+ - pocl-cuda=5.0=hdaecddf_2
187
+ - pocl-remote=5.0=h5ccd973_2
188
+ - pooch=1.8.1=pyhd8ed1ab_0
189
+ - postgresql=16.2=h7387d8b_0
190
+ - prompt-toolkit=3.0.42=pyha770c72_0
191
+ - prompt_toolkit=3.0.42=hd8ed1ab_0
192
+ - psutil=5.9.8=py39hd1e30aa_0
193
+ - psycopg2=2.9.9=py39h89197e3_0
194
+ - pthread-stubs=0.4=h36c2ea0_1001
195
+ - ptyprocess=0.7.0=pyhd3deb0d_0
196
+ - pugixml=1.14=h59595ed_0
197
+ - pure_eval=0.2.2=pyhd8ed1ab_0
198
+ - pycparser=2.21=pyhd8ed1ab_0
199
+ - pygments=2.17.2=pyhd8ed1ab_0
200
+ - pyparsing=3.1.2=pyhd8ed1ab_0
201
+ - pysocks=1.7.1=pyha2e5f31_6
202
+ - pysoundfile=0.12.1=pypyhd8ed1ab_1
203
+ - python=3.9.18=h0755675_1_cpython
204
+ - python-tzdata=2024.1=pyhd8ed1ab_0
205
+ - python_abi=3.9=4_cp39
206
+ - pytz=2024.1=pyhd8ed1ab_0
207
+ - pyyaml=6.0.1=py39hd1e30aa_1
208
+ - pyzmq=25.1.2=py39h8c080ef_0
209
+ - readline=8.2=h8228510_1
210
+ - requests=2.31.0=pyhd8ed1ab_0
211
+ - rich=13.7.1=pyhd8ed1ab_0
212
+ - rich-click=1.7.4=pyhd8ed1ab_0
213
+ - scikit-learn=1.2.2=py39hc236052_2
214
+ - scipy=1.12.0=py39h474f0d3_2
215
+ - seaborn=0.13.2=hd8ed1ab_0
216
+ - seaborn-base=0.13.2=pyhd8ed1ab_0
217
+ - setuptools=69.2.0=pyhd8ed1ab_0
218
+ - six=1.16.0=pyh6c4a22f_0
219
+ - snappy=1.1.10=h9fff704_0
220
+ - sox=14.4.2=ha5cc309_1018
221
+ - soxr=0.1.3=h0b41bf4_3
222
+ - soxr-python=0.3.7=py39h44dd56e_0
223
+ - sqlalchemy=2.0.28=py39hd1e30aa_0
224
+ - sqlite=3.45.2=h2c6b66d_0
225
+ - stack_data=0.6.2=pyhd8ed1ab_0
226
+ - statsmodels=0.14.1=py39h44dd56e_0
227
+ - svt-av1=1.8.0=h59595ed_0
228
+ - tbb=2021.11.0=h00ab1b0_1
229
+ - threadpoolctl=3.3.0=pyhc1e730c_0
230
+ - tk=8.6.13=noxft_h4845f30_101
231
+ - tornado=6.4=py39hd1e30aa_0
232
+ - tqdm=4.66.2=pyhd8ed1ab_0
233
+ - traitlets=5.14.2=pyhd8ed1ab_0
234
+ - typing-extensions=4.10.0=hd8ed1ab_0
235
+ - typing_extensions=4.10.0=pyha770c72_0
236
+ - tzcode=2024a=h3f72095_0
237
+ - tzdata=2024a=h0c530f3_0
238
+ - unicodedata2=15.1.0=py39hd1e30aa_0
239
+ - urllib3=2.2.1=pyhd8ed1ab_0
240
+ - wcwidth=0.2.13=pyhd8ed1ab_0
241
+ - wheel=0.42.0=pyhd8ed1ab_0
242
+ - x264=1!164.3095=h166bdaf_2
243
+ - x265=3.5=h924138e_3
244
+ - xorg-fixesproto=5.0=h7f98852_1002
245
+ - xorg-kbproto=1.0.7=h7f98852_1002
246
+ - xorg-libice=1.1.1=hd590300_0
247
+ - xorg-libsm=1.2.4=h7391055_0
248
+ - xorg-libx11=1.8.7=h8ee46fc_0
249
+ - xorg-libxau=1.0.11=hd590300_0
250
+ - xorg-libxdmcp=1.1.3=h7f98852_0
251
+ - xorg-libxext=1.3.4=h0b41bf4_2
252
+ - xorg-libxfixes=5.0.3=h7f98852_1004
253
+ - xorg-libxrender=0.9.11=hd590300_0
254
+ - xorg-renderproto=0.11.1=h7f98852_1002
255
+ - xorg-xextproto=7.3.0=h0b41bf4_1003
256
+ - xorg-xproto=7.0.31=h7f98852_1007
257
+ - xz=5.2.6=h166bdaf_0
258
+ - yaml=0.2.5=h7f98852_2
259
+ - zeromq=4.3.5=h59595ed_1
260
+ - zipp=3.17.0=pyhd8ed1ab_0
261
+ - zlib=1.2.13=hd590300_5
262
+ - zstd=1.5.5=hfc55251_0
263
+ - pip:
264
+ - absl-py==2.1.0
265
+ - aiofiles==23.2.1
266
+ - aiohttp==3.9.3
267
+ - aiosignal==1.3.1
268
+ - altair==5.2.0
269
+ - antlr4-python3-runtime==4.9.3
270
+ - anyio==4.3.0
271
+ - async-timeout==4.0.3
272
+ - attrs==23.2.0
273
+ - av==11.0.0
274
+ - babel==2.14.0
275
+ - beautifulsoup4==4.12.3
276
+ - bibtexparser==2.0.0b7
277
+ - bleach==6.1.0
278
+ - blis==0.7.11
279
+ - catalogue==2.0.10
280
+ - clldutils==3.22.2
281
+ - cloudpickle==3.0.0
282
+ - cmake==3.28.3
283
+ - colorlog==6.8.2
284
+ - confection==0.1.4
285
+ - csvw==3.3.0
286
+ - cymem==2.0.8
287
+ - cython==0.29.37
288
+ - datasets==2.16.0
289
+ - defusedxml==0.7.1
290
+ - demucs==4.0.1
291
+ - dill==0.3.6
292
+ - dlinfo==1.2.1
293
+ - docopt==0.6.2
294
+ - dora-search==0.1.12
295
+ - einops==0.7.0
296
+ - encodec==0.1.1
297
+ - exceptiongroup==1.2.0
298
+ - fastapi==0.110.0
299
+ - fastjsonschema==2.19.1
300
+ - ffmpy==0.3.2
301
+ - filelock==3.13.1
302
+ - flashy==0.0.2
303
+ - frozenlist==1.4.1
304
+ - fsspec==2023.10.0
305
+ - gradio==3.50.2
306
+ - gradio-client==0.6.1
307
+ - grpcio==1.62.1
308
+ - h11==0.14.0
309
+ - httpcore==1.0.4
310
+ - httpx==0.27.0
311
+ - huggingface-hub==0.21.4
312
+ - hydra-colorlog==1.2.0
313
+ - hydra-core==1.3.2
314
+ - ipython==8.12.3
315
+ - isodate==0.6.1
316
+ - jinja2==3.1.3
317
+ - jsonschema==4.21.1
318
+ - jsonschema-specifications==2023.12.1
319
+ - julius==0.2.7
320
+ - jupyterlab-pygments==0.3.0
321
+ - lameenc==1.7.0
322
+ - langcodes==3.3.0
323
+ - language-tags==1.2.0
324
+ - lit==18.1.1
325
+ - llvmlite==0.42.0
326
+ - lxml==5.1.0
327
+ - markdown==3.5.2
328
+ - markupsafe==2.1.5
329
+ - mistune==3.0.2
330
+ - mpmath==1.3.0
331
+ - msgpack==1.0.8
332
+ - multidict==6.0.5
333
+ - multiprocess==0.70.14
334
+ - murmurhash==1.0.10
335
+ - nbclient==0.10.0
336
+ - nbconvert==7.16.3
337
+ - nbformat==5.10.3
338
+ - networkx==3.2.1
339
+ - num2words==0.5.13
340
+ - nvidia-cublas-cu11==11.10.3.66
341
+ - nvidia-cuda-cupti-cu11==11.7.101
342
+ - nvidia-cuda-nvrtc-cu11==11.7.99
343
+ - nvidia-cuda-runtime-cu11==11.7.99
344
+ - nvidia-cudnn-cu11==8.5.0.96
345
+ - nvidia-cufft-cu11==10.9.0.58
346
+ - nvidia-curand-cu11==10.2.10.91
347
+ - nvidia-cusolver-cu11==11.4.0.1
348
+ - nvidia-cusparse-cu11==11.7.4.91
349
+ - nvidia-nccl-cu11==2.14.3
350
+ - nvidia-nvtx-cu11==11.7.91
351
+ - omegaconf==2.3.0
352
+ - openunmix==1.2.1
353
+ - orjson==3.9.15
354
+ - pandocfilters==1.5.1
355
+ - pathlib-abc==0.1.1
356
+ - pathy==0.11.0
357
+ - pgvector==0.2.2
358
+ - phonemizer==3.2.1
359
+ - pipreqs==0.5.0
360
+ - praatio==6.2.0
361
+ - preshed==3.0.9
362
+ - protobuf==4.25.3
363
+ - pyarrow==15.0.2
364
+ - pyarrow-hotfix==0.6
365
+ - pydantic==1.10.14
366
+ - pydub==0.25.1
367
+ - pylatexenc==2.10
368
+ - pynini==2.1.6
369
+ - pypinyin==0.48.0
370
+ - python-dateutil==2.9.0.post0
371
+ - python-multipart==0.0.9
372
+ - rdflib==7.0.0
373
+ - referencing==0.33.0
374
+ - regex==2023.12.25
375
+ - responses==0.18.0
376
+ - retrying==1.3.4
377
+ - rfc3986==1.5.0
378
+ - rpds-py==0.18.0
379
+ - safetensors==0.4.2
380
+ - segments==2.2.1
381
+ - semantic-version==2.10.0
382
+ - sentencepiece==0.2.0
383
+ - smart-open==6.4.0
384
+ - sniffio==1.3.1
385
+ - soupsieve==2.5
386
+ - spacy==3.5.2
387
+ - spacy-legacy==3.0.12
388
+ - spacy-loggers==1.0.5
389
+ - srsly==2.4.8
390
+ - starlette==0.36.3
391
+ - submitit==1.5.1
392
+ - sympy==1.12
393
+ - tabulate==0.9.0
394
+ - tensorboard==2.16.2
395
+ - tensorboard-data-server==0.7.2
396
+ - thinc==8.1.12
397
+ - tinycss2==1.2.1
398
+ - tokenizers==0.15.2
399
+ - toolz==0.12.1
400
+ - torch==2.0.1
401
+ - torchaudio==2.0.2
402
+ - torchmetrics==0.11.1
403
+ - transformers==4.38.2
404
+ - treetable==0.2.5
405
+ - triton==2.0.0
406
+ - typer==0.7.0
407
+ - uritemplate==4.1.1
408
+ - uvicorn==0.28.0
409
+ - wasabi==1.1.2
410
+ - webencodings==0.5.1
411
+ - websockets==11.0.3
412
+ - werkzeug==3.0.1
413
+ - xformers==0.0.22
414
+ - xxhash==3.4.1
415
+ - yarg==0.1.9
416
+ - yarl==1.9.4
417
+ prefix: /home/pyp/miniconda3/envs/voicecraft
lib/voicecraft/inference_speech_editing.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
lib/voicecraft/inference_speech_editing_scale.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, pickle
2
+ import logging
3
+ import os, random
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+
8
+ from data.tokenizer import (
9
+ AudioTokenizer,
10
+ TextTokenizer,
11
+ tokenize_audio,
12
+ tokenize_text
13
+ )
14
+
15
+ from models import voicecraft
16
+ import argparse, time, tqdm
17
+
18
+ # this script only works for the musicgen architecture
19
+ def get_args():
20
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
21
+ parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
22
+ parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
23
+ parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
24
+ parser.add_argument("--left_margin", type=float, default=0.08, help="extra space on the left to the word boundary")
25
+ parser.add_argument("--right_margin", type=float, default=0.08, help="extra space on the right to the word boundary")
26
+ parser.add_argument("--seed", type=int, default=1)
27
+ parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
28
+ parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
29
+ parser.add_argument("--top_k", type=int, default=-1, help="sampling param")
30
+ parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
31
+ parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
32
+ parser.add_argument("--output_dir", type=str, default=None)
33
+ parser.add_argument("--device", type=str, default="cuda")
34
+ parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
35
+ parser.add_argument("--stop_repetition", type=int, default=2, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
36
+ parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
37
+ parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
38
+ return parser.parse_args()
39
+
40
+ @torch.no_grad()
41
+ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, device, decode_config):
42
+ # phonemize
43
+ text_tokens = [phn2num[phn] for phn in
44
+ tokenize_text(
45
+ text_tokenizer, text=target_text.strip()
46
+ ) if phn in phn2num
47
+ ]
48
+ text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
49
+ text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
50
+
51
+ encoded_frames = tokenize_audio(audio_tokenizer, audio_fn)
52
+ original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K]
53
+ assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
54
+ logging.info(f"with direct encodec encoding before input, original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
55
+
56
+ # forward
57
+ stime = time.time()
58
+ encoded_frames = model.inference(
59
+ text_tokens.to(device),
60
+ text_tokens_lens.to(device),
61
+ original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
62
+ mask_interval=mask_interval.unsqueeze(0).to(device),
63
+ top_k=decode_config['top_k'],
64
+ top_p=decode_config['top_p'],
65
+ temperature=decode_config['temperature'],
66
+ stop_repetition=decode_config['stop_repetition'],
67
+ kvcache=decode_config['kvcache'],
68
+ silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens']) == str else decode_config['silence_tokens'],
69
+ ) # output is [1,K,T]
70
+ logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
71
+ if type(encoded_frames) == tuple:
72
+ encoded_frames = encoded_frames[0]
73
+ logging.info(f"generated encoded_frames.shape: {encoded_frames.shape}, which is {encoded_frames.shape[-1]/decode_config['codec_sr']} sec.")
74
+
75
+
76
+ # decode (both original and generated)
77
+ original_sample = audio_tokenizer.decode(
78
+ [(original_audio.transpose(2,1), None)] # [1,T,8] -> [1,8,T]
79
+ )
80
+ generated_sample = audio_tokenizer.decode(
81
+ [(encoded_frames, None)]
82
+ )
83
+
84
+ return original_sample, generated_sample
85
+
86
+ def get_model(exp_dir, device=None):
87
+ with open(os.path.join(exp_dir, "args.pkl"), "rb") as f:
88
+ model_args = pickle.load(f)
89
+
90
+ logging.info("load model weights...")
91
+ model = voicecraft.VoiceCraft(model_args)
92
+ ckpt_fn = os.path.join(exp_dir, "best_bundle.pth")
93
+ ckpt = torch.load(ckpt_fn, map_location='cpu')['model']
94
+ phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num']
95
+ model.load_state_dict(ckpt)
96
+ del ckpt
97
+ logging.info("done loading weights...")
98
+ if device == None:
99
+ device = torch.device("cpu")
100
+ if torch.cuda.is_available():
101
+ device = torch.device("cuda:0")
102
+ model.to(device)
103
+ model.eval()
104
+ return model, model_args, phn2num
105
+
106
+
107
+ def get_mask_interval(ali_fn, word_span_ind, editType):
108
+ with open(ali_fn, "r") as rf:
109
+ data = [l.strip().split(",") for l in rf.readlines()]
110
+ data = data[1:]
111
+ tmp = word_span_ind.split(",")
112
+ s, e = int(tmp[0]), int(tmp[-1])
113
+ start = None
114
+ for j, item in enumerate(data):
115
+ if j == s and item[3] == "words":
116
+ if editType == 'insertion':
117
+ start = float(item[1])
118
+ else:
119
+ start = float(item[0])
120
+ if j == e and item[3] == "words":
121
+ if editType == 'insertion':
122
+ end = float(item[0])
123
+ else:
124
+ end = float(item[1])
125
+ assert start != None
126
+ break
127
+ return (start, end)
128
+
129
+ if __name__ == "__main__":
130
+ def seed_everything(seed):
131
+ os.environ['PYTHONHASHSEED'] = str(seed)
132
+ random.seed(seed)
133
+ np.random.seed(seed)
134
+ torch.manual_seed(seed)
135
+ torch.cuda.manual_seed(seed)
136
+ torch.backends.cudnn.benchmark = False
137
+ torch.backends.cudnn.deterministic = True
138
+ formatter = (
139
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
140
+ )
141
+ logging.basicConfig(format=formatter, level=logging.INFO)
142
+ args = get_args()
143
+ # args.device = 'cpu'
144
+ args.allowed_repeat_tokens = eval(args.allowed_repeat_tokens)
145
+ seed_everything(args.seed)
146
+
147
+ # load model
148
+ stime = time.time()
149
+ logging.info(f"loading model from {args.exp_dir}")
150
+ model, model_args, phn2num = get_model(args.exp_dir)
151
+ if not os.path.isfile(model_args.exp_dir):
152
+ model_args.exp_dir = args.exp_dir
153
+ logging.info(f"loading model done, took {time.time() - stime:.4f} sec")
154
+
155
+ # setup text and audio tokenizer
156
+ text_tokenizer = TextTokenizer(backend="espeak")
157
+ audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu
158
+
159
+ with open(args.manifest_fn, "r") as rf:
160
+ manifest = [l.strip().split("\t") for l in rf.readlines()]
161
+ manifest = manifest[1:]
162
+
163
+ # wav_fn txt_fn alingment_fn num_words word_span_ind
164
+ audio_fns = []
165
+ target_texts = []
166
+ mask_intervals = []
167
+ edit_types = []
168
+ new_spans = []
169
+ orig_spans = []
170
+ os.makedirs(args.output_dir, exist_ok=True)
171
+ if args.crop_concat:
172
+ mfa_temp = f"{args.output_dir}/mfa_temp"
173
+ os.makedirs(mfa_temp, exist_ok=True)
174
+ for item in manifest:
175
+ audio_fn = os.path.join(args.audio_root, item[0])
176
+ temp = torchaudio.info(audio_fn)
177
+ audio_dur = temp.num_frames/temp.sample_rate
178
+ audio_fns.append(audio_fn)
179
+ target_text = item[2].split("|")[-1]
180
+ edit_types.append(item[5].split("|"))
181
+ new_spans.append(item[4].split("|"))
182
+ orig_spans.append(item[3].split("|"))
183
+ target_texts.append(target_text) # the last transcript is the target
184
+ # mi needs to be created from word_ind_span and alignment_fn, along with args.left_margin and args.right_margin
185
+ mis = []
186
+ all_ind_intervals = item[3].split("|")
187
+ editTypes = item[5].split("|")
188
+ smaller_indx = []
189
+ alignment_fn = os.path.join(args.audio_root, "aligned", item[0].replace(".wav", ".csv"))
190
+ if not os.path.isfile(alignment_fn):
191
+ alignment_fn = alignment_fn.replace("/aligned/", "/aligned_csv/")
192
+ assert os.path.isfile(alignment_fn), alignment_fn
193
+ for ind_inter,editType in zip(all_ind_intervals, editTypes):
194
+ # print(ind_inter)
195
+ mi = get_mask_interval(alignment_fn, ind_inter, editType)
196
+ mi = (max(mi[0] - args.left_margin, 1/args.codec_sr), min(mi[1] + args.right_margin, audio_dur)) # in seconds
197
+ mis.append(mi)
198
+ smaller_indx.append(mi[0])
199
+ ind = np.argsort(smaller_indx)
200
+ mis = [mis[id] for id in ind]
201
+ mask_intervals.append(mis)
202
+
203
+
204
+
205
+ for i, (audio_fn, target_text, mask_interval) in enumerate(tqdm.tqdm(zip(audio_fns, target_texts, mask_intervals))):
206
+ orig_mask_interval = mask_interval
207
+ mask_interval = [[round(cmi[0]*args.codec_sr), round(cmi[1]*args.codec_sr)] for cmi in mask_interval]
208
+ # logging.info(f"i: {i}, mask_interval: {mask_interval}")
209
+ mask_interval = torch.LongTensor(mask_interval) # [M,2]
210
+ orig_audio, new_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, args.device, vars(args))
211
+
212
+ # save segments for comparison
213
+ orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu()
214
+ # logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}")
215
+
216
+ save_fn_new = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_new_seed{args.seed}.wav"
217
+
218
+ torchaudio.save(save_fn_new, new_audio, args.codec_audio_sr)
219
+
220
+ save_fn_orig = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_orig.wav"
221
+ if not os.path.isfile(save_fn_orig):
222
+ orig_audio, orig_sr = torchaudio.load(audio_fn)
223
+ if orig_sr != args.codec_audio_sr:
224
+ orig_audio = torchaudio.transforms.Resample(orig_sr, args.codec_audio_sr)(orig_audio)
225
+ torchaudio.save(save_fn_orig, orig_audio, args.codec_audio_sr)
226
+
lib/voicecraft/inference_tts.ipynb ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "VoiceCraft Inference Text To Speech Demo\n",
8
+ "===\n",
9
+ "This will install a ton of dependencies all over so consider using the provided docker container start-jupyter script to keep the cruft off your dev box.\n",
10
+ "\n",
11
+ "Run the next cells one at a time up until the *STOP* and follow those instructions before continuing. You only have to do this the first time to setup the container."
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "### Only do the below if you are using docker"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "# install OS deps\n",
28
+ "!sudo apt-get update && sudo apt-get install -y \\\n",
29
+ " git-core \\\n",
30
+ " ffmpeg \\\n",
31
+ " espeak-ng"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "# Update and setup Conda voicecraft environment\n",
41
+ "!conda update -y -n base -c conda-forge conda\n",
42
+ "!conda create -y -n voicecraft python=3.9.16 && \\\n",
43
+ " conda init bash"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "# install conda and pip stuff in the activated conda above context\n",
53
+ "!echo -e \"Grab a cup a coffee and a slice of pizza...\\n\\n\"\n",
54
+ "\n",
55
+ "# make sure $HOME and $USER are setup so this will source the conda environment\n",
56
+ "!source ~/.bashrc && \\\n",
57
+ " conda activate voicecraft && \\\n",
58
+ " conda install -y -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 && \\\n",
59
+ " pip install torch==2.0.1 && \\\n",
60
+ " pip install tensorboard==2.16.2 && \\\n",
61
+ " pip install phonemizer==3.2.1 && \\\n",
62
+ " pip install torchaudio==2.0.2 && \\\n",
63
+ " pip install datasets==2.16.0 && \\\n",
64
+ " pip install torchmetrics==0.11.1\n",
65
+ "\n",
66
+ "# do this one last otherwise you'll get an error about torch compiler missing due to xformer mismatch\n",
67
+ "!source ~/.bashrc && \\\n",
68
+ " conda activate voicecraft && \\\n",
69
+ " pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# okay setup the conda environment such that jupyter notebook can find the kernel\n",
79
+ "!source ~/.bashrc && \\\n",
80
+ " conda activate voicecraft && \\\n",
81
+ " conda install -y -n voicecraft ipykernel --update-deps --force-reinstall\n",
82
+ "\n",
83
+ "# installs the Jupyter kernel into /home/myusername/.local/share/jupyter/kernels/voicecraft\n",
84
+ "!source ~/.bashrc && \\\n",
85
+ " conda activate voicecraft && \\\n",
86
+ " python3 -m ipykernel install --user --name=voicecraft"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "metadata": {},
92
+ "source": [
93
+ "# STOP\n",
94
+ "You have to do this part manually using the mouse/keyboard and the tabs at the top.\n",
95
+ "\n",
96
+ "* Refresh your browser to make sure it picks up the new kernel.\n",
97
+ "* Kernel -> Change Kernel -> Select Kernel -> voicecraft\n",
98
+ "* Kernel -> Restart Kernel -> Yes\n",
99
+ "\n",
100
+ "Now you can run the rest of the notebook and get an audio sample output. It will automatically download more models and such. The next time you use this container, you can just start below here as the dependencies will remain available until you delete the docker container."
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "metadata": {},
106
+ "source": [
107
+ "### Only do the above if you are using docker"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "# import libs\n",
117
+ "# if this throws an error, something went wrong installing dependencies or changing the kernel above!\n",
118
+ "import os\n",
119
+ "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n",
120
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
121
+ "os.environ[\"USER\"] = \"YOUR_USERNAME\" # TODO change this to your username\n",
122
+ "\n",
123
+ "import torch\n",
124
+ "import torchaudio\n",
125
+ "import numpy as np\n",
126
+ "import random\n",
127
+ "\n",
128
+ "from data.tokenizer import (\n",
129
+ " AudioTokenizer,\n",
130
+ " TextTokenizer,\n",
131
+ ")\n"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "# install MFA models and dictionaries if you haven't done so already\n",
141
+ "!source ~/.bashrc && \\\n",
142
+ " conda activate voicecraft && \\\n",
143
+ " mfa model download dictionary english_us_arpa && \\\n",
144
+ " mfa model download acoustic english_us_arpa"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "# load model, encodec, and phn2num\n",
154
+ "# # load model, tokenizer, and other necessary files\n",
155
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
156
+ "from models import voicecraft\n",
157
+ "#import models.voicecraft as voicecraft\n",
158
+ "voicecraft_name=\"giga830M.pth\" # or giga330M.pth\n",
159
+ "ckpt_fn =f\"./pretrained_models/{voicecraft_name}\"\n",
160
+ "encodec_fn = \"./pretrained_models/encodec_4cb2048_giga.th\"\n",
161
+ "if not os.path.exists(ckpt_fn):\n",
162
+ " os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\\?download\\=true\")\n",
163
+ " os.system(f\"mv {voicecraft_name}\\?download\\=true ./pretrained_models/{voicecraft_name}\")\n",
164
+ "if not os.path.exists(encodec_fn):\n",
165
+ " os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th\")\n",
166
+ " os.system(f\"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th\")\n",
167
+ "\n",
168
+ "ckpt = torch.load(ckpt_fn, map_location=\"cpu\")\n",
169
+ "model = voicecraft.VoiceCraft(ckpt[\"config\"])\n",
170
+ "model.load_state_dict(ckpt[\"model\"])\n",
171
+ "model.to(device)\n",
172
+ "model.eval()\n",
173
+ "\n",
174
+ "phn2num = ckpt['phn2num']\n",
175
+ "\n",
176
+ "text_tokenizer = TextTokenizer(backend=\"espeak\")\n",
177
+ "audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) # will also put the neural codec model on gpu\n"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "# Prepare your audio\n",
187
+ "# point to the original audio whose speech you want to clone\n",
188
+ "# write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file\n",
189
+ "orig_audio = \"./demo/84_121550_000074_000000.wav\"\n",
190
+ "orig_transcript = \"But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,\"\n",
191
+ "\n",
192
+ "# move the audio and transcript to temp folder\n",
193
+ "temp_folder = \"./demo/temp\"\n",
194
+ "os.makedirs(temp_folder, exist_ok=True)\n",
195
+ "os.system(f\"cp {orig_audio} {temp_folder}\")\n",
196
+ "filename = os.path.splitext(orig_audio.split(\"/\")[-1])[0]\n",
197
+ "with open(f\"{temp_folder}/{filename}.txt\", \"w\") as f:\n",
198
+ " f.write(orig_transcript)\n",
199
+ "# run MFA to get the alignment\n",
200
+ "align_temp = f\"{temp_folder}/mfa_alignments\"\n",
201
+ "!source ~/.bashrc && \\\n",
202
+ " conda activate voicecraft && \\\n",
203
+ " mfa align -v --clean -j 1 --output_format csv {temp_folder} \\\n",
204
+ " english_us_arpa english_us_arpa {align_temp}\n",
205
+ "\n",
206
+ "# # if the above fails, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue\n",
207
+ "# !source ~/.bashrc && \\\n",
208
+ "# conda activate voicecraft && \\\n",
209
+ "# mfa align -v --clean -j 1 --output_format csv {temp_folder} \\\n",
210
+ "# english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000\n",
211
+ "\n"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "# take a look at demo/temp/mfa_alignment, decide which part of the audio to use as prompt\n",
221
+ "cut_off_sec = 3.01 # NOTE: according to forced-alignment file demo/temp/mfa_alignments/84_121550_000074_000000.csv, the word \"common\" stop as 3.01 sec, this should be different for different audio\n",
222
+ "target_transcript = \"But when I had approached so near to them The common I cannot believe that the same model can also do text to speech synthesis as well!\"\n",
223
+ "# NOTE: 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec.\n",
224
+ "audio_fn = f\"{temp_folder}/{filename}.wav\"\n",
225
+ "info = torchaudio.info(audio_fn)\n",
226
+ "audio_dur = info.num_frames / info.sample_rate\n",
227
+ "\n",
228
+ "assert cut_off_sec < audio_dur, f\"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}\"\n",
229
+ "prompt_end_frame = int(cut_off_sec * info.sample_rate)\n",
230
+ "\n",
231
+ "# run the model to get the output\n",
232
+ "# hyperparameters for inference\n",
233
+ "codec_audio_sr = 16000\n",
234
+ "codec_sr = 50\n",
235
+ "top_k = 0\n",
236
+ "top_p = 0.8\n",
237
+ "temperature = 1\n",
238
+ "silence_tokens=[1388,1898,131]\n",
239
+ "kvcache = 1 # NOTE if OOM, change this to 0, or try the 330M model\n",
240
+ "\n",
241
+ "# NOTE adjust the below three arguments if the generation is not as good\n",
242
+ "stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1\n",
243
+ "sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.\n",
244
+ "seed = 1 # change seed if you are still unhappy with the result\n",
245
+ "\n",
246
+ "def seed_everything(seed):\n",
247
+ " os.environ['PYTHONHASHSEED'] = str(seed)\n",
248
+ " random.seed(seed)\n",
249
+ " np.random.seed(seed)\n",
250
+ " torch.manual_seed(seed)\n",
251
+ " torch.cuda.manual_seed(seed)\n",
252
+ " torch.backends.cudnn.benchmark = False\n",
253
+ " torch.backends.cudnn.deterministic = True\n",
254
+ "seed_everything(seed)\n",
255
+ "\n",
256
+ "decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens, \"sample_batch_size\": sample_batch_size}\n",
257
+ "from inference_tts_scale import inference_one_sample\n",
258
+ "concated_audio, gen_audio = inference_one_sample(model, ckpt[\"config\"], phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, device, decode_config, prompt_end_frame)\n",
259
+ " \n",
260
+ "# save segments for comparison\n",
261
+ "concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()\n",
262
+ "# logging.info(f\"length of the resynthesize orig audio: {orig_audio.shape}\")\n",
263
+ "\n",
264
+ "\n",
265
+ "# display the audio\n",
266
+ "from IPython.display import Audio\n",
267
+ "print(\"concatenate prompt and generated:\")\n",
268
+ "display(Audio(concated_audio, rate=codec_audio_sr))\n",
269
+ "\n",
270
+ "print(\"generated:\")\n",
271
+ "display(Audio(gen_audio, rate=codec_audio_sr))\n",
272
+ "\n",
273
+ "# # save the audio\n",
274
+ "# # output_dir\n",
275
+ "# output_dir = \"/home/pyp/VoiceCraft/demo/generated_tts\"\n",
276
+ "# os.makedirs(output_dir, exist_ok=True)\n",
277
+ "# seg_save_fn_gen = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_gen_seed{seed}.wav\"\n",
278
+ "# seg_save_fn_concat = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_concat_seed{seed}.wav\" \n",
279
+ "\n",
280
+ "# torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr)\n",
281
+ "# torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr)\n",
282
+ "\n",
283
+ "# if you get error importing T5 in transformers\n",
284
+ "# try \n",
285
+ "# pip uninstall Pillow\n",
286
+ "# pip install Pillow\n",
287
+ "# you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored"
288
+ ]
289
+ }
290
+ ],
291
+ "metadata": {
292
+ "kernelspec": {
293
+ "display_name": "voicecraft",
294
+ "language": "python",
295
+ "name": "python3"
296
+ },
297
+ "language_info": {
298
+ "codemirror_mode": {
299
+ "name": "ipython",
300
+ "version": 3
301
+ },
302
+ "file_extension": ".py",
303
+ "mimetype": "text/x-python",
304
+ "name": "python",
305
+ "nbconvert_exporter": "python",
306
+ "pygments_lexer": "ipython3",
307
+ "version": "3.9.18"
308
+ }
309
+ },
310
+ "nbformat": 4,
311
+ "nbformat_minor": 4
312
+ }
lib/voicecraft/inference_tts_scale.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, pickle
2
+ import logging
3
+ import os, random
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+
8
+ from data.tokenizer import (
9
+ AudioTokenizer,
10
+ TextTokenizer,
11
+ tokenize_audio,
12
+ tokenize_text
13
+ )
14
+
15
+ from models import voicecraft
16
+ import argparse, time, tqdm
17
+
18
+
19
+ # this script only works for the musicgen architecture
20
+ def get_args():
21
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22
+ parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
23
+ parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
24
+ parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
25
+ parser.add_argument("--seed", type=int, default=1)
26
+ parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
27
+ parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
28
+ parser.add_argument("--top_k", type=int, default=0, help="sampling param")
29
+ parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
30
+ parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
31
+ parser.add_argument("--output_dir", type=str, default=None)
32
+ parser.add_argument("--device", type=str, default="cuda")
33
+ parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
34
+ parser.add_argument("--crop_concat", type=int, default=0)
35
+ parser.add_argument("--stop_repetition", type=int, default=-1, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
36
+ parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
37
+ parser.add_argument("--sample_batch_size", type=int, default=1, help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation")
38
+ parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
39
+ return parser.parse_args()
40
+
41
+
42
+ @torch.no_grad()
43
+ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame):
44
+ # phonemize
45
+ text_tokens = [phn2num[phn] for phn in
46
+ tokenize_text(
47
+ text_tokenizer, text=target_text.strip()
48
+ ) if phn in phn2num
49
+ ]
50
+ text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
51
+ text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
52
+
53
+ # encode audio
54
+ encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame)
55
+ original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K]
56
+ assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
57
+ logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
58
+
59
+ # forward
60
+ stime = time.time()
61
+ if decode_config['sample_batch_size'] <= 1:
62
+ logging.info(f"running inference with batch size 1")
63
+ concat_frames, gen_frames = model.inference_tts(
64
+ text_tokens.to(device),
65
+ text_tokens_lens.to(device),
66
+ original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
67
+ top_k=decode_config['top_k'],
68
+ top_p=decode_config['top_p'],
69
+ temperature=decode_config['temperature'],
70
+ stop_repetition=decode_config['stop_repetition'],
71
+ kvcache=decode_config['kvcache'],
72
+ silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
73
+ ) # output is [1,K,T]
74
+ else:
75
+ logging.info(f"running inference with batch size {decode_config['sample_batch_size']}, i.e. return the shortest among {decode_config['sample_batch_size']} generations.")
76
+ concat_frames, gen_frames = model.inference_tts_batch(
77
+ text_tokens.to(device),
78
+ text_tokens_lens.to(device),
79
+ original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
80
+ top_k=decode_config['top_k'],
81
+ top_p=decode_config['top_p'],
82
+ temperature=decode_config['temperature'],
83
+ stop_repetition=decode_config['stop_repetition'],
84
+ kvcache=decode_config['kvcache'],
85
+ batch_size = decode_config['sample_batch_size'],
86
+ silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
87
+ ) # output is [1,K,T]
88
+ logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
89
+
90
+ logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.")
91
+
92
+ # for timestamp, codes in enumerate(gen_frames[0].transpose(1,0)):
93
+ # logging.info(f"{timestamp}: {codes.tolist()}")
94
+ # decode (both original and generated)
95
+ concat_sample = audio_tokenizer.decode(
96
+ [(concat_frames, None)] # [1,T,8] -> [1,8,T]
97
+ )
98
+ gen_sample = audio_tokenizer.decode(
99
+ [(gen_frames, None)]
100
+ )
101
+
102
+ # return
103
+ return concat_sample, gen_sample
104
+
105
+ def get_model(exp_dir, device=None):
106
+ with open(os.path.join(exp_dir, "args.pkl"), "rb") as f:
107
+ model_args = pickle.load(f)
108
+
109
+ logging.info("load model weights...")
110
+ model = voicecraft.VoiceCraft(model_args)
111
+ ckpt_fn = os.path.join(exp_dir, "best_bundle.pth")
112
+ ckpt = torch.load(ckpt_fn, map_location='cpu')['model']
113
+ phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num']
114
+ model.load_state_dict(ckpt)
115
+ del ckpt
116
+ logging.info("done loading weights...")
117
+ if device == None:
118
+ device = torch.device("cpu")
119
+ if torch.cuda.is_available():
120
+ device = torch.device("cuda:0")
121
+ model.to(device)
122
+ model.eval()
123
+ return model, model_args, phn2num
124
+
125
+ if __name__ == "__main__":
126
+ def seed_everything(seed):
127
+ os.environ['PYTHONHASHSEED'] = str(seed)
128
+ random.seed(seed)
129
+ np.random.seed(seed)
130
+ torch.manual_seed(seed)
131
+ torch.cuda.manual_seed(seed)
132
+ torch.backends.cudnn.benchmark = False
133
+ torch.backends.cudnn.deterministic = True
134
+ formatter = (
135
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
136
+ )
137
+ logging.basicConfig(format=formatter, level=logging.INFO)
138
+ args = get_args()
139
+ # args.device='cpu'
140
+ seed_everything(args.seed)
141
+
142
+ os.makedirs(args.output_dir, exist_ok=True)
143
+ # load model
144
+
145
+ with open(args.manifest_fn, "r") as rf:
146
+ manifest = [l.strip().split("\t") for l in rf.readlines()]
147
+ manifest = manifest[1:]
148
+ manifest = [[item[0], item[2], item[3], item[1], item[5]] for item in manifest]
149
+
150
+ stime = time.time()
151
+ logging.info(f"loading model from {args.exp_dir}")
152
+ model, model_args, phn2num = get_model(args.exp_dir)
153
+ logging.info(f"loading model done, took {time.time() - stime:.4f} sec")
154
+
155
+ # setup text and audio tokenizer
156
+ text_tokenizer = TextTokenizer(backend="espeak")
157
+ audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu
158
+
159
+ audio_fns = []
160
+ texts = []
161
+ prompt_end_frames = []
162
+ new_audio_fns = []
163
+ text_to_syn = []
164
+
165
+ for item in manifest:
166
+ audio_fn = os.path.join(args.audio_root, item[0])
167
+ audio_fns.append(audio_fn)
168
+ temp = torchaudio.info(audio_fn)
169
+ prompt_end_frames.append(round(float(item[2])*temp.sample_rate))
170
+ texts.append(item[1])
171
+ new_audio_fns.append(item[-2])
172
+ all_text = item[1].split(" ")
173
+ start_ind = int(item[-1].split(",")[0])
174
+ text_to_syn.append(" ".join(all_text[start_ind:]))
175
+
176
+ for i, (audio_fn, text, prompt_end_frame, new_audio_fn, to_syn) in enumerate(tqdm.tqdm((zip(audio_fns, texts, prompt_end_frames, new_audio_fns, text_to_syn)))):
177
+ output_expected_sr = args.codec_audio_sr
178
+ concated_audio, gen_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, text, args.device, vars(args), prompt_end_frame)
179
+
180
+ # save segments for comparison
181
+ concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
182
+ if output_expected_sr != args.codec_audio_sr:
183
+ gen_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(gen_audio)
184
+ concated_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(concated_audio)
185
+
186
+ seg_save_fn_gen = f"{args.output_dir}/gen_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav"
187
+ seg_save_fn_concat = f"{args.output_dir}/concat_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav"
188
+
189
+ torchaudio.save(seg_save_fn_gen, gen_audio, args.codec_audio_sr)
190
+ torchaudio.save(seg_save_fn_concat, concated_audio, args.codec_audio_sr)
lib/voicecraft/main.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import pickle
4
+ import argparse
5
+ import logging
6
+ import torch.distributed as dist
7
+ from config import MyParser
8
+ from steps import trainer
9
+
10
+
11
+ if __name__ == "__main__":
12
+ formatter = (
13
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
14
+ )
15
+ logging.basicConfig(format=formatter, level=logging.INFO)
16
+
17
+ torch.cuda.empty_cache()
18
+ args = MyParser().parse_args()
19
+ logging.info(args)
20
+ exp_dir = Path(args.exp_dir)
21
+ exp_dir.mkdir(exist_ok=True, parents=True)
22
+ logging.info(f"exp_dir: {str(exp_dir)}")
23
+
24
+ if args.resume:
25
+ resume = args.resume
26
+ assert(bool(args.exp_dir))
27
+ with open("%s/args.pkl" % args.exp_dir, "rb") as f:
28
+ old_args = pickle.load(f)
29
+ new_args = vars(args)
30
+ old_args = vars(old_args)
31
+ for key in new_args:
32
+ if key not in old_args or old_args[key] != new_args[key]:
33
+ old_args[key] = new_args[key]
34
+ args = argparse.Namespace(**old_args)
35
+ args.resume = resume
36
+ else:
37
+ with open("%s/args.pkl" % args.exp_dir, "wb") as f:
38
+ pickle.dump(args, f)
39
+
40
+ dist.init_process_group(backend='nccl', init_method='env://')
41
+ rank = dist.get_rank()
42
+ world_size = dist.get_world_size()
43
+ torch.cuda.set_device(rank)
44
+ my_trainer = trainer.Trainer(args, world_size, rank)
45
+ my_trainer.train()
lib/voicecraft/models/codebooks_patterns.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import namedtuple
8
+ from dataclasses import dataclass
9
+ from functools import lru_cache
10
+ import logging
11
+ import typing as tp
12
+
13
+ from abc import ABC, abstractmethod
14
+ import torch
15
+
16
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
17
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
18
+
19
+
20
+ @dataclass
21
+ class Pattern:
22
+ """Base implementation of a pattern over a sequence with multiple codebooks.
23
+
24
+ The codebook pattern consists in a layout, defining for each sequence step
25
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
26
+ The first item of the pattern is always an empty list in order to properly insert a special token
27
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
28
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
29
+
30
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
31
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
32
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
33
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
34
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
35
+ is returned along with a mask indicating valid tokens.
36
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
37
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
38
+ to fill and specify invalid positions if needed.
39
+ See the dedicated methods for more details.
40
+ """
41
+ # Pattern layout, for each sequence step, we have a list of coordinates
42
+ # corresponding to the original codebook timestep and position.
43
+ # The first list is always an empty list in order to properly insert
44
+ # a special token to start with.
45
+ layout: PatternLayout
46
+ timesteps: int
47
+ n_q: int
48
+
49
+ def __post_init__(self):
50
+ assert len(self.layout) > 0
51
+ assert self.layout[0] == []
52
+ self._validate_layout()
53
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
54
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
55
+ # logging.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
56
+
57
+ def _validate_layout(self):
58
+ """Runs checks on the layout to ensure a valid pattern is defined.
59
+ A pattern is considered invalid if:
60
+ - Multiple timesteps for a same codebook are defined in the same sequence step
61
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
62
+ (this would mean that we have future timesteps before past timesteps).
63
+ """
64
+ q_timesteps = {q: 0 for q in range(self.n_q)}
65
+ for s, seq_coords in enumerate(self.layout):
66
+ if len(seq_coords) > 0:
67
+ qs = set()
68
+ for coord in seq_coords:
69
+ qs.add(coord.q)
70
+ last_q_timestep = q_timesteps[coord.q]
71
+ assert coord.t >= last_q_timestep, \
72
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
73
+ q_timesteps[coord.q] = coord.t
74
+ # each sequence step contains at max 1 coordinate per codebook
75
+ assert len(qs) == len(seq_coords), \
76
+ f"Multiple entries for a same codebook are found at step {s}"
77
+
78
+ @property
79
+ def num_sequence_steps(self):
80
+ return len(self.layout) - 1
81
+
82
+ @property
83
+ def max_delay(self):
84
+ max_t_in_seq_coords = 0
85
+ for seq_coords in self.layout[1:]:
86
+ for coords in seq_coords:
87
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
88
+ return max_t_in_seq_coords - self.timesteps
89
+
90
+ @property
91
+ def valid_layout(self):
92
+ valid_step = len(self.layout) - self.max_delay
93
+ return self.layout[:valid_step]
94
+
95
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
96
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
97
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
98
+ and the actual codebook coordinates.
99
+ """
100
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
101
+ if q is not None:
102
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
103
+ coords = []
104
+ for s, seq_codes in enumerate(self.layout):
105
+ for code in seq_codes:
106
+ if code.t == t and (q is None or code.q == q):
107
+ coords.append((s, code))
108
+ return coords
109
+
110
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
111
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
112
+
113
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
114
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
115
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
116
+
117
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
118
+ device: tp.Union[torch.device, str] = 'cpu'):
119
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
120
+
121
+ Args:
122
+ timesteps (int): Maximum number of timesteps steps to consider.
123
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
124
+ device (Union[torch.device, str]): Device for created tensors.
125
+ Returns:
126
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
127
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
128
+ """
129
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
130
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
131
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
132
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
133
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
134
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
135
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
136
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
137
+ # fill indexes with last sequence step value that will correspond to our special token
138
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
139
+ # which will correspond to the index: n_q * timesteps
140
+ indexes[:] = n_q * timesteps
141
+ # iterate over the pattern and fill scattered indexes and mask
142
+ for s, sequence_coords in enumerate(ref_layout):
143
+ for coords in sequence_coords:
144
+ if coords.t < timesteps:
145
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
146
+ mask[coords.q, s] = 1
147
+ indexes = torch.from_numpy(indexes).to(device)
148
+ mask = torch.from_numpy(mask).to(device)
149
+ return indexes, mask
150
+
151
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
152
+ """Build sequence corresponding to the pattern from the input tensor z.
153
+ The sequence is built using up to sequence_steps if specified, and non-pattern
154
+ coordinates are filled with the special token.
155
+
156
+ Args:
157
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
158
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
159
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
160
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
161
+ Returns:
162
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
163
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
164
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
165
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
166
+ """
167
+ B, K, T = z.shape
168
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
169
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
170
+ )
171
+ z = z.view(B, -1)
172
+ # we append the special token as the last index of our flattened z tensor
173
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
174
+ values = z[:, indexes.view(-1)]
175
+ values = values.view(B, K, indexes.shape[-1])
176
+ return values, indexes, mask
177
+
178
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
179
+ keep_only_valid_steps: bool = False,
180
+ is_model_output: bool = False,
181
+ device: tp.Union[torch.device, str] = 'cpu'):
182
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
183
+ from interleaving pattern.
184
+
185
+ Args:
186
+ sequence_steps (int): Sequence steps.
187
+ n_q (int): Number of codebooks.
188
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
189
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
190
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
191
+ device (Union[torch.device, str]): Device for created tensors.
192
+ Returns:
193
+ torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
194
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
195
+ """
196
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
197
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
198
+ timesteps = self.timesteps
199
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
200
+ assert sequence_steps <= len(ref_layout), \
201
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
202
+
203
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
204
+ if is_model_output:
205
+ ref_layout = ref_layout[1:]
206
+
207
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
208
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
209
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
210
+ # fill indexes with last sequence step value that will correspond to our special token
211
+ indexes[:] = n_q * sequence_steps
212
+ for s, sequence_codes in enumerate(ref_layout):
213
+ if s < sequence_steps:
214
+ for code in sequence_codes:
215
+ if code.t < timesteps:
216
+ indexes[code.q, code.t] = s + code.q * sequence_steps
217
+ mask[code.q, code.t] = 1
218
+ indexes = torch.from_numpy(indexes).to(device)
219
+ mask = torch.from_numpy(mask).to(device)
220
+ return indexes, mask
221
+
222
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
223
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
224
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
225
+ are filled with the special token.
226
+
227
+ Args:
228
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
229
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
230
+ Returns:
231
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
232
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
233
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
234
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
235
+ """
236
+ B, K, S = s.shape
237
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
238
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
239
+ )
240
+ s = s.view(B, -1)
241
+ # we append the special token as the last index of our flattened z tensor
242
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
243
+ values = s[:, indexes.view(-1)]
244
+ values = values.view(B, K, indexes.shape[-1])
245
+ return values, indexes, mask
246
+
247
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
248
+ """Revert model logits obtained on a sequence built from the pattern
249
+ back to a tensor matching the original sequence.
250
+
251
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
252
+ 1. It is designed to work with the extra cardinality dimension
253
+ 2. We return the logits for the first sequence item that matches the special_token and
254
+ which matching target in the original sequence is the first item of the sequence,
255
+ while we skip the last logits as there is no matching target
256
+ """
257
+ B, card, K, S = logits.shape
258
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
259
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
260
+ )
261
+ logits = logits.reshape(B, card, -1)
262
+ # we append the special token as the last index of our flattened z tensor
263
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
264
+ values = logits[:, :, indexes.view(-1)]
265
+ values = values.view(B, card, K, indexes.shape[-1])
266
+ return values, indexes, mask
267
+
268
+
269
+ class CodebooksPatternProvider(ABC):
270
+ """Abstraction around providing pattern for interleaving codebooks.
271
+
272
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
273
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
274
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
275
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
276
+ can be used to construct a new sequence from the original codes respecting the specified
277
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
278
+ being a tuple with the original timestep and codebook to build the new sequence.
279
+ Note that all patterns must start with an empty list that is then used to insert a first
280
+ sequence step of special tokens in the newly generated sequence.
281
+
282
+ Args:
283
+ n_q (int): number of codebooks.
284
+ cached (bool): if True, patterns for a given length are cached. In general
285
+ that should be true for efficiency reason to avoid synchronization points.
286
+ """
287
+ def __init__(self, n_q: int, cached: bool = True):
288
+ assert n_q > 0
289
+ self.n_q = n_q
290
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
291
+
292
+ @abstractmethod
293
+ def get_pattern(self, timesteps: int) -> Pattern:
294
+ """Builds pattern with specific interleaving between codebooks.
295
+
296
+ Args:
297
+ timesteps (int): Total numer of timesteps.
298
+ """
299
+ raise NotImplementedError()
300
+
301
+
302
+ class DelayedPatternProvider(CodebooksPatternProvider):
303
+ """Provider for delayed pattern across delayed codebooks.
304
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
305
+ from different timesteps.
306
+
307
+ Example:
308
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
309
+ [[1, 2, 3, 4],
310
+ [1, 2, 3, 4],
311
+ [1, 2, 3, 4]]
312
+ The resulting sequence obtained from the returned pattern is:
313
+ [[S, 1, 2, 3, 4],
314
+ [S, S, 1, 2, 3],
315
+ [S, S, S, 1, 2]]
316
+ (with S being a special token)
317
+
318
+ Args:
319
+ n_q (int): Number of codebooks.
320
+ delays (Optional[List[int]]): Delay for each of the codebooks.
321
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
322
+ flatten_first (int): Flatten the first N timesteps.
323
+ empty_initial (int): Prepend with N empty list of coordinates.
324
+ """
325
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
326
+ flatten_first: int = 0, empty_initial: int = 0):
327
+ super().__init__(n_q)
328
+ if delays is None:
329
+ delays = list(range(n_q))
330
+ self.delays = delays
331
+ self.flatten_first = flatten_first
332
+ self.empty_initial = empty_initial
333
+ assert len(self.delays) == self.n_q
334
+ assert sorted(self.delays) == self.delays
335
+
336
+ def get_pattern(self, timesteps: int) -> Pattern:
337
+ out: PatternLayout = [[]]
338
+ max_delay = max(self.delays)
339
+ if self.empty_initial:
340
+ out += [[] for _ in range(self.empty_initial)]
341
+ if self.flatten_first:
342
+ for t in range(min(timesteps, self.flatten_first)):
343
+ for q in range(self.n_q):
344
+ out.append([LayoutCoord(t, q)])
345
+ for t in range(self.flatten_first, timesteps + max_delay):
346
+ v = []
347
+ for q, delay in enumerate(self.delays):
348
+ t_for_q = t - delay
349
+ if t_for_q >= self.flatten_first:
350
+ v.append(LayoutCoord(t_for_q, q))
351
+ out.append(v)
352
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
353
+
354
+
355
+ class ParallelPatternProvider(DelayedPatternProvider):
356
+ """Provider for parallel pattern across codebooks.
357
+ This pattern provider is a special case of the delayed pattern with actually no delay,
358
+ hence delays=repeat(0, n_q).
359
+
360
+ Args:
361
+ n_q (int): Number of codebooks.
362
+ """
363
+ def __init__(self, n_q: int):
364
+ super().__init__(n_q, [0] * n_q)
365
+
366
+
367
+ class UnrolledPatternProvider(CodebooksPatternProvider):
368
+ """Provider for unrolling codebooks pattern.
369
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
370
+ while also specifying a given delay between the flattened codebooks representation, allowing to
371
+ unroll the codebooks in the sequence.
372
+
373
+ Example:
374
+ 1. Flattening of the codebooks.
375
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
376
+ taking n_q = 3 and timesteps = 4:
377
+ [[1, 2, 3, 4],
378
+ [1, 2, 3, 4],
379
+ [1, 2, 3, 4]]
380
+ will result into:
381
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
382
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
383
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
384
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
385
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
386
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
387
+ [[1, 2, 3, 4],
388
+ [1, 2, 3, 4],
389
+ [1, 2, 3, 4]]
390
+ will result into:
391
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
392
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
393
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
394
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
395
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
396
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
397
+ and delays = [0, 3, 3]:
398
+ [[1, 2, 3, 4],
399
+ [1, 2, 3, 4],
400
+ [1, 2, 3, 4]]
401
+ will result into:
402
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
403
+ [S, S, S, 1, S, 2, S, 3, S, 4],
404
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
405
+
406
+ Args:
407
+ n_q (int): Number of codebooks.
408
+ flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
409
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
410
+ have n_q extra steps for each timestep.
411
+ delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
412
+ no delay is added and therefore will default to [0] * ``n_q``.
413
+ Note that two codebooks that will be flattened to the same inner step
414
+ should have the same delay, otherwise the pattern is considered as invalid.
415
+ """
416
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
417
+
418
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
419
+ delays: tp.Optional[tp.List[int]] = None):
420
+ super().__init__(n_q)
421
+ if flattening is None:
422
+ flattening = list(range(n_q))
423
+ if delays is None:
424
+ delays = [0] * n_q
425
+ assert len(flattening) == n_q
426
+ assert len(delays) == n_q
427
+ assert sorted(flattening) == flattening
428
+ assert sorted(delays) == delays
429
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
430
+ self.max_delay = max(delays)
431
+
432
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
433
+ """Build a flattened codebooks representation as a dictionary of inner step
434
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
435
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
436
+ """
437
+ flattened_codebooks: dict = {}
438
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
439
+ if inner_step not in flattened_codebooks:
440
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
441
+ else:
442
+ flat_codebook = flattened_codebooks[inner_step]
443
+ assert flat_codebook.delay == delay, (
444
+ "Delay and flattening between codebooks is inconsistent: ",
445
+ "two codebooks flattened to the same position should have the same delay."
446
+ )
447
+ flat_codebook.codebooks.append(q)
448
+ flattened_codebooks[inner_step] = flat_codebook
449
+ return flattened_codebooks
450
+
451
+ @property
452
+ def _num_inner_steps(self):
453
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
454
+ """
455
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
456
+
457
+ def num_virtual_steps(self, timesteps: int) -> int:
458
+ return timesteps * self._num_inner_steps + 1
459
+
460
+ def get_pattern(self, timesteps: int) -> Pattern:
461
+ """Builds pattern for delay across codebooks.
462
+
463
+ Args:
464
+ timesteps (int): Total numer of timesteps.
465
+ """
466
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
467
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
468
+ indexed_out: list = [(-1, [])]
469
+ max_timesteps = timesteps + self.max_delay
470
+ for t in range(max_timesteps):
471
+ # for each timestep, we unroll the flattened codebooks,
472
+ # emitting the sequence step with the corresponding delay
473
+ for step in range(self._num_inner_steps):
474
+ if step in self._flattened_codebooks:
475
+ # we have codebooks at this virtual step to emit
476
+ step_codebooks = self._flattened_codebooks[step]
477
+ t_for_q = t + step_codebooks.delay
478
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
479
+ if t_for_q < max_timesteps and t < max_timesteps:
480
+ indexed_out.append((t_for_q, coords))
481
+ else:
482
+ # there is no codebook in this virtual step so we emit an empty list
483
+ indexed_out.append((t, []))
484
+ out = [coords for _, coords in sorted(indexed_out)]
485
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
486
+
487
+
488
+ class VALLEPattern(CodebooksPatternProvider):
489
+ """Almost VALL-E style pattern. We futher allow some delays for the
490
+ codebooks other than the first one.
491
+
492
+ Args:
493
+ n_q (int): Number of codebooks.
494
+ delays (Optional[List[int]]): Delay for each of the codebooks.
495
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
496
+ """
497
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
498
+ super().__init__(n_q)
499
+ if delays is None:
500
+ delays = [0] * (n_q - 1)
501
+ self.delays = delays
502
+ assert len(self.delays) == self.n_q - 1
503
+ assert sorted(self.delays) == self.delays
504
+
505
+ def get_pattern(self, timesteps: int) -> Pattern:
506
+ out: PatternLayout = [[]]
507
+ for t in range(timesteps):
508
+ out.append([LayoutCoord(t, 0)])
509
+ max_delay = max(self.delays)
510
+ for t in range(timesteps + max_delay):
511
+ v = []
512
+ for q, delay in enumerate(self.delays):
513
+ t_for_q = t - delay
514
+ if t_for_q >= 0:
515
+ v.append(LayoutCoord(t_for_q, q + 1))
516
+ out.append(v)
517
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
518
+
519
+
520
+ class MusicLMPattern(CodebooksPatternProvider):
521
+ """Almost MusicLM style pattern. This is equivalent to full flattening
522
+ but in a different order.
523
+
524
+ Args:
525
+ n_q (int): Number of codebooks.
526
+ group_by (int): Number of codebooks to group together.
527
+ """
528
+ def __init__(self, n_q: int, group_by: int = 2):
529
+ super().__init__(n_q)
530
+ self.group_by = group_by
531
+
532
+ def get_pattern(self, timesteps: int) -> Pattern:
533
+ out: PatternLayout = [[]]
534
+ for offset in range(0, self.n_q, self.group_by):
535
+ for t in range(timesteps):
536
+ for q in range(offset, offset + self.group_by):
537
+ out.append([LayoutCoord(t, q)])
538
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
lib/voicecraft/models/modules/__init__.py ADDED
File without changes
lib/voicecraft/models/modules/activation.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py, modified by Puyuan Peng, 2024
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
+ from torch.nn.parameter import Parameter
11
+ import logging
12
+ from typing import Callable, List, Optional, Tuple, Union
13
+ from typing import TYPE_CHECKING
14
+ if TYPE_CHECKING:
15
+ from torch.types import _dtype as DType
16
+ else:
17
+ # The JIT doesn't understand Union, nor torch.dtype here
18
+ DType = int
19
+
20
+ def _canonical_mask(
21
+ mask: Optional[Tensor],
22
+ mask_name: str,
23
+ other_type: Optional[DType],
24
+ other_name: str,
25
+ target_type: DType,
26
+ check_other: bool = True,
27
+ ) -> Optional[Tensor]:
28
+
29
+ if mask is not None:
30
+ _mask_dtype = mask.dtype
31
+ _mask_is_float = torch.is_floating_point(mask)
32
+ if _mask_dtype != torch.bool and not _mask_is_float:
33
+ raise AssertionError(
34
+ f"only bool and floating types of {mask_name} are supported")
35
+ if check_other and other_type is not None:
36
+ if _mask_dtype != other_type:
37
+ warnings.warn(
38
+ f"Support for mismatched {mask_name} and {other_name} "
39
+ "is deprecated. Use same type for both instead."
40
+ )
41
+ if not _mask_is_float:
42
+ mask = (
43
+ torch.zeros_like(mask, dtype=target_type)
44
+ .masked_fill_(mask, float("-inf"))
45
+ )
46
+ return mask
47
+
48
+ def _in_projection_packed(
49
+ q: Tensor,
50
+ k: Tensor,
51
+ v: Tensor,
52
+ w: Tensor,
53
+ b: Optional[Tensor] = None,
54
+ ) -> List[Tensor]:
55
+ r"""
56
+ Performs the in-projection step of the attention operation, using packed weights.
57
+ Output is a triple containing projection tensors for query, key and value.
58
+
59
+ Args:
60
+ q, k, v: query, key and value tensors to be projected. For self-attention,
61
+ these are typically the same tensor; for encoder-decoder attention,
62
+ k and v are typically the same tensor. (We take advantage of these
63
+ identities for performance if they are present.) Regardless, q, k and v
64
+ must share a common embedding dimension; otherwise their shapes may vary.
65
+ w: projection weights for q, k and v, packed into a single tensor. Weights
66
+ are packed along dimension 0, in q, k, v order.
67
+ b: optional projection biases for q, k and v, packed into a single tensor
68
+ in q, k, v order.
69
+
70
+ Shape:
71
+ Inputs:
72
+ - q: :math:`(..., E)` where E is the embedding dimension
73
+ - k: :math:`(..., E)` where E is the embedding dimension
74
+ - v: :math:`(..., E)` where E is the embedding dimension
75
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
76
+ - b: :math:`E * 3` where E is the embedding dimension
77
+
78
+ Output:
79
+ - in output list :math:`[q', k', v']`, each output tensor will have the
80
+ same shape as the corresponding input tensor.
81
+ """
82
+ E = q.size(-1)
83
+ if k is v:
84
+ if q is k:
85
+ # self-attention
86
+ proj = F.linear(q, w, b)
87
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
88
+ proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
89
+ return proj[0], proj[1], proj[2]
90
+ else:
91
+ # encoder-decoder attention
92
+ w_q, w_kv = w.split([E, E * 2])
93
+ if b is None:
94
+ b_q = b_kv = None
95
+ else:
96
+ b_q, b_kv = b.split([E, E * 2])
97
+ q_proj = F.linear(q, w_q, b_q)
98
+ kv_proj = F.linear(k, w_kv, b_kv)
99
+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
100
+ kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
101
+ return (q_proj, kv_proj[0], kv_proj[1])
102
+ else:
103
+ w_q, w_k, w_v = w.chunk(3)
104
+ if b is None:
105
+ b_q = b_k = b_v = None
106
+ else:
107
+ b_q, b_k, b_v = b.chunk(3)
108
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
109
+
110
+ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
111
+ if input is None:
112
+ return None
113
+ elif isinstance(input, torch.Tensor):
114
+ return input.dtype
115
+ raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
116
+ class MultiheadAttention(Module):
117
+ r"""Allows the model to jointly attend to information
118
+ from different representation subspaces as described in the paper:
119
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
120
+
121
+ Multi-Head Attention is defined as:
122
+
123
+ .. math::
124
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
125
+
126
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
127
+
128
+ ``forward()`` will use a special optimized implementation if all of the following
129
+ conditions are met:
130
+
131
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
132
+ restriction will be loosened in the future.)
133
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
134
+ - training is disabled (using ``.eval()``)
135
+ - dropout is 0
136
+ - ``add_bias_kv`` is ``False``
137
+ - ``add_zero_attn`` is ``False``
138
+ - ``batch_first`` is ``True`` and the input is batched
139
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
140
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
141
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
142
+ nor ``attn_mask`` is passed
143
+
144
+ If the optimized implementation is in use, a
145
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
146
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
147
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
148
+ will be returned, and an additional speedup proportional to the fraction of the input
149
+ that is padding can be expected.
150
+
151
+ Args:
152
+ embed_dim: Total dimension of the model.
153
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
154
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
155
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
156
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
157
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
158
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
159
+ Default: ``False``.
160
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
161
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
162
+ batch_first: If ``True``, then the input and output tensors are provided
163
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
164
+
165
+ Examples::
166
+
167
+ >>> # xdoctest: +SKIP
168
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
169
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
170
+
171
+ """
172
+ __constants__ = ["batch_first"]
173
+ bias_k: Optional[torch.Tensor]
174
+ bias_v: Optional[torch.Tensor]
175
+
176
+ def __init__(
177
+ self,
178
+ embed_dim,
179
+ num_heads,
180
+ dropout=0.0,
181
+ bias=True,
182
+ add_bias_kv=False,
183
+ add_zero_attn=False,
184
+ kdim=None,
185
+ vdim=None,
186
+ batch_first=False,
187
+ linear1_cls=Linear,
188
+ linear2_cls=Linear,
189
+ device=None,
190
+ dtype=None,
191
+ ) -> None:
192
+ factory_kwargs = {"device": device, "dtype": dtype}
193
+ super(MultiheadAttention, self).__init__()
194
+ self.embed_dim = embed_dim
195
+ self.kdim = kdim if kdim is not None else embed_dim
196
+ self.vdim = vdim if vdim is not None else embed_dim
197
+ self._qkv_same_embed_dim = (
198
+ self.kdim == embed_dim and self.vdim == embed_dim
199
+ )
200
+
201
+ self.num_heads = num_heads
202
+ self.dropout = dropout
203
+ self.batch_first = batch_first
204
+ self.head_dim = embed_dim // num_heads
205
+ assert (
206
+ self.head_dim * num_heads == self.embed_dim
207
+ ), "embed_dim must be divisible by num_heads"
208
+
209
+ if add_bias_kv:
210
+ self.bias_k = Parameter(
211
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
212
+ )
213
+ self.bias_v = Parameter(
214
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
215
+ )
216
+ else:
217
+ self.bias_k = self.bias_v = None
218
+
219
+ if linear1_cls == Linear:
220
+ if not self._qkv_same_embed_dim:
221
+ self.q_proj_weight = Parameter(
222
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
223
+ )
224
+ self.k_proj_weight = Parameter(
225
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
226
+ )
227
+ self.v_proj_weight = Parameter(
228
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
229
+ )
230
+ self.register_parameter("in_proj_weight", None)
231
+ else:
232
+ # go down this route with voicecraft
233
+ self.in_proj_weight = Parameter(
234
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
235
+ )
236
+ self.register_parameter("q_proj_weight", None)
237
+ self.register_parameter("k_proj_weight", None)
238
+ self.register_parameter("v_proj_weight", None)
239
+
240
+ if bias: # True by default
241
+ self.in_proj_bias = Parameter(
242
+ torch.empty(3 * embed_dim, **factory_kwargs)
243
+ )
244
+ else:
245
+ self.register_parameter("in_proj_bias", None)
246
+ self.out_proj = NonDynamicallyQuantizableLinear(
247
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
248
+ )
249
+
250
+ self._reset_parameters()
251
+ else:
252
+ if not self._qkv_same_embed_dim:
253
+ raise NotImplementedError
254
+ else:
255
+ self.in_proj_linear = linear1_cls(
256
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
257
+ )
258
+ self.in_proj_weight = self.in_proj_linear.weight
259
+
260
+ self.register_parameter("q_proj_weight", None)
261
+ self.register_parameter("k_proj_weight", None)
262
+ self.register_parameter("v_proj_weight", None)
263
+
264
+ if bias:
265
+ self.in_proj_bias = self.in_proj_linear.bias
266
+ else:
267
+ self.register_parameter("in_proj_bias", None)
268
+
269
+ self.out_proj = linear2_cls(
270
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
271
+ )
272
+
273
+ if self.bias_k is not None:
274
+ xavier_normal_(self.bias_k)
275
+ if self.bias_v is not None:
276
+ xavier_normal_(self.bias_v)
277
+
278
+ self.add_zero_attn = add_zero_attn
279
+
280
+ def _reset_parameters(self):
281
+ if self._qkv_same_embed_dim:
282
+ xavier_uniform_(self.in_proj_weight)
283
+ else:
284
+ xavier_uniform_(self.q_proj_weight)
285
+ xavier_uniform_(self.k_proj_weight)
286
+ xavier_uniform_(self.v_proj_weight)
287
+
288
+ if self.in_proj_bias is not None:
289
+ constant_(self.in_proj_bias, 0.0)
290
+ constant_(self.out_proj.bias, 0.0)
291
+
292
+ if self.bias_k is not None:
293
+ xavier_normal_(self.bias_k)
294
+ if self.bias_v is not None:
295
+ xavier_normal_(self.bias_v)
296
+
297
+ def __setstate__(self, state):
298
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
299
+ if "_qkv_same_embed_dim" not in state:
300
+ state["_qkv_same_embed_dim"] = True
301
+
302
+ super(MultiheadAttention, self).__setstate__(state)
303
+
304
+ def forward(
305
+ self,
306
+ query: Tensor,
307
+ key: Tensor,
308
+ value: Tensor,
309
+ key_padding_mask: Optional[Tensor] = None,
310
+ need_weights: bool = True,
311
+ attn_mask: Optional[Tensor] = None,
312
+ average_attn_weights: bool = True,
313
+ past: Optional[Tensor] = None,
314
+ ) -> Tuple[Tensor, Optional[Tensor]]:
315
+ r"""
316
+ Args:
317
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
318
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
319
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
320
+ Queries are compared against key-value pairs to produce the output.
321
+ See "Attention Is All You Need" for more details.
322
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
323
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
324
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
325
+ See "Attention Is All You Need" for more details.
326
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
327
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
328
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
329
+ See "Attention Is All You Need" for more details.
330
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
331
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
332
+ Binary and byte masks are supported.
333
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
334
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
335
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
336
+ Default: ``True``.
337
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
338
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
339
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
340
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
341
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
342
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
343
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
344
+ the attention weight.
345
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
346
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
347
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
348
+
349
+ Outputs:
350
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
351
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
352
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
353
+ embedding dimension ``embed_dim``.
354
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
355
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
356
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
357
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
358
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
359
+
360
+ .. note::
361
+ `batch_first` argument is ignored for unbatched inputs.
362
+ """
363
+ is_batched = query.dim() == 3
364
+ if key_padding_mask is not None:
365
+ _kpm_dtype = key_padding_mask.dtype
366
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
367
+ key_padding_mask
368
+ ):
369
+ raise AssertionError(
370
+ "only bool and floating types of key_padding_mask are supported"
371
+ )
372
+ why_not_fast_path = ""
373
+ if not is_batched:
374
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
375
+ elif query is not key or key is not value:
376
+ # When lifting this restriction, don't forget to either
377
+ # enforce that the dtypes all match or test cases where
378
+ # they don't!
379
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
380
+ elif (
381
+ self.in_proj_bias is not None
382
+ and query.dtype != self.in_proj_bias.dtype
383
+ ):
384
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
385
+ elif (
386
+ self.in_proj_weight is not None
387
+ and query.dtype != self.in_proj_weight.dtype
388
+ ):
389
+ # this case will fail anyway, but at least they'll get a useful error message.
390
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
391
+ elif self.training:
392
+ why_not_fast_path = "training is enabled"
393
+ elif not self.batch_first:
394
+ why_not_fast_path = "batch_first was not True"
395
+ elif self.bias_k is not None:
396
+ why_not_fast_path = "self.bias_k was not None"
397
+ elif self.bias_v is not None:
398
+ why_not_fast_path = "self.bias_v was not None"
399
+ elif self.dropout:
400
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
401
+ elif self.add_zero_attn:
402
+ why_not_fast_path = "add_zero_attn was enabled"
403
+ elif not self._qkv_same_embed_dim:
404
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
405
+ elif attn_mask is not None:
406
+ why_not_fast_path = "attn_mask was not None"
407
+ elif query.is_nested and key_padding_mask is not None:
408
+ why_not_fast_path = (
409
+ "key_padding_mask is not supported with NestedTensor input"
410
+ )
411
+ elif self.num_heads % 2 == 1:
412
+ why_not_fast_path = "num_heads is odd"
413
+ elif torch.is_autocast_enabled():
414
+ why_not_fast_path = "autocast is enabled"
415
+
416
+ if not why_not_fast_path:
417
+ tensor_args = (
418
+ query,
419
+ key,
420
+ value,
421
+ self.in_proj_weight,
422
+ self.in_proj_bias,
423
+ self.out_proj.weight,
424
+ self.out_proj.bias,
425
+ )
426
+ # We have to use list comprehensions below because TorchScript does not support
427
+ # generator expressions.
428
+ if torch.overrides.has_torch_function(tensor_args):
429
+ why_not_fast_path = "some Tensor argument has_torch_function"
430
+ elif not all(
431
+ [
432
+ (x is None or x.is_cuda or "cpu" in str(x.device))
433
+ for x in tensor_args
434
+ ]
435
+ ):
436
+ why_not_fast_path = (
437
+ "some Tensor argument is neither CUDA nor CPU"
438
+ )
439
+ elif torch.is_grad_enabled() and any(
440
+ [x is not None and x.requires_grad for x in tensor_args]
441
+ ):
442
+ why_not_fast_path = (
443
+ "grad is enabled and at least one of query or the "
444
+ "input/output projection weights or biases requires_grad"
445
+ )
446
+ if not why_not_fast_path:
447
+ return torch._native_multi_head_attention(
448
+ query,
449
+ key,
450
+ value,
451
+ self.embed_dim,
452
+ self.num_heads,
453
+ self.in_proj_weight,
454
+ self.in_proj_bias,
455
+ self.out_proj.weight,
456
+ self.out_proj.bias,
457
+ key_padding_mask
458
+ if key_padding_mask is not None
459
+ else attn_mask,
460
+ need_weights,
461
+ average_attn_weights,
462
+ 1
463
+ if key_padding_mask is not None
464
+ else 0
465
+ if attn_mask is not None
466
+ else None,
467
+ )
468
+
469
+ any_nested = query.is_nested or key.is_nested or value.is_nested
470
+ assert not any_nested, (
471
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
472
+ + f"The fast path was not hit because {why_not_fast_path}"
473
+ )
474
+
475
+ if self.batch_first and is_batched:
476
+ # make sure that the transpose op does not affect the "is" property
477
+ if key is value:
478
+ if query is key:
479
+ query = key = value = query.transpose(1, 0)
480
+ else:
481
+ query, key = [x.transpose(1, 0) for x in (query, key)]
482
+ value = key
483
+ else:
484
+ query, key, value = [
485
+ x.transpose(1, 0) for x in (query, key, value)
486
+ ]
487
+
488
+ if not self._qkv_same_embed_dim:
489
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
490
+ query,
491
+ key,
492
+ value,
493
+ self.embed_dim,
494
+ self.num_heads,
495
+ self.in_proj_weight,
496
+ self.in_proj_bias,
497
+ self.bias_k,
498
+ self.bias_v,
499
+ self.add_zero_attn,
500
+ self.dropout,
501
+ self.out_proj.weight,
502
+ self.out_proj.bias,
503
+ training=self.training,
504
+ key_padding_mask=key_padding_mask,
505
+ need_weights=need_weights,
506
+ attn_mask=attn_mask,
507
+ use_separate_proj_weight=True,
508
+ q_proj_weight=self.q_proj_weight,
509
+ k_proj_weight=self.k_proj_weight,
510
+ v_proj_weight=self.v_proj_weight,
511
+ average_attn_weights=average_attn_weights,
512
+ )
513
+ else:
514
+ # re-write the self.attention here, to get k, v cache
515
+ tgt_len, bsz, embed_dim = query.shape
516
+ src_len, _, _ = key.shape
517
+ num_heads = self.num_heads
518
+ key_padding_mask = _canonical_mask(
519
+ mask=key_padding_mask,
520
+ mask_name="key_padding_mask",
521
+ other_type=_none_or_dtype(attn_mask),
522
+ other_name="attn_mask",
523
+ target_type=query.dtype
524
+ )
525
+ attn_mask = _canonical_mask(
526
+ mask=attn_mask,
527
+ mask_name="attn_mask",
528
+ other_type=None,
529
+ other_name="",
530
+ target_type=query.dtype,
531
+ check_other=False,
532
+ )
533
+ head_dim = self.embed_dim // self.num_heads
534
+ assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
535
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
536
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
537
+ # k_present, v_present = k, v
538
+
539
+ #
540
+ # reshape q, k, v for multihead attention and make em batch first
541
+ #
542
+
543
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
544
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
545
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
546
+ src_len = k.size(1)
547
+ if past is not None and past.ndim > 2:
548
+ expected_src_len = src_len + past[0].shape[-2]
549
+ else:
550
+ expected_src_len = src_len
551
+
552
+
553
+ # ensure attn_mask's dim is 3
554
+ if attn_mask.dim() == 2:
555
+ correct_2d_size = (tgt_len, expected_src_len)
556
+ if attn_mask.shape != correct_2d_size:
557
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
558
+ attn_mask = attn_mask.unsqueeze(0)
559
+ elif attn_mask.dim() == 3:
560
+ correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
561
+ if attn_mask.shape != correct_3d_size:
562
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
563
+ else:
564
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
565
+
566
+ if key_padding_mask is not None:
567
+ assert key_padding_mask.shape == (bsz, expected_src_len), \
568
+ f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
569
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
570
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
571
+ if attn_mask is None:
572
+ attn_mask = key_padding_mask
573
+ else:
574
+ attn_mask = attn_mask + key_padding_mask
575
+
576
+ if not self.training:
577
+ dropout_p = 0.0
578
+ else:
579
+ dropout_p = self.dropout
580
+
581
+ if need_weights:
582
+ raise NotImplementedError("need_weights not implemented for voicecraft")
583
+ # B, Nt, E = q.shape
584
+ # q_scaled = q / math.sqrt(E)
585
+
586
+ # assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
587
+
588
+ # if attn_mask is not None:
589
+ # attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
590
+ # else:
591
+ # attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
592
+ # attn_output_weights = softmax(attn_output_weights, dim=-1)
593
+ # if dropout_p > 0.0:
594
+ # attn_output_weights = dropout(attn_output_weights, p=dropout_p)
595
+
596
+ # attn_output = torch.bmm(attn_output_weights, v)
597
+
598
+ # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
599
+ # attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
600
+ # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
601
+
602
+ # # optionally average attention weights over heads
603
+ # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
604
+ # if average_attn_weights:
605
+ # attn_output_weights = attn_output_weights.mean(dim=1)
606
+
607
+ # if not is_batched:
608
+ # # squeeze the output if input was unbatched
609
+ # attn_output = attn_output.squeeze(1)
610
+ # attn_output_weights = attn_output_weights.squeeze(0)
611
+ # return attn_output, attn_output_weights
612
+ else:
613
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
614
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
615
+ # in order to match the input for SDPA of (N, num_heads, L, S)
616
+ if attn_mask is not None:
617
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
618
+ attn_mask = attn_mask.unsqueeze(0)
619
+ else:
620
+ attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
621
+
622
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
623
+ k = k.view(bsz, num_heads, src_len, head_dim)
624
+ v = v.view(bsz, num_heads, src_len, head_dim)
625
+ # logging.info(f"shape of past: {past.shape}")
626
+ if past is not None:
627
+ present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
628
+ if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
629
+ pk, pv = past
630
+ k = torch.cat([pk, k], dim=-2)
631
+ v = torch.cat([pv, v], dim=-2)
632
+ else:
633
+ present = None
634
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
635
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
636
+
637
+ attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
638
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
639
+ if not is_batched:
640
+ # squeeze the output if input was unbatched
641
+ attn_output = attn_output.squeeze(1)
642
+ # if self.training:
643
+ # return attn_output, None
644
+ # else:
645
+ # return (attn_output, present), None
646
+
647
+ # harded coded, the code do not support returning attn weigths yet
648
+ attn_output_weights=None
649
+ if self.batch_first and is_batched:
650
+ return attn_output.transpose(1, 0), present
651
+ else:
652
+ return attn_output, present
653
+
lib/voicecraft/models/modules/embedding.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ class TokenEmbedding(nn.Module):
23
+ def __init__(
24
+ self,
25
+ dim_model: int,
26
+ vocab_size: int,
27
+ dropout: float = 0.0,
28
+ ):
29
+ super().__init__()
30
+
31
+ self.vocab_size = vocab_size
32
+ self.dim_model = dim_model
33
+
34
+ self.dropout = torch.nn.Dropout(p=dropout)
35
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
36
+
37
+ @property
38
+ def weight(self) -> torch.Tensor:
39
+ return self.word_embeddings.weight
40
+
41
+ def embedding(self, index: int) -> torch.Tensor:
42
+ return self.word_embeddings.weight[index : index + 1]
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ X = self.word_embeddings(x)
46
+ X = self.dropout(X)
47
+
48
+ return X
49
+
50
+
51
+ class SinePositionalEmbedding(nn.Module):
52
+ def __init__(
53
+ self,
54
+ dim_model: int,
55
+ dropout: float = 0.0,
56
+ scale: bool = False,
57
+ alpha: bool = False,
58
+ ):
59
+ super().__init__()
60
+ self.dim_model = dim_model
61
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
62
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
63
+ self.dropout = torch.nn.Dropout(p=dropout)
64
+
65
+ self.reverse = False
66
+ self.pe = None
67
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
68
+
69
+ def extend_pe(self, x):
70
+ """Reset the positional encodings."""
71
+ if self.pe is not None:
72
+ if self.pe.size(1) >= x.size(1):
73
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
74
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
75
+ return
76
+ pe = torch.zeros(x.size(1), self.dim_model)
77
+ if self.reverse:
78
+ position = torch.arange(
79
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
80
+ ).unsqueeze(1)
81
+ else:
82
+ position = torch.arange(
83
+ 0, x.size(1), dtype=torch.float32
84
+ ).unsqueeze(1)
85
+ div_term = torch.exp(
86
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
87
+ * -(math.log(10000.0) / self.dim_model)
88
+ )
89
+ pe[:, 0::2] = torch.sin(position * div_term)
90
+ pe[:, 1::2] = torch.cos(position * div_term)
91
+ pe = pe.unsqueeze(0)
92
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ self.extend_pe(x)
96
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
97
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
98
+ return self.dropout(output)
lib/voicecraft/models/modules/sampling.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def top_k_top_p_filtering(
5
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
6
+ ):
7
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
8
+ Args:
9
+ logits: logits distribution shape (batch size, vocabulary size)
10
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
11
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
12
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
13
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
14
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
15
+ """
16
+ if top_k > 0:
17
+ top_k = min(
18
+ max(top_k, min_tokens_to_keep), logits.size(-1)
19
+ ) # Safety check
20
+ # Remove all tokens with a probability less than the last token of the top-k
21
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
22
+ logits[indices_to_remove] = filter_value
23
+
24
+ if top_p < 1.0:
25
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
26
+ cumulative_probs = torch.cumsum(
27
+ F.softmax(sorted_logits, dim=-1), dim=-1
28
+ )
29
+
30
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
31
+ sorted_indices_to_remove = cumulative_probs > top_p
32
+ if min_tokens_to_keep > 1:
33
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
34
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
35
+ # Shift the indices to the right to keep also the first token above the threshold
36
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
37
+ ..., :-1
38
+ ].clone()
39
+ sorted_indices_to_remove[..., 0] = 0
40
+
41
+ # scatter sorted tensors to original indexing
42
+ indices_to_remove = sorted_indices_to_remove.scatter(
43
+ 1, sorted_indices, sorted_indices_to_remove
44
+ )
45
+ logits[indices_to_remove] = filter_value
46
+ return logits
47
+
48
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
49
+ # temperature: (`optional`) float
50
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
51
+ # top_k: (`optional`) int
52
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
53
+ # top_p: (`optional`) float
54
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
55
+
56
+ # Temperature (higher temperature => more likely to sample low probability tokens)
57
+ if temperature != 1.0:
58
+ logits = logits / temperature
59
+ # Top-p/top-k filtering
60
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
61
+ # Sample
62
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
63
+ return token
lib/voicecraft/models/modules/scaling.py ADDED
@@ -0,0 +1,1406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py
2
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import collections
20
+ import logging
21
+ import random
22
+ import math
23
+ from functools import reduce
24
+ from itertools import repeat
25
+ from typing import Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch import Tensor
31
+ from torch.nn import Embedding as ScaledEmbedding
32
+
33
+ # from valle.utils import Transpose
34
+
35
+ class Transpose(nn.Identity):
36
+ """(N, T, D) -> (N, D, T)"""
37
+
38
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
39
+ return input.transpose(1, 2)
40
+
41
+ class ActivationBalancerFunction(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(
44
+ ctx,
45
+ x: Tensor,
46
+ scale_factor: Tensor,
47
+ sign_factor: Optional[Tensor],
48
+ channel_dim: int,
49
+ ) -> Tensor:
50
+ if channel_dim < 0:
51
+ channel_dim += x.ndim
52
+ ctx.channel_dim = channel_dim
53
+ xgt0 = x > 0
54
+ if sign_factor is None:
55
+ ctx.save_for_backward(xgt0, scale_factor)
56
+ else:
57
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
58
+ return x
59
+
60
+ @staticmethod
61
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
62
+ if len(ctx.saved_tensors) == 3:
63
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
64
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
65
+ scale_factor = scale_factor.unsqueeze(-1)
66
+ sign_factor = sign_factor.unsqueeze(-1)
67
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
68
+ else:
69
+ xgt0, scale_factor = ctx.saved_tensors
70
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
71
+ scale_factor = scale_factor.unsqueeze(-1)
72
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
73
+ neg_delta_grad = x_grad.abs() * factor
74
+ return (
75
+ x_grad - neg_delta_grad,
76
+ None,
77
+ None,
78
+ None,
79
+ )
80
+
81
+
82
+ def _compute_scale_factor(
83
+ x: Tensor,
84
+ channel_dim: int,
85
+ min_abs: float,
86
+ max_abs: float,
87
+ gain_factor: float,
88
+ max_factor: float,
89
+ ) -> Tensor:
90
+ if channel_dim < 0:
91
+ channel_dim += x.ndim
92
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
93
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
94
+
95
+ if min_abs == 0.0:
96
+ below_threshold = 0.0
97
+ else:
98
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
99
+ # x_abs)_mean , min_abs.
100
+ below_threshold = (
101
+ (min_abs - x_abs_mean) * (gain_factor / min_abs)
102
+ ).clamp(min=0, max=max_factor)
103
+
104
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
105
+ min=0, max=max_factor
106
+ )
107
+
108
+ return below_threshold - above_threshold
109
+
110
+
111
+ def _compute_sign_factor(
112
+ x: Tensor,
113
+ channel_dim: int,
114
+ min_positive: float,
115
+ max_positive: float,
116
+ gain_factor: float,
117
+ max_factor: float,
118
+ ) -> Tensor:
119
+ if channel_dim < 0:
120
+ channel_dim += x.ndim
121
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
122
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
123
+ if min_positive == 0.0:
124
+ factor1 = 0.0
125
+ else:
126
+ # 0 if proportion_positive >= min_positive, else can be
127
+ # as large as max_factor.
128
+ factor1 = (
129
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
130
+ ).clamp_(min=0, max=max_factor)
131
+
132
+ if max_positive == 1.0:
133
+ factor2 = 0.0
134
+ else:
135
+ # 0 if self.proportion_positive <= max_positive, else can be
136
+ # as large as -max_factor.
137
+ factor2 = (
138
+ (proportion_positive - max_positive)
139
+ * (gain_factor / (1.0 - max_positive))
140
+ ).clamp_(min=0, max=max_factor)
141
+ sign_factor = factor1 - factor2
142
+ # require min_positive != 0 or max_positive != 1:
143
+ assert not isinstance(sign_factor, float)
144
+ return sign_factor
145
+
146
+
147
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
148
+ """
149
+ This object is used in class ActivationBalancer when the user specified
150
+ min_positive=0, max_positive=1, so there are no constraints on the signs
151
+ of the activations and only the absolute value has a constraint.
152
+ """
153
+
154
+ @staticmethod
155
+ def forward(
156
+ ctx,
157
+ x: Tensor,
158
+ sign_factor: Tensor,
159
+ scale_factor: Tensor,
160
+ channel_dim: int,
161
+ ) -> Tensor:
162
+ if channel_dim < 0:
163
+ channel_dim += x.ndim
164
+ ctx.channel_dim = channel_dim
165
+ xgt0 = x > 0
166
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
167
+ return x
168
+
169
+ @staticmethod
170
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
171
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
172
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
173
+ sign_factor = sign_factor.unsqueeze(-1)
174
+ scale_factor = scale_factor.unsqueeze(-1)
175
+
176
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
177
+ neg_delta_grad = x_grad.abs() * factor
178
+ return (
179
+ x_grad - neg_delta_grad,
180
+ None,
181
+ None,
182
+ None,
183
+ )
184
+
185
+
186
+ class RandomClampFunction(torch.autograd.Function):
187
+ @staticmethod
188
+ def forward(
189
+ ctx,
190
+ x: Tensor,
191
+ min: Optional[float],
192
+ max: Optional[float],
193
+ prob: float,
194
+ reflect: float,
195
+ ) -> Tensor:
196
+ x_clamped = torch.clamp(x, min=min, max=max)
197
+ mask = torch.rand_like(x) < prob
198
+ ans = torch.where(mask, x_clamped, x)
199
+ if x.requires_grad:
200
+ ctx.save_for_backward(ans == x)
201
+ ctx.reflect = reflect
202
+ if reflect != 0.0:
203
+ ans = ans * (1.0 + reflect) - (x * reflect)
204
+ return ans
205
+
206
+ @staticmethod
207
+ def backward(
208
+ ctx, ans_grad: Tensor
209
+ ) -> Tuple[Tensor, None, None, None, None]:
210
+ (is_same,) = ctx.saved_tensors
211
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
212
+ reflect = ctx.reflect
213
+ if reflect != 0.0:
214
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
215
+ return x_grad, None, None, None, None
216
+
217
+
218
+ def random_clamp(
219
+ x: Tensor,
220
+ min: Optional[float] = None,
221
+ max: Optional[float] = None,
222
+ prob: float = 0.5,
223
+ reflect: float = 0.0,
224
+ ):
225
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
226
+
227
+
228
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
229
+ """
230
+ A randomized way of casting a floating point value to half precision.
231
+ """
232
+ if x.dtype == torch.float16:
233
+ return x
234
+ x_abs = x.abs()
235
+ is_too_small = x_abs < min_abs
236
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
237
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
238
+ # for those elements].
239
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
240
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
241
+
242
+
243
+ class RandomGradFunction(torch.autograd.Function):
244
+ """
245
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
246
+ randomized approach that preserves expectations (intended to reduce roundoff).
247
+ """
248
+
249
+ @staticmethod
250
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
251
+ ctx.min_abs = min_abs
252
+ return x
253
+
254
+ @staticmethod
255
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
256
+ if ans_grad.dtype == torch.float16:
257
+ return (
258
+ random_cast_to_half(
259
+ ans_grad.to(torch.float32), min_abs=ctx.min_abs
260
+ ),
261
+ None,
262
+ )
263
+ else:
264
+ return ans_grad, None
265
+
266
+
267
+ class RandomGrad(torch.nn.Module):
268
+ """
269
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
270
+ accuracy of training when using amp (automatic mixed precision)
271
+ """
272
+
273
+ def __init__(self, min_abs: float = 5.0e-06):
274
+ super(RandomGrad, self).__init__()
275
+ self.min_abs = min_abs
276
+
277
+ def forward(self, x: Tensor):
278
+ if (
279
+ torch.jit.is_scripting()
280
+ or not self.training
281
+ or torch.jit.is_tracing()
282
+ ):
283
+ return x
284
+ else:
285
+ return RandomGradFunction.apply(x, self.min_abs)
286
+
287
+
288
+ class SoftmaxFunction(torch.autograd.Function):
289
+ """
290
+ Tries to handle half-precision derivatives in a randomized way that should
291
+ be more accurate for training than the default behavior.
292
+ """
293
+
294
+ @staticmethod
295
+ def forward(ctx, x: Tensor, dim: int):
296
+ ans = x.softmax(dim=dim)
297
+ # if x dtype is float16, x.softmax() returns a float32 because
298
+ # (presumably) that op does not support float16, and autocast
299
+ # is enabled.
300
+ if torch.is_autocast_enabled():
301
+ ans = ans.to(torch.float16)
302
+ ctx.save_for_backward(ans)
303
+ ctx.x_dtype = x.dtype
304
+ ctx.dim = dim
305
+ return ans
306
+
307
+ @staticmethod
308
+ def backward(ctx, ans_grad: Tensor):
309
+ (ans,) = ctx.saved_tensors
310
+ with torch.cuda.amp.autocast(enabled=False):
311
+ ans_grad = ans_grad.to(torch.float32)
312
+ ans = ans.to(torch.float32)
313
+ x_grad = ans_grad * ans
314
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
315
+ return x_grad, None
316
+
317
+
318
+ def softmax(x: Tensor, dim: int):
319
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
320
+ return x.softmax(dim)
321
+
322
+ return SoftmaxFunction.apply(x, dim)
323
+
324
+
325
+ class MaxEigLimiterFunction(torch.autograd.Function):
326
+ @staticmethod
327
+ def forward(
328
+ ctx,
329
+ x: Tensor,
330
+ coeffs: Tensor,
331
+ direction: Tensor,
332
+ channel_dim: int,
333
+ grad_scale: float,
334
+ ) -> Tensor:
335
+ ctx.channel_dim = channel_dim
336
+ ctx.grad_scale = grad_scale
337
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
338
+ return x
339
+
340
+ @staticmethod
341
+ def backward(ctx, x_grad, *args):
342
+ with torch.enable_grad():
343
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
344
+ x_orig.requires_grad = True
345
+ num_channels = x_orig.shape[ctx.channel_dim]
346
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
347
+ new_direction.requires_grad = False
348
+ x = x - x.mean(dim=0)
349
+ x_var = (x ** 2).mean()
350
+ x_residual = x - coeffs * new_direction
351
+ x_residual_var = (x_residual ** 2).mean()
352
+ # `variance_proportion` is the proportion of the variance accounted for
353
+ # by the top eigen-direction. This is to be minimized.
354
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
355
+ variance_proportion.backward()
356
+ x_orig_grad = x_orig.grad
357
+ x_extra_grad = (
358
+ x_orig.grad
359
+ * ctx.grad_scale
360
+ * x_grad.norm()
361
+ / (x_orig_grad.norm() + 1.0e-20)
362
+ )
363
+ return x_grad + x_extra_grad.detach(), None, None, None, None
364
+
365
+
366
+ class BasicNorm(torch.nn.Module):
367
+ """
368
+ This is intended to be a simpler, and hopefully cheaper, replacement for
369
+ LayerNorm. The observation this is based on, is that Transformer-type
370
+ networks, especially with pre-norm, sometimes seem to set one of the
371
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
372
+ the LayerNorm because the output magnitude is then not strongly dependent
373
+ on the other (useful) features. Presumably the weight and bias of the
374
+ LayerNorm are required to allow it to do this.
375
+
376
+ So the idea is to introduce this large constant value as an explicit
377
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
378
+ doesn't have to do this trick. We make the "eps" learnable.
379
+
380
+ Args:
381
+ num_channels: the number of channels, e.g. 512.
382
+ channel_dim: the axis/dimension corresponding to the channel,
383
+ interprted as an offset from the input's ndim if negative.
384
+ shis is NOT the num_channels; it should typically be one of
385
+ {-2, -1, 0, 1, 2, 3}.
386
+ eps: the initial "epsilon" that we add as ballast in:
387
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
388
+ Note: our epsilon is actually large, but we keep the name
389
+ to indicate the connection with conventional LayerNorm.
390
+ learn_eps: if true, we learn epsilon; if false, we keep it
391
+ at the initial value.
392
+ eps_min: float
393
+ eps_max: float
394
+ """
395
+
396
+ def __init__(
397
+ self,
398
+ num_channels: int,
399
+ channel_dim: int = -1, # CAUTION: see documentation.
400
+ eps: float = 0.25,
401
+ learn_eps: bool = True,
402
+ eps_min: float = -3.0,
403
+ eps_max: float = 3.0,
404
+ ) -> None:
405
+ super(BasicNorm, self).__init__()
406
+ self.num_channels = num_channels
407
+ self.channel_dim = channel_dim
408
+ if learn_eps:
409
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
410
+ else:
411
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
412
+ self.eps_min = eps_min
413
+ self.eps_max = eps_max
414
+
415
+ def forward(self, x: Tensor) -> Tensor:
416
+ assert x.shape[self.channel_dim] == self.num_channels
417
+ eps = self.eps
418
+ if self.training and random.random() < 0.25:
419
+ # with probability 0.25, in training mode, clamp eps between the min
420
+ # and max; this will encourage it to learn parameters within the
421
+ # allowed range by making parameters that are outside the allowed
422
+ # range noisy.
423
+
424
+ # gradients to allow the parameter to get back into the allowed region if it happens to exit it.
425
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
426
+ scales = (
427
+ torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
428
+ ) ** -0.5
429
+ return x * scales
430
+
431
+
432
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
433
+ """
434
+ Behaves like a constructor of a modified version of nn.Linear
435
+ that gives an easy way to set the default initial parameter scale.
436
+
437
+ Args:
438
+ Accepts the standard args and kwargs that nn.Linear accepts
439
+ e.g. in_features, out_features, bias=False.
440
+
441
+ initial_scale: you can override this if you want to increase
442
+ or decrease the initial magnitude of the module's output
443
+ (affects the initialization of weight_scale and bias_scale).
444
+ Another option, if you want to do something like this, is
445
+ to re-initialize the parameters.
446
+ """
447
+ ans = nn.Linear(*args, **kwargs)
448
+ with torch.no_grad():
449
+ ans.weight[:] *= initial_scale
450
+ if ans.bias is not None:
451
+ torch.nn.init.uniform_(
452
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
453
+ )
454
+ return ans
455
+
456
+
457
+ def ScaledConv1d(
458
+ *args,
459
+ initial_scale: float = 1.0,
460
+ kernel_size: int = 3,
461
+ padding: str = "same",
462
+ **kwargs,
463
+ ) -> nn.Conv1d:
464
+ """
465
+ Behaves like a constructor of a modified version of nn.Conv1d
466
+ that gives an easy way to set the default initial parameter scale.
467
+
468
+ Args:
469
+ Accepts the standard args and kwargs that nn.Linear accepts
470
+ e.g. in_features, out_features, bias=False.
471
+
472
+ initial_scale: you can override this if you want to increase
473
+ or decrease the initial magnitude of the module's output
474
+ (affects the initialization of weight_scale and bias_scale).
475
+ Another option, if you want to do something like this, is
476
+ to re-initialize the parameters.
477
+ """
478
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
479
+ with torch.no_grad():
480
+ ans.weight[:] *= initial_scale
481
+ if ans.bias is not None:
482
+ torch.nn.init.uniform_(
483
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
484
+ )
485
+ return ans
486
+
487
+
488
+ def TransposeScaledConv1d(
489
+ *args,
490
+ initial_scale: float = 1.0,
491
+ kernel_size: int = 3,
492
+ padding: str = "same",
493
+ **kwargs,
494
+ ) -> nn.Sequential:
495
+ """
496
+ Transpose -> ScaledConv1d
497
+ """
498
+ return nn.Sequential(
499
+ Transpose(),
500
+ ScaledConv1d(
501
+ *args,
502
+ initial_scale=initial_scale,
503
+ kernel_size=kernel_size,
504
+ padding=padding,
505
+ **kwargs,
506
+ ),
507
+ )
508
+
509
+
510
+ def ScaledConv1dTranspose(
511
+ *args,
512
+ initial_scale: float = 1.0,
513
+ kernel_size: int = 3,
514
+ padding: str = "same",
515
+ **kwargs,
516
+ ) -> nn.Sequential:
517
+ """
518
+ Transpose -> ScaledConv1d
519
+ """
520
+ return nn.Sequential(
521
+ ScaledConv1d(
522
+ *args,
523
+ initial_scale=initial_scale,
524
+ kernel_size=kernel_size,
525
+ padding=padding,
526
+ **kwargs,
527
+ ),
528
+ Transpose(),
529
+ )
530
+
531
+
532
+ def TransposeConv1d(
533
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
534
+ ) -> nn.Sequential:
535
+ """
536
+ Transpose -> Conv1d
537
+ """
538
+ return nn.Sequential(
539
+ Transpose(),
540
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
541
+ )
542
+
543
+
544
+ def Conv1dTranspose(
545
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
546
+ ) -> nn.Sequential:
547
+ """
548
+ ScaledConv1d -> Transpose
549
+ """
550
+ return nn.Sequential(
551
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
552
+ Transpose(),
553
+ )
554
+
555
+
556
+ class SRLinear(nn.Linear):
557
+ """https://arxiv.org/abs/2303.06296
558
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
559
+ """
560
+
561
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
562
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
563
+ self.register_buffer(
564
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
565
+ )
566
+ with torch.no_grad():
567
+ sigma = self.get_sigma()
568
+ self.register_buffer("spectral_norm", sigma)
569
+ self.sigma = nn.Parameter(torch.ones(1))
570
+
571
+ def get_sigma(self):
572
+ with torch.no_grad():
573
+ u = self.u
574
+ v = self.weight.mv(u)
575
+ v = nn.functional.normalize(v, dim=0)
576
+ u = self.weight.T.mv(v)
577
+ u = nn.functional.normalize(u, dim=0)
578
+ self.u.data.copy_(u)
579
+ return torch.einsum("c,cd,d->", v, self.weight, u)
580
+
581
+ def get_weight(self):
582
+ sigma = self.get_sigma()
583
+ if self.training:
584
+ self.spectral_norm.data.copy_(sigma)
585
+ weight = (self.sigma / sigma) * self.weight
586
+ return weight
587
+
588
+ def forward(self, x):
589
+ return nn.functional.linear(x, self.get_weight(), self.bias)
590
+
591
+
592
+ class SRConv1d(SRLinear):
593
+ def __init__(
594
+ self,
595
+ in_features,
596
+ out_features,
597
+ kernel_size,
598
+ stride: int = 1,
599
+ padding: str = "same",
600
+ bias: bool = True,
601
+ **kwargs,
602
+ ):
603
+ in_features = in_features * kernel_size
604
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
605
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
606
+ self.kernel_size = kernel_size
607
+ self.stride = stride
608
+ self.padding = padding
609
+
610
+ def forward(self, x):
611
+ in_features = self.in_features // self.kernel_size
612
+ weight = self.get_weight().view(
613
+ self.out_features, in_features, self.kernel_size
614
+ )
615
+ return nn.functional.conv1d(
616
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
617
+ )
618
+
619
+
620
+ def TransposeSRConv1d(
621
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
622
+ ) -> nn.Sequential:
623
+ """
624
+ Transpose -> SRConv1d
625
+ """
626
+ return nn.Sequential(
627
+ Transpose(),
628
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
629
+ )
630
+
631
+
632
+ def SRConv1dTranspose(
633
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
634
+ ) -> nn.Sequential:
635
+ """
636
+ SRConv1d -> Transpose
637
+ """
638
+ return nn.Sequential(
639
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
640
+ Transpose(),
641
+ )
642
+
643
+
644
+ class ActivationBalancer(torch.nn.Module):
645
+ """
646
+ Modifies the backpropped derivatives of a function to try to encourage, for
647
+ each channel, that it is positive at least a proportion `threshold` of the
648
+ time. It does this by multiplying negative derivative values by up to
649
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
650
+ interpolated from 1 at the threshold to those extremal values when none
651
+ of the inputs are positive.
652
+
653
+ Args:
654
+ num_channels: the number of channels
655
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
656
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
657
+ min_positive: the minimum, per channel, of the proportion of the time
658
+ that (x > 0), below which we start to modify the derivatives.
659
+ max_positive: the maximum, per channel, of the proportion of the time
660
+ that (x > 0), above which we start to modify the derivatives.
661
+ max_factor: the maximum factor by which we modify the derivatives for
662
+ either the sign constraint or the magnitude constraint;
663
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
664
+ values in the range [0.98..1.02].
665
+ sign_gain_factor: determines the 'gain' with which we increase the
666
+ change in gradient once the constraints on min_positive and max_positive
667
+ are violated.
668
+ scale_gain_factor: determines the 'gain' with which we increase the
669
+ change in gradient once the constraints on min_abs and max_abs
670
+ are violated.
671
+ min_abs: the minimum average-absolute-value difference from the mean
672
+ value per channel, which we allow, before we start to modify
673
+ the derivatives to prevent this.
674
+ max_abs: the maximum average-absolute-value difference from the mean
675
+ value per channel, which we allow, before we start to modify
676
+ the derivatives to prevent this.
677
+ min_prob: determines the minimum probability with which we modify the
678
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
679
+ on each forward(). This is done randomly to prevent all layers
680
+ from doing it at the same time. Early in training we may use
681
+ higher probabilities than this; it will decay to this value.
682
+ """
683
+
684
+ def __init__(
685
+ self,
686
+ num_channels: int,
687
+ channel_dim: int,
688
+ min_positive: float = 0.05,
689
+ max_positive: float = 0.95,
690
+ max_factor: float = 0.04,
691
+ sign_gain_factor: float = 0.01,
692
+ scale_gain_factor: float = 0.02,
693
+ min_abs: float = 0.2,
694
+ max_abs: float = 100.0,
695
+ min_prob: float = 0.1,
696
+ ):
697
+ super(ActivationBalancer, self).__init__()
698
+ self.num_channels = num_channels
699
+ self.channel_dim = channel_dim
700
+ self.min_positive = min_positive
701
+ self.max_positive = max_positive
702
+ self.max_factor = max_factor
703
+ self.min_abs = min_abs
704
+ self.max_abs = max_abs
705
+ self.min_prob = min_prob
706
+ self.sign_gain_factor = sign_gain_factor
707
+ self.scale_gain_factor = scale_gain_factor
708
+
709
+ # count measures how many times the forward() function has been called.
710
+ # We occasionally sync this to a tensor called `count`, that exists to
711
+ # make sure it is synced to disk when we load and save the model.
712
+ self.cpu_count = 0
713
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
714
+
715
+ def forward(self, x: Tensor) -> Tensor:
716
+ if (
717
+ torch.jit.is_scripting()
718
+ or not x.requires_grad
719
+ or torch.jit.is_tracing()
720
+ ):
721
+ return _no_op(x)
722
+
723
+ count = self.cpu_count
724
+ self.cpu_count += 1
725
+
726
+ if random.random() < 0.01:
727
+ # Occasionally sync self.cpu_count with self.count.
728
+ # count affects the decay of 'prob'. don't do this on every iter,
729
+ # because syncing with the GPU is slow.
730
+ self.cpu_count = max(self.cpu_count, self.count.item())
731
+ self.count.fill_(self.cpu_count)
732
+
733
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
734
+ # a floor at min_prob (==0.1, by default)
735
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
736
+
737
+ if random.random() < prob:
738
+ sign_gain_factor = 0.5
739
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
740
+ sign_factor = _compute_sign_factor(
741
+ x,
742
+ self.channel_dim,
743
+ self.min_positive,
744
+ self.max_positive,
745
+ gain_factor=self.sign_gain_factor / prob,
746
+ max_factor=self.max_factor,
747
+ )
748
+ else:
749
+ sign_factor = None
750
+
751
+ scale_factor = _compute_scale_factor(
752
+ x.detach(),
753
+ self.channel_dim,
754
+ min_abs=self.min_abs,
755
+ max_abs=self.max_abs,
756
+ gain_factor=self.scale_gain_factor / prob,
757
+ max_factor=self.max_factor,
758
+ )
759
+ return ActivationBalancerFunction.apply(
760
+ x,
761
+ scale_factor,
762
+ sign_factor,
763
+ self.channel_dim,
764
+ )
765
+ else:
766
+ return _no_op(x)
767
+
768
+
769
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
770
+ """
771
+ Returns x unmodified, but in backprop will put a penalty for the excess of
772
+ the absolute values of elements of x over the limit "limit". E.g. if
773
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
774
+
775
+ Caution: the value of this penalty will be affected by grad scaling used
776
+ in automatic mixed precision training. For this reasons we use this,
777
+ it shouldn't really matter, or may even be helpful; we just use this
778
+ to disallow really implausible values of scores to be given to softmax.
779
+ """
780
+ x_sign = x.sign()
781
+ over_limit = (x.abs() - limit) > 0
782
+ # The following is a memory efficient way to penalize the absolute values of
783
+ # x that's over the limit. (The memory efficiency comes when you think
784
+ # about which items torch needs to cache for the autograd, and which ones it
785
+ # can throw away). The numerical value of aux_loss as computed here will
786
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
787
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
788
+ # limit).relu().
789
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
790
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
791
+ # sum() due to how with_loss() works.
792
+ x = with_loss(x, aux_loss)
793
+ # you must use x for something, or this will be ineffective.
794
+ return x
795
+
796
+
797
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
798
+ if x.ndim == 2:
799
+ return x.diag()
800
+ else:
801
+ (batch, dim, dim) = x.shape
802
+ x = x.reshape(batch, dim * dim)
803
+ x = x[:, :: dim + 1]
804
+ assert x.shape == (batch, dim)
805
+ return x
806
+
807
+
808
+ def _whitening_metric(x: Tensor, num_groups: int):
809
+ """
810
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
811
+ of the centered feature covariance are the same within each group's covariance matrix
812
+ and also between groups.
813
+ Args:
814
+ x: a Tensor of shape (*, num_channels)
815
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
816
+ Returns:
817
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
818
+ greater than 1.0 otherwise.
819
+ """
820
+ assert x.dtype != torch.float16
821
+ x = x.reshape(-1, x.shape[-1])
822
+ (num_frames, num_channels) = x.shape
823
+ assert num_channels % num_groups == 0
824
+ channels_per_group = num_channels // num_groups
825
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
826
+ # x now has shape (num_groups, num_frames, channels_per_group)
827
+ # subtract the mean so we use the centered, not uncentered, covariance.
828
+ # My experience has been that when we "mess with the gradients" like this,
829
+ # it's better not do anything that tries to move the mean around, because
830
+ # that can easily cause instability.
831
+ x = x - x.mean(dim=1, keepdim=True)
832
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
833
+ x_covar = torch.matmul(x.transpose(1, 2), x)
834
+ x_covar_mean_diag = _diag(x_covar).mean()
835
+ # the following expression is what we'd get if we took the matrix product
836
+ # of each covariance and measured the mean of its trace, i.e.
837
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
838
+ x_covarsq_mean_diag = (x_covar ** 2).sum() / (
839
+ num_groups * channels_per_group
840
+ )
841
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
842
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
843
+ return metric
844
+
845
+
846
+ class WhiteningPenaltyFunction(torch.autograd.Function):
847
+ @staticmethod
848
+ def forward(
849
+ ctx,
850
+ x: Tensor,
851
+ num_groups: int,
852
+ whitening_limit: float,
853
+ grad_scale: float,
854
+ ) -> Tensor:
855
+ ctx.save_for_backward(x)
856
+ ctx.num_groups = num_groups
857
+ ctx.whitening_limit = whitening_limit
858
+ ctx.grad_scale = grad_scale
859
+ return x
860
+
861
+ @staticmethod
862
+ def backward(ctx, x_grad: Tensor):
863
+ (x_orig,) = ctx.saved_tensors
864
+ with torch.enable_grad():
865
+ with torch.cuda.amp.autocast(enabled=False):
866
+ x_detached = x_orig.to(torch.float32).detach()
867
+ x_detached.requires_grad = True
868
+
869
+ metric = _whitening_metric(x_detached, ctx.num_groups)
870
+
871
+ if random.random() < 0.005 or __name__ == "__main__":
872
+ logging.info(
873
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
874
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
875
+ )
876
+
877
+ (metric - ctx.whitening_limit).relu().backward()
878
+ penalty_grad = x_detached.grad
879
+ scale = ctx.grad_scale * (
880
+ x_grad.to(torch.float32).norm()
881
+ / (penalty_grad.norm() + 1.0e-20)
882
+ )
883
+ penalty_grad = penalty_grad * scale
884
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
885
+
886
+
887
+ class Whiten(nn.Module):
888
+ def __init__(
889
+ self,
890
+ num_groups: int,
891
+ whitening_limit: float,
892
+ prob: Union[float, Tuple[float, float]],
893
+ grad_scale: float,
894
+ ):
895
+ """
896
+ Args:
897
+ num_groups: the number of groups to divide the channel dim into before
898
+ whitening. We will attempt to make the feature covariance
899
+ within each group, after mean subtraction, as "white" as possible,
900
+ while having the same trace across all groups.
901
+ whitening_limit: a value greater than 1.0, that dictates how much
902
+ freedom we have to violate the constraints. 1.0 would mean perfectly
903
+ white, with exactly the same trace across groups; larger values
904
+ give more freedom. E.g. 2.0.
905
+ prob: the probability with which we apply the gradient modification
906
+ (also affects the grad scale). May be supplied as a float,
907
+ or as a pair (min_prob, max_prob)
908
+
909
+ grad_scale: determines the scale on the gradient term from this object,
910
+ relative to the rest of the gradient on the attention weights.
911
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
912
+ """
913
+ super(Whiten, self).__init__()
914
+ assert num_groups >= 1
915
+ assert whitening_limit >= 1
916
+ assert grad_scale >= 0
917
+ self.num_groups = num_groups
918
+ self.whitening_limit = whitening_limit
919
+ if isinstance(prob, float):
920
+ assert 0 < prob <= 1
921
+ self.prob = prob
922
+ else:
923
+ (self.min_prob, self.max_prob) = prob
924
+ assert 0 < self.min_prob < self.max_prob <= 1
925
+ self.prob = self.max_prob
926
+
927
+ self.grad_scale = grad_scale
928
+
929
+ def forward(self, x: Tensor) -> Tensor:
930
+ """
931
+ In the forward pass, this function just returns the input unmodified.
932
+ In the backward pass, it will modify the gradients to ensure that the
933
+ distribution in each group has close to (lambda times I) as the covariance
934
+ after mean subtraction, with the same lambda across groups.
935
+ For whitening_limit > 1, there will be more freedom to violate this
936
+ constraint.
937
+
938
+ Args:
939
+ x: the input of shape (*, num_channels)
940
+
941
+ Returns:
942
+ x, unmodified. You should make sure
943
+ you use the returned value, or the graph will be freed
944
+ and nothing will happen in backprop.
945
+ """
946
+ if (
947
+ not x.requires_grad
948
+ or random.random() > self.prob
949
+ or self.grad_scale == 0
950
+ ):
951
+ return _no_op(x)
952
+ else:
953
+ if hasattr(self, "min_prob") and random.random() < 0.25:
954
+ # occasionally switch between min_prob and max_prob, based on whether
955
+ # we are above or below the threshold.
956
+ if (
957
+ _whitening_metric(x.to(torch.float32), self.num_groups)
958
+ > self.whitening_limit
959
+ ):
960
+ # there would be a change to the grad.
961
+ self.prob = self.max_prob
962
+ else:
963
+ self.prob = self.min_prob
964
+
965
+ return WhiteningPenaltyFunction.apply(
966
+ x, self.num_groups, self.whitening_limit, self.grad_scale
967
+ )
968
+
969
+
970
+ class WithLoss(torch.autograd.Function):
971
+ @staticmethod
972
+ def forward(ctx, x: Tensor, y: Tensor):
973
+ ctx.y_shape = y.shape
974
+ return x
975
+
976
+ @staticmethod
977
+ def backward(ctx, ans_grad: Tensor):
978
+ return ans_grad, torch.ones(
979
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
980
+ )
981
+
982
+
983
+ def with_loss(x, y):
984
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
985
+ return x
986
+ # returns x but adds y.sum() to the loss function.
987
+ return WithLoss.apply(x, y)
988
+
989
+
990
+ def _no_op(x: Tensor) -> Tensor:
991
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
992
+ return x
993
+ else:
994
+ # a no-op function that will have a node in the autograd graph,
995
+ # to avoid certain bugs relating to backward hooks
996
+ return x.chunk(1, dim=-1)[0]
997
+
998
+
999
+ class Identity(torch.nn.Module):
1000
+ def __init__(self):
1001
+ super(Identity, self).__init__()
1002
+
1003
+ def forward(self, x):
1004
+ return _no_op(x)
1005
+
1006
+
1007
+ class MaxEig(torch.nn.Module):
1008
+ """
1009
+ Modifies the backpropped derivatives of a function to try to discourage
1010
+ that any given direction in activation space accounts for more than
1011
+ a specified proportion of the covariance (e.g. 0.2).
1012
+
1013
+
1014
+ Args:
1015
+ num_channels: the number of channels
1016
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
1017
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
1018
+ max_var_per_eig: the maximum proportion of the variance of the
1019
+ features/channels, after mean subtraction, that can come from
1020
+ any given eigenvalue.
1021
+ min_prob: the minimum probability with which we apply this during any invocation
1022
+ of forward(), assuming last time we applied the constraint it was
1023
+ not active; supplied for speed.
1024
+ scale: determines the scale with which we modify the gradients, relative
1025
+ to the existing / unmodified gradients
1026
+ """
1027
+
1028
+ def __init__(
1029
+ self,
1030
+ num_channels: int,
1031
+ channel_dim: int,
1032
+ max_var_per_eig: float = 0.2,
1033
+ min_prob: float = 0.01,
1034
+ scale: float = 0.01,
1035
+ ):
1036
+ super(MaxEig, self).__init__()
1037
+ self.num_channels = num_channels
1038
+ self.channel_dim = channel_dim
1039
+ self.scale = scale
1040
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
1041
+ self.max_var_per_eig = max_var_per_eig
1042
+
1043
+ # we figure out the dominant direction using the power method: starting with
1044
+ # a random vector, keep multiplying by the covariance and renormalizing.
1045
+ with torch.no_grad():
1046
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1047
+ # random parameters unchanged for comparison
1048
+ direction = torch.arange(num_channels).to(torch.float)
1049
+ direction = direction / direction.norm()
1050
+ self.register_buffer("max_eig_direction", direction)
1051
+
1052
+ self.min_prob = min_prob
1053
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1054
+ # We'll regress this towards prob, each tiem we try to apply it and it is not
1055
+ # active.
1056
+ self.cur_prob = 1.0
1057
+
1058
+ def forward(self, x: Tensor) -> Tensor:
1059
+ if (
1060
+ torch.jit.is_scripting()
1061
+ or self.max_var_per_eig <= 0
1062
+ or random.random() > self.cur_prob
1063
+ or torch.jit.is_tracing()
1064
+ ):
1065
+ return _no_op(x)
1066
+
1067
+ with torch.cuda.amp.autocast(enabled=False):
1068
+ eps = 1.0e-20
1069
+ orig_x = x
1070
+ x = x.to(torch.float32)
1071
+ with torch.no_grad():
1072
+ x = x.transpose(self.channel_dim, -1).reshape(
1073
+ -1, self.num_channels
1074
+ )
1075
+ x = x - x.mean(dim=0)
1076
+ new_direction, coeffs = self._find_direction_coeffs(
1077
+ x, self.max_eig_direction
1078
+ )
1079
+ x_var = (x ** 2).mean()
1080
+ x_residual = x - coeffs * new_direction
1081
+ x_residual_var = (x_residual ** 2).mean()
1082
+
1083
+ # `variance_proportion` is the proportion of the variance accounted for
1084
+ # by the top eigen-direction.
1085
+ variance_proportion = (x_var - x_residual_var) / (
1086
+ x_var + 1.0e-20
1087
+ )
1088
+
1089
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1090
+ self._set_direction(
1091
+ 0.1 * self.max_eig_direction + new_direction
1092
+ )
1093
+
1094
+ if random.random() < 0.01 or __name__ == "__main__":
1095
+ logging.info(
1096
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1097
+ )
1098
+
1099
+ if variance_proportion >= self.max_var_per_eig:
1100
+ # The constraint is active. Note, we should quite rarely
1101
+ # reach here, only near the beginning of training if we are
1102
+ # starting to diverge, should this constraint be active.
1103
+ cur_prob = self.cur_prob
1104
+ self.cur_prob = (
1105
+ 1.0 # next time, do the update with probability 1.0.
1106
+ )
1107
+ return MaxEigLimiterFunction.apply(
1108
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1109
+ )
1110
+ else:
1111
+ # let self.cur_prob exponentially approach self.min_prob, as
1112
+ # long as the constraint is inactive.
1113
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1114
+ return orig_x
1115
+
1116
+ def _set_direction(self, direction: Tensor):
1117
+ """
1118
+ Sets self.max_eig_direction to a normalized version of `direction`
1119
+ """
1120
+ direction = direction.detach()
1121
+ direction = direction / direction.norm()
1122
+ direction_sum = direction.sum().item()
1123
+ if direction_sum - direction_sum == 0: # no inf/nan
1124
+ self.max_eig_direction[:] = direction
1125
+ else:
1126
+ logging.info(
1127
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1128
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1129
+ )
1130
+
1131
+ def _find_direction_coeffs(
1132
+ self, x: Tensor, prev_direction: Tensor
1133
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1134
+ """
1135
+ Figure out (an approximation to) the proportion of the variance of a set of
1136
+ feature vectors that can be attributed to the top eigen-direction.
1137
+ Args:
1138
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1139
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1140
+ of the top eigen-direction, or a random direction if this is the first
1141
+ iteration. Does not have to be normalized, but should be nonzero.
1142
+
1143
+ Returns: (cur_direction, coeffs), where:
1144
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1145
+ estimate of the top eigen-direction.
1146
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1147
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1148
+ """
1149
+ (num_frames, num_channels) = x.shape
1150
+ assert num_channels > 1 and num_frames > 1
1151
+ assert prev_direction.shape == (num_channels,)
1152
+ # `coeffs` are the coefficients of `prev_direction` in x.
1153
+ # actually represent the coeffs up to a constant positive factor.
1154
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1155
+ cur_direction = (x * coeffs).sum(dim=0) / (
1156
+ (coeffs ** 2).sum() + 1.0e-20
1157
+ )
1158
+ return cur_direction, coeffs
1159
+
1160
+
1161
+ class DoubleSwishFunction(torch.autograd.Function):
1162
+ """
1163
+ double_swish(x) = x * torch.sigmoid(x-1)
1164
+ This is a definition, originally motivated by its close numerical
1165
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1166
+
1167
+ Memory-efficient derivative computation:
1168
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1169
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1170
+ Now, s'(x) = s(x) * (1-s(x)).
1171
+ double_swish'(x) = x * s'(x) + s(x).
1172
+ = x * s(x) * (1-s(x)) + s(x).
1173
+ = double_swish(x) * (1-s(x)) + s(x)
1174
+ ... so we just need to remember s(x) but not x itself.
1175
+ """
1176
+
1177
+ @staticmethod
1178
+ def forward(ctx, x: Tensor) -> Tensor:
1179
+ requires_grad = x.requires_grad
1180
+ x_dtype = x.dtype
1181
+ if x.dtype == torch.float16:
1182
+ x = x.to(torch.float32)
1183
+
1184
+ s = torch.sigmoid(x - 1.0)
1185
+ y = x * s
1186
+
1187
+ if requires_grad:
1188
+ deriv = y * (1 - s) + s
1189
+ # notes on derivative of x * sigmoid(x - 1):
1190
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1191
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1192
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1193
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1194
+ # floors), should be expectation-preserving.
1195
+ floor = -0.043637
1196
+ ceil = 1.2
1197
+ d_scaled = (deriv - floor) * (
1198
+ 255.0 / (ceil - floor)
1199
+ ) + torch.rand_like(deriv)
1200
+ if __name__ == "__main__":
1201
+ # for self-testing only.
1202
+ assert d_scaled.min() >= 0.0
1203
+ assert d_scaled.max() < 256.0
1204
+ d_int = d_scaled.to(torch.uint8)
1205
+ ctx.save_for_backward(d_int)
1206
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1207
+ y = y.to(torch.float16)
1208
+ return y
1209
+
1210
+ @staticmethod
1211
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1212
+ (d,) = ctx.saved_tensors
1213
+ # the same constants as used in forward pass.
1214
+ floor = -0.043637
1215
+ ceil = 1.2
1216
+ d = d * ((ceil - floor) / 255.0) + floor
1217
+ return y_grad * d
1218
+
1219
+
1220
+ class DoubleSwish(torch.nn.Module):
1221
+ def forward(self, x: Tensor) -> Tensor:
1222
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1223
+ that we approximate closely with x * sigmoid(x-1).
1224
+ """
1225
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1226
+ return x * torch.sigmoid(x - 1.0)
1227
+ return DoubleSwishFunction.apply(x)
1228
+
1229
+
1230
+ def BalancedDoubleSwish(
1231
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1232
+ ) -> nn.Sequential:
1233
+ """
1234
+ ActivationBalancer -> DoubleSwish
1235
+ """
1236
+ balancer = ActivationBalancer(
1237
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1238
+ )
1239
+ return nn.Sequential(
1240
+ balancer,
1241
+ DoubleSwish(),
1242
+ )
1243
+
1244
+
1245
+ def _test_max_eig():
1246
+ for proportion in [0.1, 0.5, 10.0]:
1247
+ logging.info(f"proportion = {proportion}")
1248
+ x = torch.randn(100, 128)
1249
+ direction = torch.randn(128)
1250
+ coeffs = torch.randn(100, 1)
1251
+ x += proportion * direction * coeffs
1252
+
1253
+ x.requires_grad = True
1254
+
1255
+ num_channels = 128
1256
+ m = MaxEig(
1257
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1258
+ ) # grad_scale
1259
+
1260
+ for _ in range(4):
1261
+ y = m(x)
1262
+
1263
+ y_grad = torch.randn_like(x)
1264
+ y.backward(gradient=y_grad)
1265
+
1266
+ if proportion < 0.2:
1267
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1268
+ elif proportion > 1.0:
1269
+ assert not torch.allclose(x.grad, y_grad)
1270
+
1271
+
1272
+ def _test_whiten():
1273
+ for proportion in [0.1, 0.5, 10.0]:
1274
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1275
+ x = torch.randn(100, 128)
1276
+ direction = torch.randn(128)
1277
+ coeffs = torch.randn(100, 1)
1278
+ x += proportion * direction * coeffs
1279
+
1280
+ x.requires_grad = True
1281
+
1282
+ num_channels = 128
1283
+ m = Whiten(
1284
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1285
+ ) # grad_scale
1286
+
1287
+ for _ in range(4):
1288
+ y = m(x)
1289
+
1290
+ y_grad = torch.randn_like(x)
1291
+ y.backward(gradient=y_grad)
1292
+
1293
+ if proportion < 0.2:
1294
+ assert torch.allclose(x.grad, y_grad)
1295
+ elif proportion > 1.0:
1296
+ assert not torch.allclose(x.grad, y_grad)
1297
+
1298
+
1299
+ def _test_activation_balancer_sign():
1300
+ probs = torch.arange(0, 1, 0.01)
1301
+ N = 1000
1302
+ x = 1.0 * (
1303
+ (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
1304
+ )
1305
+ x = x.detach()
1306
+ x.requires_grad = True
1307
+ m = ActivationBalancer(
1308
+ probs.numel(),
1309
+ channel_dim=0,
1310
+ min_positive=0.05,
1311
+ max_positive=0.95,
1312
+ max_factor=0.2,
1313
+ min_abs=0.0,
1314
+ )
1315
+
1316
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1317
+
1318
+ y = m(x)
1319
+ y.backward(gradient=y_grad)
1320
+ print("_test_activation_balancer_sign: x = ", x)
1321
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1322
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1323
+
1324
+
1325
+ def _test_activation_balancer_magnitude():
1326
+ magnitudes = torch.arange(0, 1, 0.01)
1327
+ N = 1000
1328
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
1329
+ -1
1330
+ )
1331
+ x = x.detach()
1332
+ x.requires_grad = True
1333
+ m = ActivationBalancer(
1334
+ magnitudes.numel(),
1335
+ channel_dim=0,
1336
+ min_positive=0.0,
1337
+ max_positive=1.0,
1338
+ max_factor=0.2,
1339
+ min_abs=0.2,
1340
+ max_abs=0.8,
1341
+ min_prob=1.0,
1342
+ )
1343
+
1344
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1345
+
1346
+ y = m(x)
1347
+ y.backward(gradient=y_grad)
1348
+ print("_test_activation_balancer_magnitude: x = ", x)
1349
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1350
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1351
+
1352
+
1353
+ def _test_basic_norm():
1354
+ num_channels = 128
1355
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1356
+
1357
+ x = torch.randn(500, num_channels)
1358
+
1359
+ y = m(x)
1360
+
1361
+ assert y.shape == x.shape
1362
+ x_rms = (x ** 2).mean().sqrt()
1363
+ y_rms = (y ** 2).mean().sqrt()
1364
+ print("x rms = ", x_rms)
1365
+ print("y rms = ", y_rms)
1366
+ assert y_rms < x_rms
1367
+ assert y_rms > 0.5 * x_rms
1368
+
1369
+
1370
+ def _test_double_swish_deriv():
1371
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1372
+ x.requires_grad = True
1373
+ m = DoubleSwish()
1374
+
1375
+ tol = (1.2 - (-0.043637)) / 255.0
1376
+ torch.autograd.gradcheck(m, x, atol=tol)
1377
+
1378
+ # for self-test.
1379
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1380
+ x.requires_grad = True
1381
+ y = m(x)
1382
+
1383
+
1384
+ def _test_softmax():
1385
+ a = torch.randn(2, 10, dtype=torch.float64)
1386
+ b = a.clone()
1387
+ a.requires_grad = True
1388
+ b.requires_grad = True
1389
+ a.softmax(dim=1)[:, 0].sum().backward()
1390
+ print("a grad = ", a.grad)
1391
+ softmax(b, dim=1)[:, 0].sum().backward()
1392
+ print("b grad = ", b.grad)
1393
+ assert torch.allclose(a.grad, b.grad)
1394
+
1395
+
1396
+ if __name__ == "__main__":
1397
+ logging.getLogger().setLevel(logging.INFO)
1398
+ torch.set_num_threads(1)
1399
+ torch.set_num_interop_threads(1)
1400
+ _test_softmax()
1401
+ _test_whiten()
1402
+ _test_max_eig()
1403
+ _test_activation_balancer_sign()
1404
+ _test_activation_balancer_magnitude()
1405
+ _test_basic_norm()
1406
+ _test_double_swish_deriv()
lib/voicecraft/models/modules/transformer.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng 2024
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any, Callable, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+ from torch.nn import functional as F
10
+
11
+ from .activation import MultiheadAttention
12
+ from .scaling import ActivationBalancer, BalancedDoubleSwish
13
+ from .scaling import BasicNorm as _BasicNorm
14
+
15
+ _shape_t = Union[int, List[int], torch.Size]
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
20
+ normalized_shape: Tuple[int, ...]
21
+ eps: float
22
+ elementwise_affine: bool
23
+
24
+ def __init__(
25
+ self,
26
+ normalized_shape: _shape_t,
27
+ eps: float = 1e-5,
28
+ elementwise_affine: bool = True,
29
+ device=None,
30
+ dtype=None,
31
+ ) -> None:
32
+ factory_kwargs = {"device": device, "dtype": dtype}
33
+ super(LayerNorm, self).__init__()
34
+ if isinstance(normalized_shape, numbers.Integral):
35
+ # mypy error: incompatible types in assignment
36
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
37
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
38
+ self.eps = eps
39
+ self.elementwise_affine = elementwise_affine
40
+ if self.elementwise_affine:
41
+ self.weight = nn.Parameter(
42
+ torch.empty(self.normalized_shape, **factory_kwargs)
43
+ )
44
+ self.bias = nn.Parameter(
45
+ torch.empty(self.normalized_shape, **factory_kwargs)
46
+ )
47
+ else:
48
+ self.register_parameter("weight", None)
49
+ self.register_parameter("bias", None)
50
+
51
+ self.reset_parameters()
52
+
53
+ def reset_parameters(self) -> None:
54
+ if self.elementwise_affine:
55
+ nn.init.ones_(self.weight)
56
+ nn.init.zeros_(self.bias)
57
+
58
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
59
+ if isinstance(input, tuple):
60
+ input, embedding = input
61
+ return (
62
+ F.layer_norm(
63
+ input,
64
+ self.normalized_shape,
65
+ self.weight,
66
+ self.bias,
67
+ self.eps,
68
+ ),
69
+ embedding,
70
+ )
71
+
72
+ assert embedding is None
73
+ return F.layer_norm(
74
+ input, self.normalized_shape, self.weight, self.bias, self.eps
75
+ )
76
+
77
+ def extra_repr(self) -> str:
78
+ return (
79
+ "{normalized_shape}, eps={eps}, "
80
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
81
+ )
82
+
83
+
84
+ class AdaptiveLayerNorm(nn.Module):
85
+ r"""Adaptive Layer Normalization"""
86
+
87
+ def __init__(self, d_model, norm) -> None:
88
+ super(AdaptiveLayerNorm, self).__init__()
89
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
90
+ self.norm = norm
91
+ self.d_model = d_model
92
+ self.eps = self.norm.eps
93
+
94
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
95
+ if isinstance(input, tuple):
96
+ input, embedding = input
97
+ weight, bias = torch.split(
98
+ self.project_layer(embedding),
99
+ split_size_or_sections=self.d_model,
100
+ dim=-1,
101
+ )
102
+ return (weight * self.norm(input) + bias, embedding)
103
+
104
+ weight, bias = torch.split(
105
+ self.project_layer(embedding),
106
+ split_size_or_sections=self.d_model,
107
+ dim=-1,
108
+ )
109
+ return weight * self.norm(input) + bias
110
+
111
+
112
+ class BasicNorm(_BasicNorm):
113
+ def __init__(
114
+ self,
115
+ d_model: int,
116
+ eps: float = 1e-5,
117
+ device=None,
118
+ dtype=None,
119
+ ):
120
+ super(BasicNorm, self).__init__(d_model, eps=eps)
121
+
122
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
123
+ if isinstance(input, tuple):
124
+ input, embedding = input
125
+ return (
126
+ super(BasicNorm, self).forward(input),
127
+ embedding,
128
+ )
129
+
130
+ assert embedding is None
131
+ return super(BasicNorm, self).forward(input)
132
+
133
+
134
+ class BalancedBasicNorm(nn.Module):
135
+ def __init__(
136
+ self,
137
+ d_model: int,
138
+ eps: float = 1e-5,
139
+ device=None,
140
+ dtype=None,
141
+ ):
142
+ super(BalancedBasicNorm, self).__init__()
143
+ self.balancer = ActivationBalancer(
144
+ d_model,
145
+ channel_dim=-1,
146
+ min_positive=0.45,
147
+ max_positive=0.55,
148
+ max_abs=6.0,
149
+ )
150
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
151
+
152
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
153
+ if isinstance(input, tuple):
154
+ input, embedding = input
155
+ return self.norm((self.balancer(input), embedding))
156
+
157
+ assert embedding is None
158
+ return self.norm(self.balancer(input))
159
+
160
+
161
+ class IdentityNorm(nn.Module):
162
+ def __init__(
163
+ self,
164
+ d_model: int,
165
+ eps: float = 1e-5,
166
+ device=None,
167
+ dtype=None,
168
+ ) -> None:
169
+ super(IdentityNorm, self).__init__()
170
+
171
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
172
+ if isinstance(input, tuple):
173
+ return input
174
+
175
+ assert embedding is None
176
+ return input
177
+
178
+
179
+ class TransformerEncoderLayer(nn.Module):
180
+ __constants__ = ["batch_first", "norm_first"]
181
+
182
+ def __init__(
183
+ self,
184
+ d_model: int,
185
+ nhead: int,
186
+ dim_feedforward: int = 2048,
187
+ dropout: float = 0.1,
188
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
189
+ batch_first: bool = False,
190
+ norm_first: bool = False,
191
+ device=None,
192
+ dtype=None,
193
+ linear1_self_attention_cls: nn.Module = nn.Linear,
194
+ linear2_self_attention_cls: nn.Module = nn.Linear,
195
+ linear1_feedforward_cls: nn.Module = nn.Linear,
196
+ linear2_feedforward_cls: nn.Module = nn.Linear,
197
+ layer_norm_cls: nn.Module = LayerNorm,
198
+ layer_norm_eps: float = 1e-5,
199
+ adaptive_layer_norm=False,
200
+ ) -> None:
201
+ factory_kwargs = {"device": device, "dtype": dtype}
202
+ super(TransformerEncoderLayer, self).__init__()
203
+ self.self_attn = MultiheadAttention(
204
+ d_model,
205
+ nhead,
206
+ dropout=dropout,
207
+ batch_first=batch_first,
208
+ linear1_cls=linear1_self_attention_cls,
209
+ linear2_cls=linear2_self_attention_cls,
210
+ **factory_kwargs,
211
+ )
212
+
213
+ # Implementation of Feedforward model
214
+ self.linear1 = linear1_feedforward_cls(
215
+ d_model, dim_feedforward, **factory_kwargs
216
+ )
217
+ self.dropout = nn.Dropout(dropout)
218
+ self.linear2 = linear2_feedforward_cls(
219
+ dim_feedforward, d_model, **factory_kwargs
220
+ )
221
+
222
+ self.norm_first = norm_first
223
+ self.dropout1 = nn.Dropout(dropout)
224
+ self.dropout2 = nn.Dropout(dropout)
225
+
226
+ # Legacy string support for activation function.
227
+ if isinstance(activation, str):
228
+ activation = _get_activation_fn(activation)
229
+ elif isinstance(activation, partial):
230
+ activation = activation(d_model)
231
+ elif activation == BalancedDoubleSwish:
232
+ activation = BalancedDoubleSwish(d_model)
233
+
234
+ # # We can't test self.activation in forward() in TorchScript,
235
+ # # so stash some information about it instead.
236
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
237
+ # self.activation_relu_or_gelu = 1
238
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
239
+ # self.activation_relu_or_gelu = 2
240
+ # else:
241
+ # self.activation_relu_or_gelu = 0
242
+ self.activation = activation
243
+
244
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
245
+ if layer_norm_cls == IdentityNorm:
246
+ norm2 = BalancedBasicNorm(
247
+ d_model, eps=layer_norm_eps, **factory_kwargs
248
+ )
249
+ else:
250
+ norm2 = layer_norm_cls(
251
+ d_model, eps=layer_norm_eps, **factory_kwargs
252
+ )
253
+
254
+ if adaptive_layer_norm:
255
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
256
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
257
+ else:
258
+ self.norm1 = norm1
259
+ self.norm2 = norm2
260
+
261
+ def __setstate__(self, state):
262
+ super(TransformerEncoderLayer, self).__setstate__(state)
263
+ if not hasattr(self, "activation"):
264
+ self.activation = F.relu
265
+
266
+ def forward(
267
+ self,
268
+ src: Tensor,
269
+ src_mask: Optional[Tensor] = None,
270
+ src_key_padding_mask: Optional[Tensor] = None,
271
+ need_weights: Optional[bool] = False,
272
+ past: Optional[Tensor] = None,
273
+ ) -> Tensor:
274
+ r"""Pass the input through the encoder layer.
275
+
276
+ Args:
277
+ src: the sequence to the encoder layer (required).
278
+ src_mask: the mask for the src sequence (optional).
279
+ src_key_padding_mask: the mask for the src keys per batch (optional).
280
+
281
+ Shape:
282
+ see the docs in Transformer class.
283
+ """
284
+ x, stage_embedding = src, None
285
+ is_src_tuple = False
286
+ if isinstance(src, tuple):
287
+ x, stage_embedding = src
288
+ is_src_tuple = True
289
+
290
+ if src_key_padding_mask is not None:
291
+ _skpm_dtype = src_key_padding_mask.dtype
292
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
293
+ src_key_padding_mask
294
+ ):
295
+ raise AssertionError(
296
+ "only bool and floating types of key_padding_mask are supported"
297
+ )
298
+ if need_weights:
299
+ if self.norm_first:
300
+ out, attn = self._sa_block_attn(
301
+ self.norm1(x, stage_embedding),
302
+ src_mask,
303
+ src_key_padding_mask,
304
+ past
305
+ )
306
+ out, present = out # present is the kvcache of the present timestep
307
+ x = x + out
308
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
309
+ else:
310
+ out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past)
311
+ out, present = out # present is the kvcache of the present timestep
312
+ x = self.norm1(
313
+ x + out,
314
+ stage_embedding,
315
+ )
316
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
317
+ assert not is_src_tuple
318
+ # return (x, stage_embedding)
319
+ return (x, attn)
320
+ else:
321
+ if self.norm_first:
322
+ out = self._sa_block(
323
+ self.norm1(x, stage_embedding),
324
+ src_mask,
325
+ src_key_padding_mask, past
326
+ )
327
+ out, present = out # present is the kvcache of the present timestep
328
+ x = x + out
329
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
330
+ else:
331
+ out = self._sa_block(x, src_mask, src_key_padding_mask)
332
+ out, present = out # present is the kvcache of the present timestep
333
+ x = self.norm1(
334
+ x + out,
335
+ stage_embedding, past
336
+ )
337
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
338
+
339
+ if is_src_tuple:
340
+ x = (x, stage_embedding)
341
+ if present != None:
342
+ x = [x, present]
343
+ return x
344
+
345
+ # self-attention block
346
+ def _sa_block(
347
+ self,
348
+ x: Tensor,
349
+ attn_mask: Optional[Tensor],
350
+ key_padding_mask: Optional[Tensor],
351
+ past: Optional[Tensor] = None,
352
+ ) -> Tensor:
353
+ x = self.self_attn(
354
+ x,
355
+ x,
356
+ x,
357
+ attn_mask=attn_mask,
358
+ key_padding_mask=key_padding_mask,
359
+ need_weights=False,
360
+ past=past
361
+ )
362
+ x, present = x
363
+ return self.dropout1(x), present
364
+
365
+ # self-attention block, also return attention weights
366
+ def _sa_block_attn(
367
+ self,
368
+ x: Tensor,
369
+ attn_mask: Optional[Tensor],
370
+ key_padding_mask: Optional[Tensor],
371
+ past: Optional[Tensor] = None,
372
+ ) -> Tensor:
373
+ x, attn = self.self_attn(
374
+ x,
375
+ x,
376
+ x,
377
+ attn_mask=attn_mask,
378
+ key_padding_mask=key_padding_mask,
379
+ need_weights=True,
380
+ past=past
381
+ )
382
+ x, present = x
383
+ return (self.dropout1(x), present), attn
384
+
385
+ # feed forward block
386
+ def _ff_block(self, x: Tensor) -> Tensor:
387
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
388
+ return self.dropout2(x)
389
+
390
+
391
+ class TransformerEncoder(nn.Module):
392
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
393
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
394
+
395
+ Args:
396
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
397
+ num_layers: the number of sub-encoder-layers in the encoder (required).
398
+ norm: the layer normalization component (optional).
399
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
400
+ (and convert back on output). This will improve the overall performance of
401
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
402
+
403
+ Examples::
404
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
405
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
406
+ >>> src = torch.rand(10, 32, 512)
407
+ >>> out = transformer_encoder(src)
408
+ """
409
+ __constants__ = ["norm"]
410
+
411
+ def __init__(self, encoder_layer, num_layers, norm=None):
412
+ super(TransformerEncoder, self).__init__()
413
+ self.layers = _get_clones(encoder_layer, num_layers)
414
+ self.num_layers = num_layers
415
+ self.norm = norm
416
+
417
+ def forward(
418
+ self,
419
+ src: Tensor,
420
+ mask: Optional[Tensor] = None,
421
+ src_key_padding_mask: Optional[Tensor] = None,
422
+ return_layer_states: bool = False,
423
+ need_weights:Optional[bool] = False,
424
+ past: Optional[Tensor] = None,
425
+ ) -> Tensor:
426
+ r"""Pass the input through the encoder layers in turn.
427
+
428
+ Args:
429
+ src: the sequence to the encoder (required).
430
+ mask: the mask for the src sequence (optional).
431
+ src_key_padding_mask: the mask for the src keys per batch (optional).
432
+ return_layer_states: return layers' state (optional).
433
+
434
+ Shape:
435
+ see the docs in Transformer class.
436
+ """
437
+ if return_layer_states:
438
+ assert not need_weights
439
+ layer_states = [] # layers' output
440
+ output = src
441
+ for mod in self.layers:
442
+ output = mod(
443
+ output,
444
+ src_mask=mask,
445
+ src_key_padding_mask=src_key_padding_mask,
446
+ past=past
447
+ )
448
+ layer_states.append(output[0])
449
+
450
+ if self.norm is not None:
451
+ output = self.norm(output)
452
+
453
+ return layer_states, output
454
+ if need_weights:
455
+ assert not return_layer_states
456
+ layer_attn = [] # layers' output
457
+ output = src
458
+ for mod in self.layers:
459
+ output = mod(
460
+ output,
461
+ src_mask=mask,
462
+ src_key_padding_mask=src_key_padding_mask,
463
+ need_weights=True,
464
+ past=past
465
+ )
466
+ layer_attn.append(output[1])
467
+
468
+ if self.norm is not None:
469
+ output = self.norm(output)
470
+
471
+ return layer_attn, output
472
+
473
+ output = src
474
+ all_present = []
475
+ for n_layer, mod in enumerate(self.layers):
476
+ output = mod(
477
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
478
+ )
479
+ if isinstance(output, list):
480
+ output, present = output
481
+ all_present.append(present)
482
+
483
+ if self.norm is not None:
484
+ output = self.norm(output)
485
+ if all_present != []:
486
+ all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
487
+ output = [output, all_present]
488
+ return output
489
+
490
+
491
+ class TransformerDecoderLayer(nn.Module):
492
+ __constants__ = ["batch_first", "norm_first"]
493
+
494
+ def __init__(
495
+ self,
496
+ d_model: int,
497
+ nhead: int,
498
+ dim_feedforward: int = 2048,
499
+ dropout: float = 0.1,
500
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
501
+ linear1_self_attention_cls: nn.Module = nn.Linear,
502
+ linear2_self_attention_cls: nn.Module = nn.Linear,
503
+ linear1_feedforward_cls: nn.Module = nn.Linear,
504
+ linear2_feedforward_cls: nn.Module = nn.Linear,
505
+ batch_first: bool = False,
506
+ norm_first: bool = False,
507
+ device=None,
508
+ dtype=None,
509
+ layer_norm_cls: nn.Module = LayerNorm,
510
+ layer_norm_eps: float = 1e-5,
511
+ adaptive_layer_norm=False,
512
+ ) -> None:
513
+ factory_kwargs = {"device": device, "dtype": dtype}
514
+ super(TransformerDecoderLayer, self).__init__()
515
+ self.self_attn = MultiheadAttention(
516
+ d_model,
517
+ nhead,
518
+ dropout=dropout,
519
+ batch_first=batch_first,
520
+ linear1_cls=linear1_self_attention_cls,
521
+ linear2_cls=linear2_self_attention_cls,
522
+ **factory_kwargs,
523
+ )
524
+ self.multihead_attn = MultiheadAttention(
525
+ d_model,
526
+ nhead,
527
+ dropout=dropout,
528
+ batch_first=batch_first,
529
+ linear1_cls=linear1_self_attention_cls,
530
+ linear2_cls=linear2_self_attention_cls,
531
+ **factory_kwargs,
532
+ )
533
+ # Implementation of Feedforward model
534
+ self.linear1 = linear1_feedforward_cls(
535
+ d_model, dim_feedforward, **factory_kwargs
536
+ )
537
+ self.dropout = nn.Dropout(dropout)
538
+ self.linear2 = linear2_feedforward_cls(
539
+ dim_feedforward, d_model, **factory_kwargs
540
+ )
541
+
542
+ self.norm_first = norm_first
543
+ self.dropout1 = nn.Dropout(dropout)
544
+ self.dropout2 = nn.Dropout(dropout)
545
+ self.dropout3 = nn.Dropout(dropout)
546
+
547
+ # Legacy string support for activation function.
548
+ if isinstance(activation, str):
549
+ self.activation = _get_activation_fn(activation)
550
+ elif isinstance(activation, partial):
551
+ self.activation = activation(d_model)
552
+ elif activation == BalancedDoubleSwish:
553
+ self.activation = BalancedDoubleSwish(d_model)
554
+ else:
555
+ self.activation = activation
556
+
557
+ if adaptive_layer_norm:
558
+ norm1 = layer_norm_cls(
559
+ d_model, eps=layer_norm_eps, **factory_kwargs
560
+ )
561
+ norm2 = layer_norm_cls(
562
+ d_model, eps=layer_norm_eps, **factory_kwargs
563
+ )
564
+ norm3 = layer_norm_cls(
565
+ d_model, eps=layer_norm_eps, **factory_kwargs
566
+ )
567
+
568
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
569
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
570
+ self.norm3 = AdaptiveLayerNorm(d_model, norm3)
571
+ else:
572
+ self.norm1 = layer_norm_cls(
573
+ d_model, eps=layer_norm_eps, **factory_kwargs
574
+ )
575
+ self.norm2 = layer_norm_cls(
576
+ d_model, eps=layer_norm_eps, **factory_kwargs
577
+ )
578
+ if layer_norm_cls == IdentityNorm:
579
+ self.norm3 = BalancedBasicNorm(
580
+ d_model, eps=layer_norm_eps, **factory_kwargs
581
+ )
582
+ else:
583
+ self.norm3 = layer_norm_cls(
584
+ d_model, eps=layer_norm_eps, **factory_kwargs
585
+ )
586
+
587
+ def forward(
588
+ self,
589
+ tgt: Tensor,
590
+ memory: Tensor,
591
+ tgt_mask: Optional[Tensor] = None,
592
+ memory_mask: Optional[Tensor] = None,
593
+ tgt_key_padding_mask: Optional[Tensor] = None,
594
+ memory_key_padding_mask: Optional[Tensor] = None,
595
+ ) -> Tensor:
596
+ r"""Pass the inputs (and mask) through the decoder layer.
597
+
598
+ Args:
599
+ tgt: the sequence to the decoder layer (required).
600
+ memory: the sequence from the last layer of the encoder (required).
601
+ tgt_mask: the mask for the tgt sequence (optional).
602
+ memory_mask: the mask for the memory sequence (optional).
603
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
604
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
605
+
606
+ Shape:
607
+ see the docs in Transformer class.
608
+ """
609
+ tgt_is_tuple = False
610
+ if isinstance(tgt, tuple):
611
+ x, stage_embedding = tgt
612
+ tgt_is_tuple = True
613
+ else:
614
+ x, stage_embedding = tgt, None
615
+
616
+ if self.norm_first:
617
+ x = x + self._sa_block(
618
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
619
+ )
620
+ x = x + self._mha_block(
621
+ self.norm2(x, stage_embedding),
622
+ memory,
623
+ memory_mask,
624
+ memory_key_padding_mask,
625
+ )
626
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
627
+ else:
628
+ x = self.norm1(
629
+ x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
630
+ stage_embedding,
631
+ )
632
+ x = self.norm2(
633
+ x
634
+ + self._mha_block(
635
+ x, memory, memory_mask, memory_key_padding_mask
636
+ ),
637
+ stage_embedding,
638
+ )
639
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
640
+
641
+ if tgt_is_tuple:
642
+ return (x, stage_embedding)
643
+ return x
644
+
645
+ # self-attention block
646
+ def _sa_block(
647
+ self,
648
+ x: Tensor,
649
+ attn_mask: Optional[Tensor],
650
+ key_padding_mask: Optional[Tensor],
651
+ ) -> Tensor:
652
+ x = self.self_attn(
653
+ x,
654
+ x,
655
+ x,
656
+ attn_mask=attn_mask,
657
+ key_padding_mask=key_padding_mask,
658
+ need_weights=False,
659
+ )[0]
660
+ return self.dropout1(x)
661
+
662
+ # multihead attention block
663
+ def _mha_block(
664
+ self,
665
+ x: Tensor,
666
+ mem: Tensor,
667
+ attn_mask: Optional[Tensor],
668
+ key_padding_mask: Optional[Tensor],
669
+ ) -> Tensor:
670
+ x = self.multihead_attn(
671
+ x,
672
+ mem,
673
+ mem,
674
+ attn_mask=attn_mask,
675
+ key_padding_mask=key_padding_mask,
676
+ need_weights=False,
677
+ )[0]
678
+ return self.dropout2(x)
679
+
680
+ # feed forward block
681
+ def _ff_block(self, x: Tensor) -> Tensor:
682
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
683
+ return self.dropout3(x)
684
+
685
+
686
+ def _get_clones(module, N):
687
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
688
+
689
+
690
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
691
+ if activation == "relu":
692
+ return F.relu
693
+ elif activation == "gelu":
694
+ return F.gelu
695
+
696
+ raise RuntimeError(
697
+ "activation should be relu/gelu, not {}".format(activation)
698
+ )
lib/voicecraft/models/modules/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng
2
+ import torch
3
+
4
+
5
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
6
+ """
7
+ Args:
8
+ lengths:
9
+ A 1-D tensor containing sentence lengths.
10
+ max_len:
11
+ The length of masks.
12
+ Returns:
13
+ Return a 2-D bool tensor, where masked positions
14
+ are filled with `True` and non-masked positions are
15
+ filled with `False`.
16
+
17
+ >>> lengths = torch.tensor([1, 3, 2, 5])
18
+ >>> make_pad_mask(lengths)
19
+ tensor([[False, True, True, True, True],
20
+ [False, False, False, True, True],
21
+ [False, False, True, True, True],
22
+ [False, False, False, False, False]])
23
+ """
24
+ assert lengths.ndim == 1, lengths.ndim
25
+ max_len = max(max_len, lengths.max())
26
+ n = lengths.size(0)
27
+ seq_range = torch.arange(0, max_len, device=lengths.device)
28
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
29
+
30
+ return expaned_lengths >= lengths.unsqueeze(-1)
31
+
32
+ def generate_partial_autoregressive_mask(sz, start, end):
33
+ mask = torch.zeros(sz, sz).bool()
34
+ mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1)
35
+ mask[:start, start:end] = True
36
+ mask[end:, start:end] = True
37
+ return mask
lib/voicecraft/models/voicecraft.py ADDED
@@ -0,0 +1,1406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import logging
5
+ import argparse, copy
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchmetrics.classification import MulticlassAccuracy
10
+
11
+ from .modules.utils import make_pad_mask
12
+
13
+ from .modules.embedding import SinePositionalEmbedding, TokenEmbedding
14
+ from .modules.transformer import (
15
+ LayerNorm,
16
+ TransformerEncoder,
17
+ TransformerEncoderLayer,
18
+ )
19
+ from .codebooks_patterns import DelayedPatternProvider
20
+
21
+ def top_k_top_p_filtering(
22
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
23
+ ):
24
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
25
+ Args:
26
+ logits: logits distribution shape (batch size, vocabulary size)
27
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
28
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
29
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
30
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
31
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
32
+ """
33
+ if top_k > 0:
34
+ top_k = min(
35
+ max(top_k, min_tokens_to_keep), logits.size(-1)
36
+ ) # Safety check
37
+ # Remove all tokens with a probability less than the last token of the top-k
38
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
39
+ logits[indices_to_remove] = filter_value
40
+
41
+ if top_p < 1.0:
42
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
43
+ cumulative_probs = torch.cumsum(
44
+ F.softmax(sorted_logits, dim=-1), dim=-1
45
+ )
46
+
47
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
48
+ sorted_indices_to_remove = cumulative_probs > top_p
49
+ if min_tokens_to_keep > 1:
50
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
51
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
52
+ # Shift the indices to the right to keep also the first token above the threshold
53
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
54
+ ..., :-1
55
+ ].clone()
56
+ sorted_indices_to_remove[..., 0] = 0
57
+
58
+ # scatter sorted tensors to original indexing
59
+ indices_to_remove = sorted_indices_to_remove.scatter(
60
+ 1, sorted_indices, sorted_indices_to_remove
61
+ )
62
+ logits[indices_to_remove] = filter_value
63
+ return logits
64
+
65
+
66
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
67
+ # temperature: (`optional`) float
68
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
69
+ # top_k: (`optional`) int
70
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
71
+ # top_p: (`optional`) float
72
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
73
+
74
+ # Temperature (higher temperature => more likely to sample low probability tokens)
75
+ if temperature != 1.0:
76
+ logits = logits / temperature
77
+ # Top-p/top-k filtering
78
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
79
+ # Sample
80
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
81
+ return token
82
+
83
+
84
+
85
+ class VoiceCraft(nn.Module):
86
+ def __init__(self, args):
87
+ super().__init__()
88
+ self.args = copy.copy(args)
89
+ self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
90
+ if not getattr(self.args, "special_first", False):
91
+ self.args.special_first = 0
92
+ if not getattr(self.args, "n_special", False):
93
+ self.args.n_special = 3
94
+ self.args.eos = getattr(self.args, "eos", -1)
95
+ self.eog = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long), requires_grad=False) # [K 1]
96
+ if self.args.eos > 0:
97
+ assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
98
+ self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
99
+ if type(self.args.audio_vocab_size) == str:
100
+ self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
101
+
102
+ self.n_text_tokens = self.args.text_vocab_size + 1
103
+ assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}"
104
+
105
+ self.n_audio_tokens = [self.args.audio_vocab_size + self.args.n_special] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token
106
+ assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token
107
+ assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog
108
+ assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token
109
+
110
+ self.text_embedding = TokenEmbedding(
111
+ dim_model=self.args.d_model,
112
+ vocab_size=self.n_text_tokens,
113
+ dropout=self.args.text_embedding_dropout
114
+ )
115
+
116
+ self.audio_embedding = nn.ModuleList(
117
+ [
118
+ TokenEmbedding(
119
+ dim_model=self.args.audio_embedding_dim,
120
+ vocab_size=self.n_audio_tokens[k],
121
+ dropout=self.args.audio_embedding_dropout
122
+ ) for k in range(self.args.n_codebooks)
123
+ ]
124
+ )
125
+ self.mask_embedding = nn.Parameter(torch.randn(self.args.max_n_spans, self.args.d_model), requires_grad=True)
126
+ self.text_positional_embedding = SinePositionalEmbedding(
127
+ self.args.d_model,
128
+ dropout=self.args.text_positional_embedding_dropout,
129
+ scale=False,
130
+ alpha=True, # learnable scaler, scale the volume of positional embedding
131
+ )
132
+ self.audio_positional_embedding = SinePositionalEmbedding(
133
+ self.args.d_model,
134
+ dropout=self.args.audio_positional_embedding_dropout,
135
+ scale=False,
136
+ alpha=True, # learnable scaler, scale the volume of positional embedding
137
+ )
138
+
139
+ dec_layer = TransformerEncoderLayer(
140
+ self.args.d_model,
141
+ self.args.nhead,
142
+ dim_feedforward=self.args.d_model * 4,
143
+ dropout=self.args.trm_dropout,
144
+ batch_first=True,
145
+ norm_first=True,
146
+ layer_norm_cls=LayerNorm
147
+ )
148
+ self.decoder = TransformerEncoder(
149
+ dec_layer,
150
+ num_layers=self.args.num_decoder_layers,
151
+ norm=LayerNorm(self.args.d_model),
152
+ )
153
+
154
+ self.predict_layer = nn.ModuleList(
155
+ [
156
+ nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
157
+ ]
158
+ )
159
+
160
+ self.accuracy_metrics = nn.ModuleList(
161
+ [MulticlassAccuracy(
162
+ self.n_audio_tokens[k],
163
+ top_k=10,
164
+ average="micro",
165
+ multidim_average="global",
166
+ ignore_index=None,
167
+ ) for k in range(self.args.n_codebooks)]
168
+ )
169
+
170
+
171
+ def prepare_mask_intervals(self, y_lens):
172
+ mask_intervals = []
173
+ non_mask_intervals = []
174
+
175
+ for i, y_len in enumerate(y_lens):
176
+ if self.args.mask_sample_dist == "uniform":
177
+ n_spans = random.choice(range(1, self.args.max_n_spans+1))
178
+ elif "poisson" in self.args.mask_sample_dist.lower():
179
+ param = float(self.args.mask_sample_dist[len("poisson"):])
180
+ poisson_sample = torch.poisson(torch.tensor([param]))
181
+ n_spans = int(poisson_sample.clamp(1, self.args.max_n_spans).item())
182
+
183
+ starts = random.sample(range(1, y_len-1-self.args.mask_len_min), n_spans)
184
+ starts = sorted(starts)
185
+
186
+ for j in range(len(starts)-1, 0, -1):
187
+ if starts[j] - starts[j-1] < self.args.min_gap:
188
+ del starts[j] # If elements are too close, delete the later one
189
+ assert len(starts) > 0, f"there is no masked span left, y_len: {y_len}, sampled n_spans: {n_spans}"
190
+
191
+ temp_starts = starts + [y_len]
192
+ gaps = [temp_starts[j+1] - temp_starts[j] for j in range(len(temp_starts)-1)]
193
+
194
+ ends = []
195
+
196
+ for j, (start, gap) in enumerate(zip(starts, gaps)):
197
+ mask_len = random.randint(self.args.mask_len_min, self.args.mask_len_max)
198
+ # if mask_len > gap * self.args.max_mask_portion: # make sure the masks are not overlapping with each other
199
+ if mask_len > gap - 1: # make sure the masks are not overlapping with each other
200
+ # temp_mask_start = int(0.6*gap*self.args.max_mask_portion)
201
+ # temp_mask_end = int(gap*self.args.max_mask_portion)
202
+ temp_mask_start = 1
203
+ temp_mask_end = gap - 1
204
+ mask_len = random.randint(temp_mask_start, temp_mask_end)
205
+ ends.append(start + mask_len)
206
+
207
+ mask_intervals.append([(s,e) for s,e in zip(starts, ends)])
208
+ non_mask_intervals.append([(ns,ne) for ns, ne in zip([0]+ends, starts+[y_len])])
209
+
210
+ return mask_intervals, non_mask_intervals
211
+
212
+ def rearrange(self, y, non_mask_intervals, mask_intervals):
213
+ reduced_eog = getattr(self.args, "reduced_eog", 0)
214
+ rearranged_y = []
215
+ for i in range(len(y)):
216
+ if self.args.eos > 0:
217
+ assert reduced_eog
218
+ cur_y = [y[i, :, item[0]: item[1]] for item in non_mask_intervals[i][:-1]] + [torch.cat([y[i, :, non_mask_intervals[i][-1][0]: non_mask_intervals[i][-1][1]], self.eos], dim=-1)] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # only insert eog to the last non-mask-interval, which is when the utterance actual ends
219
+ else:
220
+ if reduced_eog:
221
+ cur_y = [y[i, :, item[0]: item[1]] for item in non_mask_intervals[i][:-1]] + [torch.cat([y[i, :, non_mask_intervals[i][-1][0]: non_mask_intervals[i][-1][1]], self.eog], dim=-1)] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # only insert eog to the last non-mask-interval, which is when the utterance actual ends
222
+ else:
223
+ cur_y = [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in non_mask_intervals[i]] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # eog is added to each section TODO this is not correct, I should add eog to non_mask_intervals if that segment is not the ending segment (as there is no way for the model to predict eog for those segments, and this will do harm to tts experiment, where the model randomly output eog for the first segment)
224
+ rearranged_y.append(cur_y)
225
+ return rearranged_y
226
+
227
+ def shift(self, rearranged_y):
228
+ shifted_y = []
229
+ patterns = []
230
+ for i in range(len(rearranged_y)):
231
+ cur_patterns = [self.pattern.get_pattern(cur_y.shape[1]) for cur_y in rearranged_y[i]]
232
+ out = [cur_pattern.build_pattern_sequence(z=cur_y.unsqueeze(0).contiguous(), special_token=self.args.empty_token, keep_only_valid_steps=False) for cur_pattern, cur_y in zip(cur_patterns, rearranged_y[i])]
233
+ shifted_y.append([item[0].squeeze(0) for item in out]) # the first item is values, later two are indexes and mask
234
+ patterns.append(cur_patterns)
235
+ return shifted_y, patterns
236
+
237
+ def insert_mask(self, shifted_y):
238
+ inserted_y = []
239
+ mask_position = []
240
+ mask_value = []
241
+ for i in range(len(shifted_y)):
242
+ num_masks = (len(shifted_y[i]) - 1) // 2
243
+ assert num_masks == (len(shifted_y[i]) - 1) / 2, len(shifted_y[i])
244
+ emb_inds = list(range(self.args.max_n_spans))
245
+ if self.args.shuffle_mask_embedding:
246
+ random.shuffle(emb_inds)
247
+ emb_inds_use = emb_inds[:num_masks]
248
+ emb_inds_use = emb_inds_use + emb_inds_use
249
+ mask_value.append(emb_inds_use)
250
+ cur_inserted_y = []
251
+ cur_mask_position = []
252
+ for j in range(len(shifted_y[i])-1):
253
+ cur_inserted_y.append(shifted_y[i][j])
254
+ cur_mask_position.append(sum([item.shape[1] for item in cur_inserted_y])) # each item is of shape [K S], so take shape[1]
255
+ cur_inserted_y.append(self.eog) # insert mask token of shape [K, 1], BUT we are actually using the eog token as a place holder here, as the real mask will be inserted in embed_y function
256
+
257
+ cur_inserted_y.append(shifted_y[i][-1])
258
+
259
+ inserted_y.append(cur_inserted_y)
260
+ mask_position.append(cur_mask_position)
261
+ return inserted_y, mask_position, mask_value
262
+
263
+ def cat_y(self, inserted_y, mask_position, y_lens):
264
+ reduced_eog = getattr(self.args, "reduced_eog", 0)
265
+ cated_y = []
266
+ new_y_lens = []
267
+ for i in range(len(inserted_y)):
268
+ cur_cated_y = torch.cat(inserted_y[i], dim=1) #[K S]
269
+ cur_cated_y = cur_cated_y.transpose(1,0) # [S K]
270
+ cur_cated_y_len = cur_cated_y.shape[0]
271
+ if reduced_eog:
272
+ assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i])/2 + 1) ({len(mask_position[i])/2 + 1})={y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1)}"
273
+ else:
274
+ assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i]) + 1), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i]) + 1) ({len(mask_position[i]) + 1})" # the last term represent the inserted eog token, originally it's inserted at the end of every token, but this is wrong
275
+ new_y_lens.append(cur_cated_y_len)
276
+ cated_y.append(cur_cated_y)
277
+
278
+ cated_y = torch.nn.utils.rnn.pad_sequence(cated_y, batch_first=False, padding_value=self.args.audio_pad_token)
279
+ assert cated_y.shape == torch.Size([max(new_y_lens),len(inserted_y), self.args.n_codebooks]), f"cated_y.shape: {cated_y.shape}, but it should be {torch.Size([max(new_y_lens,len(inserted_y), self.args.n_codebooks)])}"
280
+ cated_y = cated_y.permute(2,0,1) # [T,B,K]->[K,T,B]
281
+ assert cated_y.shape[0] == self.args.n_codebooks, cated_y.shape
282
+ return cated_y, torch.LongTensor(new_y_lens).to(cated_y.device)
283
+
284
+ def embed_y(self, cated_y, mask_position, mask_value):
285
+ embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D]
286
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
287
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
288
+ embedded_y = embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D]
289
+ embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D]
290
+ for i in range(len(embedded_y)):
291
+ if len(mask_position[i]) > 0:
292
+ embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]]
293
+ return embedded_y
294
+
295
+ def prepare_input_target(self, y, y_lens):
296
+ # rearrange y
297
+ # assume y shape: [B T K], K is n_codebooks
298
+ assert y.shape[1] == self.args.n_codebooks, y.shape
299
+ # sample mask_intervals
300
+ mask_intervals, non_mask_intervals = self.prepare_mask_intervals(y_lens)
301
+
302
+ # need to have EOG in each section (SOG will be generated by the pattern class)
303
+ # but mask can be inserted later after we have shifted the input
304
+ # y could be rearranged in this way:
305
+ # [
306
+ # [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
307
+ # [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
308
+ # ...
309
+ # ]
310
+ # for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
311
+ # NOTE #non_masked_part = #masked_part + 1
312
+ # NOTE *these are also the targets*
313
+ # added eog at the end of each segment (masked segment and unmasked segment)
314
+ rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
315
+ targets = rearranged_y # each element in each sample is of shape [K T]
316
+ assert targets[0][0].shape[0] == self.args.n_codebooks, targets[0][0].shape
317
+
318
+ # next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
319
+ # [[5, 1, 2, 3, 4, 5, 5],
320
+ # [5, 5, 1, 2, 3, 4, 5],
321
+ # [5, 5, 5, 1, 2, 3, 4]]
322
+ shifted_y, patterns = self.shift(rearranged_y) # each element [K S]
323
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape[0]
324
+
325
+
326
+ # then, insert mask token at the intersection of each tensor (we want to decide the arrangement of the mask (shuffle or not)), we better have a separate nn.embedding for it
327
+ # we also need to record the position of the inserted mask
328
+ inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
329
+ assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
330
+ assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
331
+
332
+ # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
333
+ cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
334
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
335
+
336
+
337
+ # embed remember to separately embed the mask tokens
338
+ embedded_y = self.embed_y(cated_y, mask_position, mask_value) #BTD
339
+ assert embedded_y.shape[1:] == torch.Size((max(new_y_lens), self.args.d_model)), embedded_y.shape
340
+
341
+ # positional embedding
342
+ y_input = self.audio_positional_embedding(embedded_y)
343
+
344
+ # make attention mask and padding mask
345
+ y_padding_mask = make_pad_mask(new_y_lens).to(y.device)
346
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device)
347
+ return y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns
348
+
349
+ def remove_mask(self, logits, mask_position, new_y_lens):
350
+ # logits: [B K S card]
351
+ logits_use = []
352
+ for i in range(len(logits)):
353
+ non_mask_positions = [-1] + mask_position[i] + [new_y_lens[i]]
354
+ non_mask_intervals = [[non_mask_positions[i]+1, non_mask_positions[i+1]] for i in range(len(non_mask_positions)-1)]
355
+ cur_logits_use = [logits[i, :, l:r] for l,r in non_mask_intervals]
356
+ logits_use.append(cur_logits_use)
357
+
358
+ return logits_use
359
+
360
+ def revert_pattern(self, patterns, logits_use):
361
+ logits_final = []
362
+ logit_masks = []
363
+ for i in range(len(logits_use)):
364
+ cur_logits = [
365
+ item.unsqueeze(0).permute(0, 3, 1, 2).contiguous() for item in logits_use[i]
366
+ ] # each item is of shape [1 K S card] [1 card K S]
367
+ cur_logits_final = [
368
+ cur_pattern.revert_pattern_logits(
369
+ item, 0, keep_only_valid_steps=False
370
+ )
371
+ for cur_pattern, item in zip(patterns[i], cur_logits)
372
+ ] # if input output order doesn't match, this step will give an error
373
+ cur_logits_final_ret = [item[0].permute(0,2,3,1).squeeze(0) for item in cur_logits_final] # each element is of shape [K,T,card]
374
+ logits_final.append(cur_logits_final_ret)
375
+ logit_masks.append([item[2] for item in cur_logits_final])
376
+
377
+ return logits_final, logit_masks
378
+
379
+ def dec_forward(
380
+ self,
381
+ x_input,
382
+ x_lens,
383
+ x_attention_mask,
384
+ x_padding_mask,
385
+ y_input,
386
+ new_y_lens,
387
+ y_attention_mask,
388
+ y_padding_mask,
389
+ past=None,
390
+ last_3_tokens=False
391
+ ):
392
+ x_attn_mask = F.pad(
393
+ x_attention_mask,
394
+ (0, new_y_lens.max()),
395
+ value=True,
396
+ ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper
397
+ y_attn_mask = F.pad(
398
+ y_attention_mask,
399
+ (x_lens.max(), 0), # y is padded at the front
400
+ value=False,
401
+ ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive
402
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
403
+
404
+ # merge key padding and attention masks
405
+ bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max()
406
+ xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1)
407
+ _xy_padding_mask = (
408
+ xy_padding_mask.view(bsz, 1, 1, src_len)
409
+ .expand(-1, self.args.nhead, -1, -1)
410
+ .reshape(bsz * self.args.nhead, 1, src_len)
411
+ )
412
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
413
+
414
+ new_attn_mask = torch.zeros_like(xy_attn_mask)
415
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
416
+ xy_attn_mask = new_attn_mask
417
+
418
+ xy_input = torch.cat([x_input, y_input], dim=1)
419
+
420
+ if past == None: # do not use kvcache
421
+ out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
422
+ return out[:, x_lens.max():], None
423
+ else: # use kvcache
424
+ if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
425
+ if last_3_tokens:
426
+ xy_input = xy_input[:, -3:]
427
+ xy_attn_mask = xy_attn_mask[:, -3:]
428
+ else:
429
+ xy_input = xy_input[:, -1:]
430
+ xy_attn_mask = xy_attn_mask[:, -1:]
431
+
432
+ out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past)
433
+ if isinstance(out, tuple): # get rid of stage_embedding
434
+ out = out[0]
435
+
436
+ if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet
437
+ return out[:, x_lens.max():], present
438
+ else: # used kvcache
439
+ return out, present
440
+
441
+ def forward(self, batch):
442
+ """
443
+ Args:
444
+ x:
445
+ A 2-D tensor of shape (N, S).
446
+ x_lens:
447
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
448
+ before padding.
449
+ y:
450
+ A 3-D tensor of shape (N, K, T).
451
+ where K is the number of codebooks
452
+ y_lens:
453
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
454
+ before padding.
455
+ """
456
+ x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
457
+ x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
458
+ y = y[:, :y_lens.max()]
459
+ assert x.ndim == 2, x.shape
460
+ assert x_lens.ndim == 1, x_lens.shape
461
+ assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
462
+ assert y_lens.ndim == 1, y_lens.shape
463
+ # makes attention mask and padding mask for x
464
+ x_padding_mask = make_pad_mask(x_lens).to(x.device)
465
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device)
466
+ x_input = self.text_embedding(x)
467
+ x_input = self.text_positional_embedding(x_input)
468
+ y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns = self.prepare_input_target(y, y_lens)
469
+ y_out = self.dec_forward(
470
+ x_input,
471
+ x_lens,
472
+ x_attention_mask,
473
+ x_padding_mask,
474
+ y_input,
475
+ new_y_lens,
476
+ y_attention_mask,
477
+ y_padding_mask
478
+ )
479
+ y_out = y_out[0] # no kv-caching during training
480
+ assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
481
+
482
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card]
483
+ # take out the mask token (using mask_position and new_y_lens) and revert (using function provided by self.pattern)
484
+ assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape
485
+
486
+ logits_use = self.remove_mask(logits, mask_position, new_y_lens)
487
+
488
+ # revert the pattern shift for each logits section in each sample
489
+ logits_final, logit_masks = self.revert_pattern(patterns, logits_use)
490
+ assert logits_final[0][0].shape[0] == self.args.n_codebooks and logits_final[0][0].shape[2] == self.n_audio_tokens[0], f"it is: {logits_final[0][0].shape}, but should be [K, T, card]"
491
+ # testing
492
+ sample_to_test = 0
493
+ assert len(logits_final[sample_to_test]) == len(targets[sample_to_test]), f"{len(logits_final[sample_to_test])}, {len(targets[sample_to_test])}"
494
+ temp = sum([logits_final[sample_to_test][i].shape[:-1] != targets[sample_to_test][i].shape for i in range(len(targets[sample_to_test]))])
495
+ assert temp == 0, f"none equal positions: {temp}, total number of elements: {len(targets[sample_to_test])}"
496
+
497
+ logit_masked = sum([(item==False).any() for cur_mask in logit_masks for item in cur_mask])
498
+ assert logit_masked == 0, logit_masks
499
+
500
+ logits = torch.cat([torch.cat(item, dim=1) for item in logits_final], dim=1) # [K, T1+T2+T3+..., card]
501
+ targets = torch.cat([torch.cat(item, dim=1) for item in targets], dim=1) # [K, T1+T2+T3+...]
502
+ assert targets.shape[0] == logits.shape[0], f"{targets.shape}, {logits.shape}"
503
+ loss = []
504
+ ntokens = []
505
+ top10acc = []
506
+ for k, (logit, target) in enumerate(zip(logits, targets)):
507
+ loss.append(F.cross_entropy(logit, target, reduction='mean'))
508
+ top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
509
+ ntokens.append(len(logit))
510
+
511
+ all_ntokens = sum(ntokens)
512
+ if self.args.codebook_weight != None:
513
+ codebook_weight = eval(self.args.codebook_weight)
514
+ else:
515
+ codebook_weight = [1.] * self.args.n_codebooks
516
+ loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)])
517
+ top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)]
518
+ top10acc = sum(top10acc_by_codebook)
519
+ ntokens = torch.tensor(all_ntokens).to(logits.device)
520
+
521
+ return {
522
+ "loss": loss,
523
+ "top10acc": top10acc,
524
+ "top10acc_by_codebook": top10acc_by_codebook,
525
+ "effective_ntoken": ntokens,
526
+ }
527
+
528
+ def inference(
529
+ self,
530
+ x: torch.Tensor,
531
+ x_lens: torch.Tensor,
532
+ y: torch.Tensor,
533
+ mask_interval: list[torch.Tensor],
534
+ top_k: int=-100,
535
+ top_p: float=1.0,
536
+ temperature: float=1.0,
537
+ stop_repetition: int=-1,
538
+ kvcache: int=1,
539
+ silence_tokens: list[int]=[1388,1898,131],
540
+ ) -> torch.Tensor:
541
+ """
542
+ Args:
543
+ x:
544
+ A 2-D tensor of shape (1, L).
545
+ x_lens:
546
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
547
+ before padding.
548
+ y:
549
+ A 3-D tensor of shape (1, T, K).
550
+ mask_interval:
551
+ a list of tensors of shape (M, 2). contains M mask_start and mask_end. list length is actually 1, because we only support single sample inference for now
552
+ top_k: (`optional`) int
553
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
554
+ top_p: (`optional`) float
555
+ For Neucleus sampling
556
+ temperature: (`optional`) float
557
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
558
+ eog_coef: (`optional`) float
559
+ if 0, no change to eog token logits, otherwise, will adjust eog token logit based on the difference between acoustic token and phn token length
560
+ stop_repetition (`optional`) int
561
+ if not -1, will set the logits of a token that repeated this many times to be -100000, to avoid generating it again. This only apply to tokens from the first codebook
562
+ allowed_repeat_tokens (`optional`) list of ints
563
+ by inspecting the validation set, get a few tokens that indeed repeat a significant amount of time, and exclude those tokens from prevent repetition
564
+ ultimate_stop_repetition (`optional`) int
565
+ no matter that token it is, stop repetition once after this number
566
+ """
567
+ assert x.ndim == 2, x.shape
568
+ assert x_lens.ndim == 1, x_lens.shape
569
+ assert y.ndim == 3, y.shape
570
+ if self.args.special_first:
571
+ y = y + int(self.args.n_special)
572
+ y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
573
+ assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
574
+ assert mask_interval.shape == torch.Size((1, mask_interval.shape[1], 2)), mask_interval
575
+
576
+ # make x attention mask and x_input
577
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
578
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
579
+ x_input = self.text_embedding(x)
580
+ x_input = self.text_positional_embedding(x_input)
581
+
582
+ # make initial y_input
583
+
584
+ # make mask_interval and non_mask_interval
585
+ y_len = y.shape[2]
586
+ y_lens = torch.LongTensor([y_len]).to(y.device)
587
+ mask_interval = mask_interval[0]
588
+ starts = [item[0].item() for item in mask_interval] + [y_len]
589
+ ends = [0] + [item[1].item() for item in mask_interval]
590
+ mask_intervals = [[
591
+ (item[0].item(), item[1].item()) for item in mask_interval
592
+ ]] # a werid name change, mask_interval is input, now is mask_intervals, with one more dimension
593
+ non_mask_intervals = [[
594
+ (ns, ne) for ns, ne in zip(ends, starts)
595
+ ]]
596
+
597
+ # rearrange y
598
+ # will add have EOG in each section (SOG will be generated by the pattern class)
599
+ # but mask can be inserted later after we have shifted the input
600
+ # y could be rearranged in this way:
601
+ # [
602
+ # [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
603
+ # [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
604
+ # ...
605
+ # ]
606
+ # for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
607
+ # NOTE #non_masked_part = #masked_part + 1
608
+ rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
609
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
610
+
611
+ # shift each element of y
612
+ # next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
613
+ # [
614
+ # [empty, 1, 2, 3, eog, empty, empty, empty],
615
+ # [empty, empty, 1, 2, 3, eog, empty, empty],
616
+ # [empty, empty, empty, 1, 2, 3, eog, empty],
617
+ # [empty, empty, empty, empty, 1, 2, 3, eog]
618
+ # ]
619
+ shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
620
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
621
+
622
+ # insert mask token at the intersction of each tensor, but *actually inserted eog as place holder*
623
+ # the position of inserted mask is also recorded
624
+ # and the mask_value, the index of the mask emb is recorded
625
+ inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
626
+ assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
627
+ assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
628
+
629
+ # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
630
+ cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
631
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
632
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
633
+
634
+ ### NOTE this is different from forward, as we will remove the masked tokens
635
+ ### say there are two masked region
636
+ ### the cated_y should be like
637
+ ### [empty a a a a mask0 empty b b b mask1 empty c c mask0 empty]
638
+ ### which means we need to take the part after the last empty out
639
+ num_mask = len(mask_position[0])//2
640
+ assert num_mask == len(mask_position[0])/2, mask_position
641
+ cated_y = cated_y[:, :mask_position[0][num_mask]+2] # of shape [K,T,B]
642
+ # logging.info(f"mask_position[0][num_mask]+2: {mask_position[0][num_mask]+2}")
643
+ more_mask_value = mask_value[0][num_mask+1:] # NOTE this will be used in the generation loop for reference for inserting mask embedding
644
+ new_y_lens[0] = mask_position[0][num_mask]+2
645
+ mask_position[0] = mask_position[0][:num_mask+1]
646
+ assert mask_position[0][num_mask]+2 == cated_y.shape[1], f"num_mask: {num_mask}, mask_position: {mask_position}, cated_y.shape: {cated_y.shape}"
647
+
648
+ # embed: remember to separately embed the mask tokens
649
+ embedded_y = self.embed_y(cated_y, mask_position, [mask_value[0][:num_mask+1]]) #BTD
650
+ # assert embedded_y.shape == torch.Size((y.shape[0], max(new_y_lens), self.args.d_model)), embedded_y.shape
651
+
652
+ # positional embedding
653
+ y_input = self.audio_positional_embedding(embedded_y)
654
+
655
+ # make attention mask and padding mask
656
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
657
+ # y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
658
+
659
+ x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
660
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
661
+
662
+
663
+ codebook_eog = [False] * self.args.n_codebooks
664
+ generated = [] # doesn't contain any empty_token, contains eog
665
+ cur_generated = []
666
+ # say 0 is empty, 4 is eog
667
+ # tensor([[ 1, 2, 3, 4, 0, 0],
668
+ # [ 0, 1, 2, 3, 4, 0],
669
+ # [ 0, 0, 1, 2, 3, 4]])
670
+ num_gen = []
671
+ cur_num_gen = 0
672
+ ##################### silence repetition handling #####################
673
+ ##################### silence repetition handling #####################
674
+ logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
675
+ consec_silence_count = 0
676
+ prev_token = None
677
+ ##################### silence repetition handling #####################
678
+ ##################### silence repetition handling #####################
679
+ # prepare the cache placeholder
680
+ # n_layers, 2, bsz, num_heads, src_len, head_dim
681
+ past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
682
+ # handle multi-span kv-cache
683
+ new_masked_span = False
684
+
685
+ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
686
+ if n_eog == 0:
687
+ logits_adjust = logits
688
+ for jj in range(1,self.args.n_codebooks):
689
+ logits_adjust[jj][self.args.eog] = -10000
690
+ logits_adjust[jj][self.args.empty_token] = -10000
691
+ ##################### silence repetition handling #####################
692
+ if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
693
+ if logits_adjust[0, prev_token] < 0:
694
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1))
695
+ else:
696
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1))
697
+ ##################### silence repetition handling #####################
698
+ if type(logits_adjust) == list:
699
+ samples_list= []
700
+ for logit in logits_adjust:
701
+ # print(logit)
702
+ # print(logit.shape)
703
+ cur_sample = topk_sampling(
704
+ logit.unsqueeze(0), top_k=top_k, top_p=top_p, temperature=temperature
705
+ ) # [1, 1]
706
+ samples_list.append(cur_sample)
707
+ samples = torch.cat(samples_list, dim=0) # [K, 1]
708
+ else:
709
+ samples = topk_sampling(
710
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
711
+ ) # [K, 1]
712
+ assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
713
+ if cur_num_gen < self.args.n_codebooks-1:
714
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
715
+ samples[-jj, 0] = self.args.empty_token
716
+
717
+ if (
718
+ samples[0,0] == self.args.eog or torch.argmax(logits[0], dim=-1) == self.args.eog or y_input.shape[1] > x_lens[0] * 10
719
+ ): # last one means y is already too long, shouldn't happen, but put it here
720
+ samples[0,0] = self.args.eog
721
+ codebook_eog[0] = True
722
+ ##################### silence repetition handling #####################
723
+ ##################### silence repetition handling #####################
724
+ if samples[0,0] in silence_tokens and samples[0,0] == prev_token:
725
+ consec_silence_count += 1
726
+ else:
727
+ consec_silence_count = 0
728
+ prev_token = samples[0,0]
729
+ ##################### silence repetition handling #####################
730
+ ##################### silence repetition handling #####################
731
+ return samples, codebook_eog, prev_token, consec_silence_count
732
+ else:
733
+ assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
734
+ logits_adjust = logits
735
+ for jj in range(n_eog+1,self.args.n_codebooks):
736
+ logits_adjust[jj][self.args.eog] = -10000
737
+ logits_adjust[jj][self.args.empty_token] = -10000
738
+ if type(logits_adjust) == list:
739
+ samples_list= []
740
+ for logit in logits_adjust:
741
+ cur_sample = topk_sampling(
742
+ logit.unsqueeze(0), top_k=top_k, top_p=top_p, temperature=temperature
743
+ ) # [1, 1]
744
+ samples_list.append(cur_sample)
745
+ samples = torch.cat(samples_list, dim=0) # [K, 1]
746
+ else:
747
+ samples = topk_sampling(
748
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
749
+ ) # [K, 1]
750
+ for jj in range(n_eog):
751
+ samples[jj, 0] = self.args.empty_token
752
+ samples[n_eog, 0] = self.args.eog
753
+ codebook_eog[n_eog] = True
754
+ return samples, codebook_eog, prev_token, consec_silence_count
755
+
756
+ while True:
757
+ y_out, present = self.dec_forward(
758
+ x_input,
759
+ x_lens,
760
+ x_attention_mask,
761
+ x_padding_mask,
762
+ y_input,
763
+ new_y_lens,
764
+ y_attention_mask,
765
+ y_padding_mask,
766
+ past=past,
767
+ last_3_tokens = new_masked_span
768
+ )
769
+ if new_masked_span:
770
+ new_masked_span = False
771
+
772
+ if past != None:
773
+ past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
774
+
775
+ y_out = y_out[:, -1:] # only take the last one
776
+
777
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
778
+ logits = logits.squeeze(0).squeeze(1) # [K card]
779
+ assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
780
+
781
+ n_eog = sum(codebook_eog)
782
+ assert n_eog < self.args.n_codebooks
783
+ if self.args.eos > 0: # eos stands for end-of-sentence, which shouldn't be used as we are doing speech editing
784
+ for jj in range(self.args.n_codebooks):
785
+ logits[jj][self.args.eos] = -10000.
786
+ # need to use a helper function to hand different n_eog cases
787
+ samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
788
+ cur_num_gen += 1
789
+ cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
790
+ # get samples_emb
791
+ samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
792
+ samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D]
793
+
794
+ if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
795
+ # re-init
796
+ codebook_eog = [False] * self.args.n_codebooks
797
+ num_gen.append(cur_num_gen)
798
+ cur_num_gen = 0
799
+ generated.append(cur_generated)
800
+ cur_generated = []
801
+
802
+ # if the current mask span is the last span, then all done
803
+ # else
804
+ # append the next mask token and the four empty tokens to start the next generation
805
+ if len(more_mask_value) > 0:
806
+ next_mask_ind = more_mask_value.pop(0)
807
+ mask_emb = self.mask_embedding[next_mask_ind].unsqueeze(0).unsqueeze(0) # [1,1,D]
808
+ assert mask_emb.shape == torch.Size((1,1,self.args.d_model)), mask_emb.shape
809
+ empty_token = torch.LongTensor([self.args.empty_token]).to(y.device)
810
+ empty_emb = torch.stack([
811
+ self.audio_embedding[k](empty_token) for k in range(self.args.n_codebooks)], dim=0
812
+ ).sum(dim=0, keepdim=True) # [1,1,D]
813
+ assert empty_emb.shape == torch.Size((1,1,self.args.d_model)), empty_emb.shape
814
+ extra_emb = torch.cat([mask_emb, empty_emb], dim=1) # [1,2,D]
815
+ samples_emb = torch.cat([samples_emb, extra_emb], dim=1) # [1,3,D] # prev_last_token, mask_token, empty token
816
+ assert samples_emb.shape == torch.Size((1,3,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
817
+ ##################### silence repetition handling #####################
818
+ ##################### silence repetition handling #####################
819
+ consec_silence_count = 0
820
+ prev_token = None
821
+ ##################### silence repetition handling #####################
822
+ ##################### silence repetition handling #####################
823
+
824
+ # handling kv-caching for multi-span editing
825
+ new_masked_span = True
826
+ else:
827
+ break
828
+ else:
829
+ assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
830
+
831
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
832
+ # positional embedding
833
+ y_input = self.audio_positional_embedding(embedded_y) # [B T D]
834
+ # make attention mask and padding mask
835
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
836
+ new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
837
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
838
+
839
+ assert len(generated) == num_mask, f"len(generated): {len(generated)}, num_mask: {num_mask}"
840
+
841
+ # # combine non_masked_span with generated spans
842
+ # first need to shift the generated part back
843
+ flatten_gen = []
844
+ for l, orig_span in enumerate(generated):
845
+ span = torch.stack(orig_span, dim=0) # [T K]
846
+ span = span.transpose(1,0) # [K, T]
847
+ assert span.shape[0] == self.args.n_codebooks, span.shape
848
+ unshifted_span = []
849
+ for j, s in enumerate(span):
850
+ start_from = j
851
+ end_at = - (self.args.n_codebooks - start_from)
852
+ unshifted_span.append(s[start_from:end_at])
853
+ unshifted_span = torch.stack(unshifted_span, dim=0)
854
+
855
+ assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
856
+ flatten_gen.append(unshifted_span)
857
+ # logging.info(f"unshfited_span: {unshifted_span.shape}")
858
+ # raise
859
+ assert len(non_mask_intervals[0]) - 1 == len(flatten_gen), f"len(non_mask_intervals[0]): {len(non_mask_intervals[0])}, len(flatten_gen): {len(flatten_gen)}"
860
+ res = []
861
+ for orig_interval, gen in zip(non_mask_intervals[0], flatten_gen):
862
+ res.append(y[0, :, orig_interval[0]:orig_interval[1]])
863
+ res.append(gen)
864
+ res.append(y[0, :, non_mask_intervals[0][-1][0]:non_mask_intervals[0][-1][1]])
865
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K,new_T] -> [1, K, new_T]
866
+
867
+ expected_y_len = y_len - sum([item[1] - item[0] for item in mask_intervals[0]]) + sum([item - self.args.n_codebooks for item in num_gen])
868
+ assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len - sum([item[1] - item[0] for item in mask_interval]) + sum([item - self.args.n_codebooks for item in num_gen]): {y_len}-{sum([item[1] - item[0] for item in mask_interval])} + {sum([item - self.args.n_codebooks for item in num_gen])}"
869
+
870
+ if self.args.special_first:
871
+ res = res - int(self.args.n_special)
872
+
873
+ return res
874
+
875
+ def inference_tts(
876
+ self,
877
+ x: torch.Tensor,
878
+ x_lens: torch.Tensor,
879
+ y: torch.Tensor,
880
+ top_k: int=-100,
881
+ top_p: float=1.0,
882
+ temperature: float=1.0,
883
+ stop_repetition: int=3,
884
+ kvcache: int=1,
885
+ silence_tokens: list[int]=[1388,1898,131],
886
+ *kargs
887
+ ) -> torch.Tensor:
888
+ """
889
+ different from inference_tts, this implementation uses kvcache, which should have significant speed up
890
+ Args:
891
+ x:
892
+ A 2-D tensor of shape (1, L).
893
+ x_lens:
894
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
895
+ before padding.
896
+ y:
897
+ A 3-D tensor of shape (1, T, K).
898
+ top_k: (`optional`) int
899
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
900
+ top_p: (`optional`) float
901
+ For Neucleus sampling
902
+ temperature: (`optional`) float
903
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
904
+ """
905
+ eog_inference = self.args.eos if self.args.eos>0 else self.args.eog
906
+ assert x.ndim == 2, x.shape
907
+ assert x_lens.ndim == 1, x_lens.shape
908
+ assert y.ndim == 3, y.shape
909
+ if self.args.special_first:
910
+ y = y + int(self.args.n_special)
911
+ y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
912
+ assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
913
+
914
+ # make x attention mask and x_input
915
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
916
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
917
+ x_input = self.text_embedding(x)
918
+ x_input = self.text_positional_embedding(x_input)
919
+
920
+ y_len = y.shape[2]
921
+ y_lens = torch.LongTensor([y_len]).to(y.device)
922
+
923
+ # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
924
+ rearranged_y = [[y[0]]]
925
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
926
+
927
+ # shift y to create the delayed pattern
928
+ shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
929
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
930
+ assert len(shifted_y[0]) == 1, len(shifted_y[0])
931
+
932
+ # below is different from forward or inference
933
+ # where we cut this shifted part
934
+ shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)]
935
+ assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0]
936
+
937
+ # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
938
+ # next section is concate tensors of each sample to one tensor, which we also don't need
939
+ cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B]
940
+ new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
941
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
942
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
943
+
944
+ # replace tokens in y with the embeddings, add sum codebooks up
945
+ embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D]
946
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
947
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
948
+ embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
949
+ embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
950
+
951
+ # positional embedding
952
+ y_input = self.audio_positional_embedding(embedded_y)
953
+
954
+ # make attention mask and padding mask
955
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
956
+
957
+ x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
958
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
959
+
960
+ # entering the generation stage
961
+ # starting from line 708
962
+ codebook_eog = [False] * self.args.n_codebooks
963
+ generated = [] # doesn't contain any empty token, contain eog
964
+ cur_generated = []
965
+ # say 0 is empty, 4 is eog
966
+ # tensor([[ 1, 2, 3, 4, 0, 0],
967
+ # [ 0, 1, 2, 3, 4, 0],
968
+ # [ 0, 0, 1, 2, 3, 4]])
969
+ num_gen = []
970
+ cur_num_gen = 0
971
+ ##################### silence repetition handling #####################
972
+ ##################### silence repetition handling #####################
973
+ logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
974
+ consec_silence_count = 0
975
+ prev_token = None
976
+ ##################### silence repetition handling #####################
977
+ ##################### silence repetition handling #####################
978
+
979
+ # prepare the cache placeholder
980
+ # n_layers, 2, bsz, num_heads, src_len, head_dim
981
+ past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
982
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
983
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
984
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
985
+ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
986
+ if n_eog == 0:
987
+ logits_adjust = logits
988
+ for jj in range(1,self.args.n_codebooks):
989
+ logits_adjust[jj][eog_inference] = -10000
990
+ logits_adjust[jj][self.args.empty_token] = -10000
991
+ if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
992
+ logits_adjust[0][eog_inference] = -10000
993
+ ##################### silence repetition handling #####################
994
+ if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
995
+ if logits_adjust[0, prev_token] < 0:
996
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1))
997
+ else:
998
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1))
999
+ ##################### silence repetition handling #####################
1000
+ samples = topk_sampling(
1001
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
1002
+ ) # [K, 1]
1003
+ assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
1004
+ if cur_num_gen < self.args.n_codebooks-1:
1005
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
1006
+ samples[-jj, 0] = self.args.empty_token
1007
+
1008
+ if (
1009
+ samples[0,0] == eog_inference or torch.argmax(logits[0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr//5)
1010
+ ): # last one means y is already too long, shouldn't happen, but put it here
1011
+ samples[0,0] = eog_inference
1012
+ codebook_eog[0] = True
1013
+ ##################### silence repetition handling #####################
1014
+ if samples[0,0] in silence_tokens and samples[0,0] == prev_token:
1015
+ consec_silence_count += 1
1016
+ else:
1017
+ consec_silence_count = 0
1018
+ prev_token = samples[0,0]
1019
+ ##################### silence repetition handling #####################
1020
+ return samples, codebook_eog, prev_token, consec_silence_count
1021
+ else:
1022
+ assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
1023
+ logits_adjust = logits
1024
+ for jj in range(n_eog+1,self.args.n_codebooks):
1025
+ logits_adjust[jj][eog_inference] = -10000
1026
+ logits_adjust[jj][self.args.empty_token] = -10000
1027
+ samples = topk_sampling(
1028
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
1029
+ ) # [K, 1]
1030
+ for jj in range(n_eog):
1031
+ samples[jj, 0] = self.args.empty_token
1032
+ samples[n_eog, 0] = eog_inference
1033
+ codebook_eog[n_eog] = True
1034
+ return samples, codebook_eog, prev_token, consec_silence_count
1035
+ while True:
1036
+ y_out, present = self.dec_forward(
1037
+ x_input,
1038
+ x_lens,
1039
+ x_attention_mask,
1040
+ x_padding_mask,
1041
+ y_input,
1042
+ new_y_lens,
1043
+ y_attention_mask,
1044
+ y_padding_mask,
1045
+ past=past
1046
+ )
1047
+ if past != None:
1048
+ past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
1049
+
1050
+
1051
+ y_out = y_out[:, -1:] # only take the last token
1052
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
1053
+ logits = logits.squeeze(0).squeeze(1) # [K card]
1054
+ assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
1055
+
1056
+ n_eog = sum(codebook_eog)
1057
+ assert n_eog < self.args.n_codebooks
1058
+ if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans
1059
+ for jj in range(self.args.n_codebooks):
1060
+ logits[jj][self.args.eog] = -10000.
1061
+
1062
+ samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
1063
+
1064
+ cur_num_gen += 1
1065
+ cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
1066
+
1067
+ # samples.shape is [K,1]
1068
+ # ge samples_emb
1069
+ samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
1070
+ samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D]
1071
+
1072
+ if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
1073
+ codebook_eog = [False] * self.args.n_codebooks
1074
+ num_gen.append(cur_num_gen)
1075
+ cur_num_gen = 0
1076
+ generated.append(cur_generated)
1077
+ cur_generated = []
1078
+ break
1079
+ else:
1080
+ assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
1081
+
1082
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
1083
+ y_input = self.audio_positional_embedding(embedded_y) # [B T D]
1084
+ # make attention mask and padding mask
1085
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
1086
+ new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
1087
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
1088
+
1089
+ assert len(generated) == 1, f"len(generated): {len(generated)}"
1090
+
1091
+ # revert the pattern
1092
+ flatten_gen = []
1093
+ for l, orig_span in enumerate(generated):
1094
+ span = torch.stack(orig_span, dim=0) # [T, K]
1095
+ span = span.transpose(1,0) # [K, T]
1096
+ assert span.shape[0] == self.args.n_codebooks, span.shape
1097
+ unshifted_span = []
1098
+ for j, s in enumerate(span):
1099
+ start_from = j
1100
+ end_at = - (self.args.n_codebooks - start_from)
1101
+ unshifted_span.append(s[start_from:end_at])
1102
+ unshifted_span = torch.stack(unshifted_span, dim=0)
1103
+
1104
+ assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
1105
+
1106
+ flatten_gen.append(unshifted_span)
1107
+ assert len(flatten_gen) == 1, len(flatten_gen)
1108
+
1109
+ # combine
1110
+ res = [y[0], flatten_gen[0]]
1111
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
1112
+
1113
+ expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
1114
+ assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
1115
+
1116
+ if self.args.special_first:
1117
+ res = res - int(self.args.n_special)
1118
+ flatten_gen = flatten_gen - int(self.args.n_special)
1119
+
1120
+ return res, flatten_gen[0].unsqueeze(0)
1121
+
1122
+
1123
+ def inference_tts_batch(
1124
+ self,
1125
+ x: torch.Tensor,
1126
+ x_lens: torch.Tensor,
1127
+ y: torch.Tensor,
1128
+ top_k: int=-100,
1129
+ top_p: float=1.0,
1130
+ temperature: float=1.0,
1131
+ stop_repetition: int=3,
1132
+ kvcache: int=1,
1133
+ batch_size: int=5,
1134
+ silence_tokens: list[int]=[1388,1898,131],
1135
+ *kargs
1136
+ ) -> torch.Tensor:
1137
+ """
1138
+ have a batch size when forward passing, but they are equivalant to same example but different random seed, therefore as long as one example generated eog, we can drop all other samlpes
1139
+ different from inference_tts, this implementation uses kvcache, which should have significant speed up
1140
+ Args:
1141
+ x:
1142
+ A 2-D tensor of shape (1, L).
1143
+ x_lens:
1144
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
1145
+ before padding.
1146
+ y:
1147
+ A 3-D tensor of shape (1, T, K).
1148
+ top_k: (`optional`) int
1149
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
1150
+ top_p: (`optional`) float
1151
+ For Neucleus sampling
1152
+ temperature: (`optional`) float
1153
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
1154
+ """
1155
+ eog_inference = self.args.eos if self.args.eos>0 else self.args.eog
1156
+ assert x.ndim == 2, x.shape
1157
+ assert x_lens.ndim == 1, x_lens.shape
1158
+ assert y.ndim == 3, y.shape
1159
+ if self.args.special_first:
1160
+ y = y + int(self.args.n_special)
1161
+ y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
1162
+ assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
1163
+
1164
+ # make x attention mask and x_input
1165
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
1166
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
1167
+ x_input = self.text_embedding(x)
1168
+ x_input = self.text_positional_embedding(x_input)
1169
+
1170
+ y_len = y.shape[2]
1171
+ y_lens = torch.LongTensor([y_len]).to(y.device)
1172
+
1173
+ # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
1174
+ rearranged_y = [[y[0]]]
1175
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
1176
+
1177
+ # shift y to create the delayed pattern
1178
+ shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
1179
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
1180
+ assert len(shifted_y[0]) == 1, len(shifted_y[0])
1181
+
1182
+ # below is different from forward or inference
1183
+ # where we cut this shifted part
1184
+ shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)]
1185
+ assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0]
1186
+
1187
+ # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
1188
+ # next section is concate tensors of each sample to one tensor, which we also don't need
1189
+ cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B]
1190
+ new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
1191
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
1192
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
1193
+
1194
+ # replace tokens in y with the embeddings, add sum codebooks up
1195
+ embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D]
1196
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
1197
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
1198
+ embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
1199
+ embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
1200
+
1201
+ # positional embedding
1202
+ y_input = self.audio_positional_embedding(embedded_y)
1203
+
1204
+ # make attention mask and padding mask
1205
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
1206
+
1207
+ x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
1208
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
1209
+
1210
+ # entering the generation stage
1211
+ # starting from line 708
1212
+ codebook_eog = [False] * self.args.n_codebooks
1213
+ generated = [] # doesn't contain any empty token, contain eog
1214
+ cur_generated = [[] for _ in range(batch_size)]
1215
+ # say 0 is empty, 4 is eog
1216
+ # tensor([[ 1, 2, 3, 4, 0, 0],
1217
+ # [ 0, 1, 2, 3, 4, 0],
1218
+ # [ 0, 0, 1, 2, 3, 4]])
1219
+ num_gen = []
1220
+ cur_num_gen = 0
1221
+ ##################### silence repetition handling #####################
1222
+ ##################### silence repetition handling #####################
1223
+ logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
1224
+ consec_silence_counts = [0 for _ in range(batch_size)]
1225
+ prev_tokens = [None for _ in range(batch_size)]
1226
+ ##################### silence repetition handling #####################
1227
+ ##################### silence repetition handling #####################
1228
+
1229
+ # prepare the cache placeholder
1230
+ # n_layers, 2, bsz, num_heads, src_len, head_dim
1231
+ past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
1232
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1233
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1234
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1235
+ keep = None # NOTE: this very important, tells which sample to keep
1236
+ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep):
1237
+ if n_eog == 0:
1238
+ logits_adjust = logits
1239
+ for jj in range(1,self.args.n_codebooks):
1240
+ logits_adjust[:,jj,eog_inference] = -10000
1241
+ logits_adjust[:,jj,self.args.empty_token] = -10000
1242
+ if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
1243
+ logits_adjust[:,:,eog_inference] = -10000
1244
+ ##################### silence repetition handling #####################
1245
+ for b in range(batch_size):
1246
+ prev_token = prev_tokens[b]
1247
+ consec_silence_count = consec_silence_counts[b]
1248
+ if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
1249
+ if logits_adjust[b, 0, prev_token] < 0:
1250
+ logits_adjust[b, 0, prev_token] = logits_adjust[b, 0, prev_token] * (consec_silence_count - (stop_repetition-1))
1251
+ else:
1252
+ logits_adjust[b, 0, prev_token] = logits_adjust[b, 0, prev_token] / (consec_silence_count - (stop_repetition-1))
1253
+ ##################### silence repetition handling #####################
1254
+ samples = topk_sampling(
1255
+ logits_adjust.reshape(batch_size * self.args.n_codebooks, logits_adjust.shape[-1]), top_k=top_k, top_p=top_p, temperature=temperature
1256
+ ) # [B*K, 1]
1257
+ samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
1258
+ assert samples.shape == torch.Size((batch_size, self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
1259
+ for b in range(batch_size):
1260
+ if cur_num_gen < self.args.n_codebooks-1:
1261
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
1262
+ samples[b, -jj, 0] = self.args.empty_token
1263
+
1264
+ if (
1265
+ samples[b,0,0] == eog_inference or torch.argmax(logits[b,0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[b] * (self.args.encodec_sr//5)
1266
+ ): # last one means y is already too long, shouldn't happen, but put it here
1267
+ samples[b,0,0] = eog_inference
1268
+ codebook_eog[0] = True
1269
+ keep = b # NOTE keep is a very important variable, we only return this one, note that if eog shows up in two samples, keep will be overwritten by the later one (or the last one)
1270
+ ##################### silence repetition handling #####################
1271
+ if samples[b,0,0] in silence_tokens and samples[b,0,0] == prev_tokens[b]:
1272
+ consec_silence_counts[b] += 1
1273
+ else:
1274
+ consec_silence_counts[b] = 0
1275
+ prev_tokens[b] = samples[b,0,0]
1276
+ ##################### silence repetition handling #####################
1277
+ return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
1278
+ else:
1279
+ assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
1280
+ logits_adjust = logits
1281
+ for jj in range(n_eog+1,self.args.n_codebooks):
1282
+ logits_adjust[:,jj,eog_inference] = -10000
1283
+ logits_adjust[:,jj,self.args.empty_token] = -10000
1284
+ samples = topk_sampling(
1285
+ logits_adjust.reshape(batch_size * self.args.n_codebooks, logits_adjust.shape[-1]), top_k=top_k, top_p=top_p, temperature=temperature
1286
+ ) # [B, K, 1]
1287
+ samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
1288
+ for jj in range(n_eog):
1289
+ samples[keep, jj, 0] = self.args.empty_token
1290
+ samples[keep, n_eog, 0] = eog_inference
1291
+ codebook_eog[n_eog] = True
1292
+ return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
1293
+ while True:
1294
+ # if cur_num_gen > 0, should have everything in kvcache, so only pass in the last token
1295
+ # in the first generation step, we repeat each tensor to make their first dimension of length the batch size
1296
+ if cur_num_gen == 0:
1297
+ assert x_input.ndim == 3 and x_input.shape[0] == 1, x_input.shape
1298
+ assert x_padding_mask.ndim == 2 and x_padding_mask.shape[0] == 1, x_padding_mask.shape
1299
+ assert y_input.ndim == 3 and y_input.shape[0] == 1 and y_input.shape[1] == new_y_lens[0], y_input.shape
1300
+ assert embedded_y.ndim == 3 and embedded_y.shape[0] == 1 and embedded_y.shape[1] == new_y_lens[0], embedded_y.shape
1301
+ x_input = x_input.repeat(batch_size, 1, 1)
1302
+ x_lens = x_lens.repeat(batch_size)
1303
+ # x_attention_mask = x_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
1304
+ x_padding_mask = x_padding_mask.repeat(batch_size, 1)
1305
+ y_input = y_input.repeat(batch_size, 1, 1)
1306
+ new_y_lens = new_y_lens.repeat(batch_size)
1307
+ # y_attention_mask = y_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
1308
+ y_padding_mask = y_padding_mask.repeat(batch_size, 1)
1309
+ embedded_y = embedded_y.repeat(batch_size, 1, 1) # will be used to concat with newly generated token embedding
1310
+ past = past.repeat(1, 1, batch_size) if past != None else None
1311
+ else:
1312
+ assert x_input.shape[0] == batch_size and x_padding_mask.shape[0] == batch_size and y_input.shape[0] == batch_size and new_y_lens.shape[0] == batch_size, f"x_input.shape: {x_input.shape}, x_padding_mask.shape: {x_padding_mask.shape}, y_input.shape: {y_input.shape}, new_y_lens.shape: {new_y_lens.shape}"
1313
+ y_out, present = self.dec_forward(
1314
+ x_input,
1315
+ x_lens,
1316
+ x_attention_mask,
1317
+ x_padding_mask,
1318
+ y_input,
1319
+ new_y_lens,
1320
+ y_attention_mask,
1321
+ y_padding_mask,
1322
+ past=past
1323
+ )
1324
+ if past != None:
1325
+ past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
1326
+
1327
+ # if no eog emerges, y_out should have batch size of batch_size
1328
+ if sum(codebook_eog) == 0:
1329
+ assert y_out.shape[0] == batch_size and y_out.ndim == 3, y_out.shape
1330
+ y_out = y_out[:, -1:] # only take the last token
1331
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], S==1, so [B K 1 card]
1332
+ logits = logits.squeeze(2) # [B K card]
1333
+ assert logits.shape == torch.Size((batch_size, self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
1334
+
1335
+ n_eog = sum(codebook_eog)
1336
+ if self.args.eos > 0:
1337
+ for jj in range(self.args.n_codebooks):
1338
+ logits[:,jj,self.args.eog] = -10000.
1339
+ samples, codebook_eog, prev_tokens, consec_silence_counts, keep = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep)
1340
+
1341
+ cur_num_gen += 1
1342
+ if sum(codebook_eog) == 0: # no eog yet, keep batch_size of samples
1343
+ assert keep == None
1344
+ for b in range(batch_size):
1345
+ cur_generated[b].append(samples[b].squeeze(-1))
1346
+ elif sum(codebook_eog) == 1: # the first eog just showed up in this step
1347
+ assert keep != None
1348
+ cur_generated = cur_generated[keep]
1349
+ cur_generated.append(samples[keep].squeeze(-1))
1350
+ else: # we are generating the rest eogs for the 'keep' sample
1351
+ cur_generated.append(samples[keep].squeeze(-1))
1352
+
1353
+ # samples.shape is [K,1]
1354
+ # ge samples_emb
1355
+ samples_emb = torch.stack([self.audio_embedding[k](samples[:, k]) for k in range(self.args.n_codebooks)], dim=1) # [B, K,1,D]
1356
+ assert samples_emb.shape == torch.Size([batch_size, self.args.n_codebooks, 1, self.args.d_model])
1357
+ samples_emb = samples_emb.sum(dim=1,keepdim=False) # [B,1,D]
1358
+ if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
1359
+ codebook_eog = [False] * self.args.n_codebooks
1360
+ num_gen.append(cur_num_gen)
1361
+ cur_num_gen = 0
1362
+ generated.append(cur_generated)
1363
+ cur_generated = [[] for _ in range(batch_size)]
1364
+ break
1365
+ else:
1366
+ assert samples_emb.shape == torch.Size((batch_size,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
1367
+
1368
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
1369
+ y_input = self.audio_positional_embedding(embedded_y) # [B T D]
1370
+ # make attention mask and padding mask
1371
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
1372
+ new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device).repeat(batch_size)
1373
+ y_padding_mask = torch.full((batch_size,new_y_lens[0]), False).to(y.device)
1374
+
1375
+ assert len(generated) == 1, f"len(generated): {len(generated)}"
1376
+
1377
+ # revert the pattern
1378
+ flatten_gen = []
1379
+ for l, orig_span in enumerate(generated):
1380
+ span = torch.stack(orig_span, dim=0) # [T, K]
1381
+ span = span.transpose(1,0) # [K, T]
1382
+ assert span.shape[0] == self.args.n_codebooks, span.shape
1383
+ unshifted_span = []
1384
+ for j, s in enumerate(span):
1385
+ start_from = j
1386
+ end_at = - (self.args.n_codebooks - start_from)
1387
+ unshifted_span.append(s[start_from:end_at])
1388
+ unshifted_span = torch.stack(unshifted_span, dim=0)
1389
+
1390
+ assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
1391
+
1392
+ flatten_gen.append(unshifted_span)
1393
+ assert len(flatten_gen) == 1, len(flatten_gen)
1394
+
1395
+ # combine
1396
+ res = [y[0], flatten_gen[0]]
1397
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
1398
+
1399
+ expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
1400
+ assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
1401
+
1402
+ if self.args.special_first:
1403
+ res = res - int(self.args.n_special)
1404
+ flatten_gen = flatten_gen - int(self.args.n_special)
1405
+
1406
+ return res, flatten_gen[0].unsqueeze(0)
lib/voicecraft/pretrained_models/.gitkeep ADDED
File without changes
lib/voicecraft/start-jupyter.bat ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ echo Creating and running the Jupyter container...
4
+
5
+ docker run -it -d ^
6
+ --gpus all ^
7
+ -p 8888:8888 ^
8
+ --name jupyter ^
9
+ --user root ^
10
+ -e NB_USER="%username%" ^
11
+ -e CHOWN_HOME=yes ^
12
+ -e GRANT_SUDO=yes ^
13
+ -e JUPYTER_TOKEN=mytoken ^
14
+ -w "/home/%username%" ^
15
+ -v "%cd%":"/home/%username%/work" ^
16
+ jupyter/base-notebook
17
+
18
+ if %errorlevel% == 0 (
19
+ echo Jupyter container created and running.
20
+
21
+ echo Jupyter container is running.
22
+ echo To access the Jupyter web UI, please follow these steps:
23
+ echo 1. Open your web browser
24
+ echo 2. Navigate to http://localhost:8888/?token=mytoken
25
+ echo 3. !! The default token is "mytoken" and should be changed. !!
26
+ pause
27
+ ) else (
28
+ echo Failed to create and run the Jupyter container.
29
+ )
lib/voicecraft/start-jupyter.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ ## Assumes you have docker installed with nvidia container container-toolkit
3
+ # https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/1.13.5/install-guide.html
4
+ # sudo apt-get install -y nvidia-container-toolkit-base || yay -Syu nvidia-container-toolkit || echo etc...
5
+ ## Try to start an existing container otherwise create a new one
6
+ docker start jupyter 2> /dev/null || \
7
+ docker run -it \
8
+ -d \
9
+ --gpus all \
10
+ -p 8888:8888 \
11
+ --name jupyter \
12
+ --user root \
13
+ -e NB_USER="$USER" \
14
+ -e CHOWN_HOME=yes \
15
+ -e GRANT_SUDO=yes \
16
+ -w "/home/${NB_USER}" \
17
+ -v "$PWD":"/home/$USER/work" \
18
+ jupyter/base-notebook
19
+
20
+ ## `docker logs jupyter` to get the URL link and token e.g.
21
+ ## http://127.0.0.1:8888/lab?token=blahblahblahblabhlaabhalbhalbhal
lib/voicecraft/steps/__init__.py ADDED
File without changes
lib/voicecraft/steps/optim.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import contextlib
18
+ import logging
19
+ import random
20
+ from collections import defaultdict
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch import Tensor
26
+ from torch.optim import Optimizer
27
+
28
+
29
+ class BatchedOptimizer(Optimizer):
30
+ """
31
+ This class adds to class Optimizer the capability to optimize parameters in batches:
32
+ it will stack the parameters and their grads for you so the optimizer can work
33
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
34
+ as it reduces the number of kernels launched in the optimizer.
35
+
36
+ Args:
37
+ params:
38
+ """
39
+
40
+ def __init__(self, params, defaults):
41
+ super(BatchedOptimizer, self).__init__(params, defaults)
42
+
43
+ @contextlib.contextmanager
44
+ def batched_params(self, param_group, group_params_names):
45
+ """
46
+ This function returns (technically, yields) a list of
47
+ of tuples (p, state), where
48
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
49
+ that share the same shape, and its gradient is also stacked;
50
+ `state` is the state corresponding to this batch of parameters
51
+ (it will be physically located in the "state" for one of the real
52
+ parameters, the last one that has any particular shape and dtype).
53
+
54
+ This function is decorated as a context manager so that it can
55
+ write parameters back to their "real" locations.
56
+
57
+ The idea is, instead of doing:
58
+ <code>
59
+ for p in group["params"]:
60
+ state = self.state[p]
61
+ ...
62
+ </code>
63
+ you can do:
64
+ <code>
65
+ with self.batched_params(group["params"]) as batches:
66
+ for p, state, p_names in batches:
67
+ ...
68
+ </code>
69
+
70
+ Args:
71
+ group: a parameter group, which is a list of parameters; should be
72
+ one of self.param_groups.
73
+ group_params_names: name for each parameter in group,
74
+ which is List[str].
75
+ """
76
+ batches = defaultdict(
77
+ list
78
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
79
+ batches_names = defaultdict(
80
+ list
81
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
82
+
83
+ assert len(param_group) == len(group_params_names), f"len(param_group): {len(param_group)}, len(group_params_names): {len(group_params_names)}"
84
+ for p, named_p in zip(param_group, group_params_names):
85
+ key = (str(p.dtype), *p.shape)
86
+ batches[key].append(p)
87
+ batches_names[key].append(named_p)
88
+
89
+ batches_names_keys = list(batches_names.keys())
90
+ sorted_idx = sorted(
91
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
92
+ )
93
+ batches_names = [
94
+ batches_names[batches_names_keys[idx]] for idx in sorted_idx
95
+ ]
96
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
97
+
98
+ stacked_params_dict = dict()
99
+
100
+ # turn batches into a list, in deterministic order.
101
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
102
+ # one for each batch in `batches`.
103
+ tuples = []
104
+
105
+ for batch, batch_names in zip(batches, batches_names):
106
+ p = batch[0]
107
+ # we arbitrarily store the state in the
108
+ # state corresponding to the 1st parameter in the
109
+ # group. class Optimizer will take care of saving/loading state.
110
+ state = self.state[p]
111
+ p_stacked = torch.stack(batch)
112
+ grad = torch.stack(
113
+ [
114
+ torch.zeros_like(p) if p.grad is None else p.grad
115
+ for p in batch
116
+ ]
117
+ )
118
+ p_stacked.grad = grad
119
+ stacked_params_dict[key] = p_stacked
120
+ tuples.append((p_stacked, state, batch_names))
121
+
122
+ yield tuples # <-- calling code will do the actual optimization here!
123
+
124
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
125
+ for i, p in enumerate(batch): # batch is list of Parameter
126
+ p.copy_(stacked_params[i])
127
+
128
+
129
+ class ScaledAdam(BatchedOptimizer):
130
+ """
131
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
132
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
133
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
134
+ param = underlying_param * log_scale.exp())
135
+
136
+
137
+ Args:
138
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
139
+ lr: The learning rate. We will typically use a learning rate schedule that starts
140
+ at 0.03 and decreases over time, i.e. much higher than other common
141
+ optimizers.
142
+ clipping_scale: (e.g. 2.0)
143
+ A scale for gradient-clipping: if specified, the normalized gradients
144
+ over the whole model will be clipped to have 2-norm equal to
145
+ `clipping_scale` times the median 2-norm over the most recent period
146
+ of `clipping_update_period` minibatches. By "normalized gradients",
147
+ we mean after multiplying by the rms parameter value for this tensor
148
+ [for non-scalars]; this is appropriate because our update is scaled
149
+ by this quantity.
150
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
151
+ Must satisfy 0 < beta <= beta2 < 1.
152
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
153
+ scale of each parameter tensor and scalar parameters of the mode..
154
+ If each parameter were decomposed
155
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
156
+ would be a the scaling factor on the learning rate of p_scale.
157
+ eps: A general-purpose epsilon to prevent division by zero
158
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
159
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
160
+ parameter tensor to be >= this value)
161
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
162
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
163
+ parameter tensor to be <= this value)
164
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
165
+ model has any parameters with numel() == 1).
166
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
167
+ of the parameter tensor. This is provided to save a little time
168
+ in the update.
169
+ clipping_update_period: if clipping_scale is specified, this is the period
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ params,
175
+ lr=3e-02,
176
+ clipping_scale=None,
177
+ betas=(0.9, 0.98),
178
+ scalar_lr_scale=0.1,
179
+ eps=1.0e-08,
180
+ param_min_rms=1.0e-05,
181
+ param_max_rms=3.0,
182
+ scalar_max=10.0,
183
+ size_update_period=4,
184
+ clipping_update_period=100,
185
+ parameters_names=None,
186
+ show_dominant_parameters=True,
187
+ ):
188
+
189
+ assert parameters_names is not None, (
190
+ "Please prepare parameters_names,"
191
+ "which is a List[List[str]]. Each List[str] is for a group"
192
+ "and each str is for a parameter"
193
+ )
194
+ defaults = dict(
195
+ lr=lr,
196
+ clipping_scale=clipping_scale,
197
+ betas=betas,
198
+ scalar_lr_scale=scalar_lr_scale,
199
+ eps=eps,
200
+ param_min_rms=param_min_rms,
201
+ param_max_rms=param_max_rms,
202
+ scalar_max=scalar_max,
203
+ size_update_period=size_update_period,
204
+ clipping_update_period=clipping_update_period,
205
+ )
206
+
207
+ super(ScaledAdam, self).__init__(params, defaults)
208
+ assert len(self.param_groups) == len(parameters_names)
209
+ self.parameters_names = parameters_names
210
+ self.show_dominant_parameters = show_dominant_parameters
211
+
212
+ def __setstate__(self, state):
213
+ super(ScaledAdam, self).__setstate__(state)
214
+
215
+ @torch.no_grad()
216
+ def step(self, closure=None):
217
+ """Performs a single optimization step.
218
+
219
+ Arguments:
220
+ closure (callable, optional): A closure that reevaluates the model
221
+ and returns the loss.
222
+ """
223
+ loss = None
224
+ if closure is not None:
225
+ with torch.enable_grad():
226
+ loss = closure()
227
+
228
+ batch = True
229
+
230
+ for group, group_params_names in zip(
231
+ self.param_groups, self.parameters_names
232
+ ):
233
+
234
+ with self.batched_params(
235
+ group["params"], group_params_names
236
+ ) as batches:
237
+
238
+ # batches is list of pairs (stacked_param, state). stacked_param is like
239
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
240
+ # a stacking dim, it is not a real dim.
241
+
242
+ if (
243
+ len(batches[0][1]) == 0
244
+ ): # if len(first state) == 0: not yet initialized
245
+ clipping_scale = 1
246
+ else:
247
+ clipping_scale = self._get_clipping_scale(group, batches)
248
+
249
+ for p, state, _ in batches:
250
+ # Perform optimization step.
251
+ # grad is not going to be None, we handled that when creating the batches.
252
+ grad = p.grad
253
+ if grad.is_sparse:
254
+ raise RuntimeError(
255
+ "ScaledAdam optimizer does not support sparse gradients"
256
+ )
257
+ # State initialization
258
+ if len(state) == 0:
259
+ self._init_state(group, p, state)
260
+
261
+ self._step_one_batch(group, p, state, clipping_scale)
262
+
263
+ return loss
264
+
265
+ def _init_state(self, group: dict, p: Tensor, state: dict):
266
+ """
267
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
268
+ is actually the batch dimension, corresponding to batched-together
269
+ parameters of a given shape.
270
+
271
+
272
+ Args:
273
+ group: Dict to look up configuration values.
274
+ p: The parameter that we are initializing the state for
275
+ state: Dict from string to whatever state we are initializing
276
+ """
277
+ size_update_period = group["size_update_period"]
278
+
279
+ state["step"] = 0
280
+
281
+ kwargs = {"device": p.device, "dtype": p.dtype}
282
+
283
+ # 'delta' implements conventional momentum. There are
284
+ # several different kinds of update going on, so rather than
285
+ # compute "exp_avg" like in Adam, we store and decay a
286
+ # parameter-change "delta", which combines all forms of
287
+ # update. this is equivalent to how it's done in Adam,
288
+ # except for the first few steps.
289
+ state["delta"] = torch.zeros_like(
290
+ p, memory_format=torch.preserve_format
291
+ )
292
+
293
+ batch_size = p.shape[0]
294
+ numel = p.numel() // batch_size
295
+ numel = p.numel()
296
+
297
+ if numel > 1:
298
+ # "param_rms" just periodically records the scalar root-mean-square value of
299
+ # the parameter tensor.
300
+ # it has a shape like (batch_size, 1, 1, 1, 1)
301
+ param_rms = (
302
+ (p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
303
+ )
304
+ state["param_rms"] = param_rms
305
+
306
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
307
+ state["scale_grads"] = torch.zeros(
308
+ size_update_period, *param_rms.shape, **kwargs
309
+ )
310
+
311
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
312
+ state["exp_avg_sq"] = torch.zeros_like(
313
+ p, memory_format=torch.preserve_format
314
+ )
315
+
316
+ def _get_clipping_scale(
317
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
318
+ ) -> float:
319
+ """
320
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
321
+ by this amount before applying the rest of the update.
322
+
323
+ Args:
324
+ group: the parameter group, an item in self.param_groups
325
+ tuples: a list of tuples of (param, state, param_names)
326
+ where param is a batched set of parameters,
327
+ with a .grad (1st dim is batch dim)
328
+ and state is the state-dict where optimization parameters are kept.
329
+ param_names is a List[str] while each str is name for a parameter
330
+ in batched set of parameters "param".
331
+ """
332
+ assert len(tuples) >= 1
333
+ clipping_scale = group["clipping_scale"]
334
+ (first_p, first_state, _) = tuples[0]
335
+ step = first_state["step"]
336
+ if clipping_scale is None or step == 0:
337
+ # no clipping. return early on step == 0 because the other
338
+ # parameters' state won't have been initialized yet.
339
+ return 1.0
340
+ clipping_update_period = group["clipping_update_period"]
341
+
342
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
343
+ for (p, state, param_names) in tuples:
344
+ grad = p.grad
345
+ if grad.is_sparse:
346
+ raise RuntimeError(
347
+ "ScaledAdam optimizer does not support sparse gradients"
348
+ )
349
+ if p.numel() == p.shape[0]: # a batch of scalars
350
+ tot_sumsq += (
351
+ grad ** 2
352
+ ).sum() # sum() to change shape [1] to []
353
+ else:
354
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
355
+
356
+ tot_norm = tot_sumsq.sqrt()
357
+ if "model_norms" not in first_state:
358
+ first_state["model_norms"] = torch.zeros(
359
+ clipping_update_period, device=p.device
360
+ )
361
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
362
+
363
+ if step % clipping_update_period == 0:
364
+ # Print some stats.
365
+ # We don't reach here if step == 0 because we would have returned
366
+ # above.
367
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
368
+ quartiles = []
369
+ for n in range(0, 5):
370
+ index = min(
371
+ clipping_update_period - 1,
372
+ (clipping_update_period // 4) * n,
373
+ )
374
+ quartiles.append(sorted_norms[index].item())
375
+
376
+ median = quartiles[2]
377
+ threshold = clipping_scale * median
378
+ first_state["model_norm_threshold"] = threshold
379
+ percent_clipped = (
380
+ first_state["num_clipped"] * 100.0 / clipping_update_period
381
+ if "num_clipped" in first_state
382
+ else 0.0
383
+ )
384
+ first_state["num_clipped"] = 0
385
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
386
+ logging.info(
387
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
388
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
389
+ )
390
+
391
+ if step < clipping_update_period:
392
+ return 1.0 # We have not yet estimated a norm to clip to.
393
+ else:
394
+ try:
395
+ model_norm_threshold = first_state["model_norm_threshold"]
396
+ except KeyError:
397
+ logging.info(
398
+ "Warning: model_norm_threshold not in state: possibly "
399
+ "you changed config when restarting, adding clipping_scale option?"
400
+ )
401
+ return 1.0
402
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
403
+ if ans < 1.0:
404
+ first_state["num_clipped"] += 1
405
+ if ans < 0.1:
406
+ logging.warn(
407
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
408
+ )
409
+ if self.show_dominant_parameters:
410
+ assert p.shape[0] == len(param_names)
411
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
412
+ return ans
413
+
414
+ def _show_gradient_dominating_parameter(
415
+ self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
416
+ ):
417
+ """
418
+ Show information of parameter wihch dominanting tot_sumsq.
419
+
420
+ Args:
421
+ tuples: a list of tuples of (param, state, param_names)
422
+ where param is a batched set of parameters,
423
+ with a .grad (1st dim is batch dim)
424
+ and state is the state-dict where optimization parameters are kept.
425
+ param_names is a List[str] while each str is name for a parameter
426
+ in batched set of parameters "param".
427
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
428
+ from tuples, we still pass it to save some time.
429
+ """
430
+ all_sumsq_orig = {}
431
+ for (p, state, batch_param_names) in tuples:
432
+ # p is a stacked batch parameters.
433
+ batch_grad = p.grad
434
+ if p.numel() == p.shape[0]: # a batch of scalars
435
+ batch_sumsq_orig = batch_grad ** 2
436
+ # Dummpy values used by following `zip` statement.
437
+ batch_rms_orig = torch.ones(p.shape[0])
438
+ else:
439
+ batch_rms_orig = state["param_rms"]
440
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
441
+ dim=list(range(1, batch_grad.ndim))
442
+ )
443
+
444
+ for name, sumsq_orig, rms, grad in zip(
445
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
446
+ ):
447
+
448
+ proportion_orig = sumsq_orig / tot_sumsq
449
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
450
+
451
+ assert torch.isclose(
452
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
453
+ torch.tensor(1.0),
454
+ )
455
+ sorted_by_proportion = {
456
+ k: v
457
+ for k, v in sorted(
458
+ all_sumsq_orig.items(),
459
+ key=lambda item: item[1][0],
460
+ reverse=True,
461
+ )
462
+ }
463
+ dominant_param_name = next(iter(sorted_by_proportion))
464
+ (
465
+ dominant_proportion,
466
+ dominant_sumsq,
467
+ dominant_rms,
468
+ dominant_grad,
469
+ ) = sorted_by_proportion[dominant_param_name]
470
+ logging.info(
471
+ f"Parameter Dominanting tot_sumsq {dominant_param_name}"
472
+ f" with proportion {dominant_proportion:.2f},"
473
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
474
+ f"={dominant_sumsq:.3e},"
475
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
476
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
477
+ )
478
+
479
+ def _step_one_batch(
480
+ self, group: dict, p: Tensor, state: dict, clipping_scale: float
481
+ ):
482
+ """
483
+ Do the step for one parameter, which is actually going to be a batch of
484
+ `real` parameters, with dim 0 as the batch dim.
485
+ Args:
486
+ group: dict to look up configuration values
487
+ p: parameter to update (actually multiple parameters stacked together
488
+ as a batch)
489
+ state: state-dict for p, to look up the optimizer state
490
+ """
491
+ lr = group["lr"]
492
+ size_update_period = group["size_update_period"]
493
+ beta1 = group["betas"][0]
494
+
495
+ grad = p.grad
496
+ if clipping_scale != 1.0:
497
+ grad = grad * clipping_scale
498
+ step = state["step"]
499
+ delta = state["delta"]
500
+
501
+ delta.mul_(beta1)
502
+ batch_size = p.shape[0]
503
+ numel = p.numel() // batch_size
504
+ if numel > 1:
505
+ # Update the size/scale of p, and set param_rms
506
+ scale_grads = state["scale_grads"]
507
+ scale_grads[step % size_update_period] = (p * grad).sum(
508
+ dim=list(range(1, p.ndim)), keepdim=True
509
+ )
510
+ if step % size_update_period == size_update_period - 1:
511
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
512
+ param_rms.copy_(
513
+ (p ** 2)
514
+ .mean(dim=list(range(1, p.ndim)), keepdim=True)
515
+ .sqrt()
516
+ )
517
+ if step > 0:
518
+ # self._size_update() learns the overall scale on the
519
+ # parameter, by shrinking or expanding it.
520
+ self._size_update(group, scale_grads, p, state)
521
+
522
+ if numel == 1:
523
+ # For parameters with 1 element we just use regular Adam.
524
+ # Updates delta.
525
+ self._step_scalar(group, p, state)
526
+ else:
527
+ self._step(group, p, state)
528
+
529
+ state["step"] = step + 1
530
+
531
+ def _size_update(
532
+ self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
533
+ ) -> None:
534
+ """
535
+ Called only where p.numel() > 1, this updates the scale of the parameter.
536
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
537
+ gradient descent on underlying param and on scale, this function does the update
538
+ on `scale`.
539
+
540
+ Args:
541
+ group: dict to look up configuration values
542
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
543
+ grads w.r.t. the scales.
544
+ p: The parameter to update
545
+ state: The state-dict of p
546
+ """
547
+
548
+ param_rms = state["param_rms"]
549
+ beta1, beta2 = group["betas"]
550
+ size_lr = group["lr"] * group["scalar_lr_scale"]
551
+ param_min_rms = group["param_min_rms"]
552
+ param_max_rms = group["param_max_rms"]
553
+ eps = group["eps"]
554
+ step = state["step"]
555
+ batch_size = p.shape[0]
556
+
557
+ size_update_period = scale_grads.shape[0]
558
+ # correct beta2 for the size update period: we will have
559
+ # faster decay at this level.
560
+ beta2_corr = beta2 ** size_update_period
561
+
562
+ scale_exp_avg_sq = state[
563
+ "scale_exp_avg_sq"
564
+ ] # shape: (batch_size, 1, 1, ..)
565
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
566
+ (scale_grads ** 2).mean(
567
+ dim=0
568
+ ), # mean over dim `size_update_period`
569
+ alpha=1 - beta2_corr,
570
+ ) # shape is (batch_size, 1, 1, ...)
571
+
572
+ # The 1st time we reach here is when size_step == 1.
573
+ size_step = (step + 1) // size_update_period
574
+ bias_correction2 = 1 - beta2_corr ** size_step
575
+ # we don't bother with bias_correction1; this will help prevent divergence
576
+ # at the start of training.
577
+
578
+ denom = scale_exp_avg_sq.sqrt() + eps
579
+
580
+ scale_step = (
581
+ -size_lr
582
+ * (bias_correction2 ** 0.5)
583
+ * scale_grads.sum(dim=0)
584
+ / denom
585
+ )
586
+
587
+ is_too_small = param_rms < param_min_rms
588
+ is_too_large = param_rms > param_max_rms
589
+
590
+ # when the param gets too small, just don't shrink it any further.
591
+ scale_step.masked_fill_(is_too_small, 0.0)
592
+ # when it gets too large, stop it from getting any larger.
593
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
594
+ delta = state["delta"]
595
+ # the factor of (1-beta1) relates to momentum.
596
+ delta.add_(p * scale_step, alpha=(1 - beta1))
597
+
598
+ def _step(self, group: dict, p: Tensor, state: dict):
599
+ """
600
+ This function does the core update of self.step(), in the case where the members of
601
+ the batch have more than 1 element.
602
+
603
+ Args:
604
+ group: A dict which will be used to look up configuration values
605
+ p: The parameter to be updated
606
+ grad: The grad of p
607
+ state: The state-dict corresponding to parameter p
608
+
609
+ This function modifies p.
610
+ """
611
+ grad = p.grad
612
+ lr = group["lr"]
613
+ beta1, beta2 = group["betas"]
614
+ eps = group["eps"]
615
+ param_min_rms = group["param_min_rms"]
616
+ step = state["step"]
617
+
618
+ exp_avg_sq = state["exp_avg_sq"]
619
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
620
+
621
+ this_step = state["step"] - (
622
+ state["zero_step"] if "zero_step" in state else 0
623
+ )
624
+ bias_correction2 = 1 - beta2 ** (this_step + 1)
625
+ if bias_correction2 < 0.99:
626
+ # note: not in-place.
627
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
628
+
629
+ denom = exp_avg_sq.sqrt()
630
+ denom += eps
631
+ grad = grad / denom
632
+
633
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
634
+
635
+ delta = state["delta"]
636
+ delta.add_(grad * alpha)
637
+ p.add_(delta)
638
+
639
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
640
+ """
641
+ A simplified form of the core update for scalar tensors, where we cannot get a good
642
+ estimate of the parameter rms.
643
+ """
644
+ beta1, beta2 = group["betas"]
645
+ scalar_max = group["scalar_max"]
646
+ eps = group["eps"]
647
+ lr = group["lr"] * group["scalar_lr_scale"]
648
+ grad = p.grad
649
+
650
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
651
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
652
+
653
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
654
+ # slower update at the start will help stability anyway.
655
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
656
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
657
+
658
+ delta = state["delta"]
659
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
660
+ p.clamp_(min=-scalar_max, max=scalar_max)
661
+ p.add_(delta)
662
+
663
+
664
+ class LRScheduler(object):
665
+ """
666
+ Base-class for learning rate schedulers where the learning-rate depends on both the
667
+ batch and the epoch.
668
+ """
669
+
670
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
671
+ # Attach optimizer
672
+ if not isinstance(optimizer, Optimizer):
673
+ raise TypeError(
674
+ "{} is not an Optimizer".format(type(optimizer).__name__)
675
+ )
676
+ self.optimizer = optimizer
677
+ self.verbose = verbose
678
+
679
+ for group in optimizer.param_groups:
680
+ group.setdefault("base_lr", group["lr"])
681
+
682
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
683
+
684
+ self.epoch = 0
685
+ self.batch = 0
686
+
687
+ def state_dict(self):
688
+ """Returns the state of the scheduler as a :class:`dict`.
689
+
690
+ It contains an entry for every variable in self.__dict__ which
691
+ is not the optimizer.
692
+ """
693
+ return {
694
+ "base_lrs": self.base_lrs,
695
+ "epoch": self.epoch,
696
+ "batch": self.batch,
697
+ }
698
+
699
+ def load_state_dict(self, state_dict):
700
+ """Loads the schedulers state.
701
+
702
+ Args:
703
+ state_dict (dict): scheduler state. Should be an object returned
704
+ from a call to :meth:`state_dict`.
705
+ """
706
+ self.__dict__.update(state_dict)
707
+
708
+ def get_last_lr(self) -> List[float]:
709
+ """Return last computed learning rate by current scheduler. Will be a list of float."""
710
+ return self._last_lr
711
+
712
+ def get_lr(self):
713
+ # Compute list of learning rates from self.epoch and self.batch and
714
+ # self.base_lrs; this must be overloaded by the user.
715
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
716
+ raise NotImplementedError
717
+
718
+ def step_batch(self, batch: Optional[int] = None) -> None:
719
+ # Step the batch index, or just set it. If `batch` is specified, it
720
+ # must be the batch index from the start of training, i.e. summed over
721
+ # all epochs.
722
+ # You can call this in any order; if you don't provide 'batch', it should
723
+ # of course be called once per batch.
724
+ if batch is not None:
725
+ self.batch = batch
726
+ else:
727
+ self.batch = self.batch + 1
728
+ self._set_lrs()
729
+
730
+ def step_epoch(self, epoch: Optional[int] = None):
731
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg,
732
+ # you should call this at the start of the epoch; if you don't provide the 'epoch'
733
+ # arg, you should call it at the end of the epoch.
734
+ if epoch is not None:
735
+ self.epoch = epoch
736
+ else:
737
+ self.epoch = self.epoch + 1
738
+ self._set_lrs()
739
+
740
+ def _set_lrs(self):
741
+ values = self.get_lr()
742
+ assert len(values) == len(self.optimizer.param_groups)
743
+
744
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
745
+ param_group, lr = data
746
+ param_group["lr"] = lr
747
+ self.print_lr(self.verbose, i, lr)
748
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
749
+
750
+ def print_lr(self, is_verbose, group, lr):
751
+ """Display the current learning rate."""
752
+ if is_verbose:
753
+ logging.info(
754
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
755
+ f" of group {group} to {lr:.4e}."
756
+ )
757
+
758
+
759
+ class Eden(LRScheduler):
760
+ """
761
+ Eden scheduler.
762
+ The basic formula (before warmup) is:
763
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
764
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
765
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
766
+ and then stays constant at 1.
767
+
768
+
769
+ E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
770
+
771
+ Args:
772
+ optimizer: the optimizer to change the learning rates on
773
+ lr_batches: the number of batches after which we start significantly
774
+ decreasing the learning rate, suggest 5000.
775
+ lr_epochs: the number of epochs after which we start significantly
776
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
777
+ 20 to 40 epochs, but may need smaller number if dataset is huge
778
+ and you will do few epochs.
779
+ """
780
+
781
+ def __init__(
782
+ self,
783
+ optimizer: Optimizer,
784
+ lr_batches: Union[int, float],
785
+ lr_epochs: Union[int, float],
786
+ warmup_batches: Union[int, float] = 500.0,
787
+ verbose: bool = False,
788
+ ):
789
+ super(Eden, self).__init__(optimizer, verbose)
790
+ self.lr_batches = lr_batches
791
+ self.lr_epochs = lr_epochs
792
+ self.warmup_batches = warmup_batches
793
+
794
+ def get_lr(self):
795
+ factor = (
796
+ (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
797
+ ) ** -0.25 * (
798
+ ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
799
+ ** -0.25
800
+ )
801
+ warmup_factor = (
802
+ 1.0
803
+ if self.batch >= self.warmup_batches
804
+ else 0.5 + 0.5 * (self.batch / self.warmup_batches)
805
+ )
806
+
807
+ return [x * factor * warmup_factor for x in self.base_lrs]
808
+
809
+
810
+ def _test_eden():
811
+ m = torch.nn.Linear(100, 100)
812
+ optim = ScaledAdam(m.parameters(), lr=0.03)
813
+
814
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
815
+
816
+ for epoch in range(10):
817
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
818
+
819
+ for step in range(20):
820
+ x = torch.randn(200, 100).detach()
821
+ x.requires_grad = True
822
+ y = m(x)
823
+ dy = torch.randn(200, 100).detach()
824
+ f = (y * dy).sum()
825
+ f.backward()
826
+
827
+ optim.step()
828
+ scheduler.step_batch()
829
+ optim.zero_grad()
830
+
831
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
832
+ logging.info(f"state dict = {scheduler.state_dict()}")
833
+
834
+
835
+ # This is included mostly as a baseline for ScaledAdam.
836
+ class Eve(Optimizer):
837
+ """
838
+ Implements Eve algorithm. This is a modified version of AdamW with a special
839
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
840
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
841
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
842
+ will be close to invariant to the absolute scale on the parameter matrix.
843
+
844
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
845
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
846
+ Eve is unpublished so far.
847
+
848
+ Arguments:
849
+ params (iterable): iterable of parameters to optimize or dicts defining
850
+ parameter groups
851
+ lr (float, optional): learning rate (default: 1e-3)
852
+ betas (Tuple[float, float], optional): coefficients used for computing
853
+ running averages of gradient and its square (default: (0.9, 0.999))
854
+ eps (float, optional): term added to the denominator to improve
855
+ numerical stability (default: 1e-8)
856
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
857
+ this value means that the weight would decay significantly after
858
+ about 3k minibatches. Is not multiplied by learning rate, but
859
+ is conditional on RMS-value of parameter being > target_rms.
860
+ target_rms (float, optional): target root-mean-square value of
861
+ parameters, if they fall below this we will stop applying weight decay.
862
+
863
+
864
+ .. _Adam: A Method for Stochastic Optimization:
865
+ https://arxiv.org/abs/1412.6980
866
+ .. _Decoupled Weight Decay Regularization:
867
+ https://arxiv.org/abs/1711.05101
868
+ .. _On the Convergence of Adam and Beyond:
869
+ https://openreview.net/forum?id=ryQu7f-RZ
870
+ """
871
+
872
+ def __init__(
873
+ self,
874
+ params,
875
+ lr=1e-3,
876
+ betas=(0.9, 0.98),
877
+ eps=1e-8,
878
+ weight_decay=1e-3,
879
+ target_rms=0.1,
880
+ ):
881
+ if not 0.0 <= lr:
882
+ raise ValueError("Invalid learning rate: {}".format(lr))
883
+ if not 0.0 <= eps:
884
+ raise ValueError("Invalid epsilon value: {}".format(eps))
885
+ if not 0.0 <= betas[0] < 1.0:
886
+ raise ValueError(
887
+ "Invalid beta parameter at index 0: {}".format(betas[0])
888
+ )
889
+ if not 0.0 <= betas[1] < 1.0:
890
+ raise ValueError(
891
+ "Invalid beta parameter at index 1: {}".format(betas[1])
892
+ )
893
+ if not 0 <= weight_decay <= 0.1:
894
+ raise ValueError(
895
+ "Invalid weight_decay value: {}".format(weight_decay)
896
+ )
897
+ if not 0 < target_rms <= 10.0:
898
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
899
+ defaults = dict(
900
+ lr=lr,
901
+ betas=betas,
902
+ eps=eps,
903
+ weight_decay=weight_decay,
904
+ target_rms=target_rms,
905
+ )
906
+ super(Eve, self).__init__(params, defaults)
907
+
908
+ def __setstate__(self, state):
909
+ super(Eve, self).__setstate__(state)
910
+
911
+ @torch.no_grad()
912
+ def step(self, closure=None):
913
+ """Performs a single optimization step.
914
+
915
+ Arguments:
916
+ closure (callable, optional): A closure that reevaluates the model
917
+ and returns the loss.
918
+ """
919
+ loss = None
920
+ if closure is not None:
921
+ with torch.enable_grad():
922
+ loss = closure()
923
+
924
+ for group in self.param_groups:
925
+ for p in group["params"]:
926
+ if p.grad is None:
927
+ continue
928
+
929
+ # Perform optimization step
930
+ grad = p.grad
931
+ if grad.is_sparse:
932
+ raise RuntimeError(
933
+ "AdamW does not support sparse gradients"
934
+ )
935
+
936
+ state = self.state[p]
937
+
938
+ # State initialization
939
+ if len(state) == 0:
940
+ state["step"] = 0
941
+ # Exponential moving average of gradient values
942
+ state["exp_avg"] = torch.zeros_like(
943
+ p, memory_format=torch.preserve_format
944
+ )
945
+ # Exponential moving average of squared gradient values
946
+ state["exp_avg_sq"] = torch.zeros_like(
947
+ p, memory_format=torch.preserve_format
948
+ )
949
+
950
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
951
+
952
+ beta1, beta2 = group["betas"]
953
+
954
+ state["step"] += 1
955
+ bias_correction1 = 1 - beta1 ** state["step"]
956
+ bias_correction2 = 1 - beta2 ** state["step"]
957
+
958
+ # Decay the first and second moment running average coefficient
959
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
960
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
961
+ denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
962
+ group["eps"]
963
+ )
964
+
965
+ step_size = group["lr"] / bias_correction1
966
+ target_rms = group["target_rms"]
967
+ weight_decay = group["weight_decay"]
968
+
969
+ if p.numel() > 1:
970
+ # avoid applying this weight-decay on "scaling factors"
971
+ # (which are scalar).
972
+ is_above_target_rms = p.norm() > (
973
+ target_rms * (p.numel() ** 0.5)
974
+ )
975
+ p.mul_(1 - (weight_decay * is_above_target_rms))
976
+
977
+ p.addcdiv_(exp_avg, denom, value=-step_size)
978
+
979
+ # if random.random() < 0.0005:
980
+ # step = (exp_avg / denom) * step_size
981
+ # logging.info(
982
+ # f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
983
+ # )
984
+
985
+ return loss
986
+
987
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
988
+ """
989
+ Behaves like a constructor of a modified version of nn.Linear
990
+ that gives an easy way to set the default initial parameter scale.
991
+
992
+ Args:
993
+ Accepts the standard args and kwargs that nn.Linear accepts
994
+ e.g. in_features, out_features, bias=False.
995
+
996
+ initial_scale: you can override this if you want to increase
997
+ or decrease the initial magnitude of the module's output
998
+ (affects the initialization of weight_scale and bias_scale).
999
+ Another option, if you want to do something like this, is
1000
+ to re-initialize the parameters.
1001
+ """
1002
+ ans = nn.Linear(*args, **kwargs)
1003
+ with torch.no_grad():
1004
+ ans.weight[:] *= initial_scale
1005
+ if ans.bias is not None:
1006
+ torch.nn.init.uniform_(
1007
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
1008
+ )
1009
+ return ans
1010
+ def _test_scaled_adam(hidden_dim: int):
1011
+ import timeit
1012
+
1013
+ E = 100
1014
+ B = 4
1015
+ T = 2
1016
+ logging.info("in test_eve_cain")
1017
+ # device = torch.device('cuda')
1018
+ device = torch.device("cpu")
1019
+ dtype = torch.float32
1020
+
1021
+ # these input_magnitudes and output_magnitudes are to test that
1022
+ # Abel is working as we expect and is able to adjust scales of
1023
+ # different dims differently.
1024
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1025
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1026
+
1027
+ for iter in [1, 0]:
1028
+ Linear = torch.nn.Linear if iter == 0 else ScaledLinear
1029
+
1030
+ m = torch.nn.Sequential(
1031
+ Linear(E, hidden_dim),
1032
+ torch.nn.PReLU(),
1033
+ Linear(hidden_dim, hidden_dim),
1034
+ torch.nn.PReLU(),
1035
+ Linear(hidden_dim, E),
1036
+ ).to(device)
1037
+
1038
+ train_pairs = [
1039
+ (
1040
+ 100.0
1041
+ * torch.randn(B, T, E, device=device, dtype=dtype)
1042
+ * input_magnitudes,
1043
+ torch.randn(B, T, E, device=device, dtype=dtype)
1044
+ * output_magnitudes,
1045
+ )
1046
+ for _ in range(20)
1047
+ ]
1048
+
1049
+ if iter == 0:
1050
+ optim = Eve(m.parameters(), lr=0.003)
1051
+ elif iter == 1:
1052
+ optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
1053
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
1054
+
1055
+ start = timeit.default_timer()
1056
+ avg_loss = 0.0
1057
+ for epoch in range(180):
1058
+ scheduler.step_epoch()
1059
+ # if epoch == 100 and iter in [2,3]:
1060
+ # optim.reset_speedup() # check it doesn't crash.
1061
+
1062
+ # if epoch == 130:
1063
+ # opts = diagnostics.TensorDiagnosticOptions(
1064
+ # 2 ** 22
1065
+ # ) # allow 4 megabytes per sub-module
1066
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
1067
+
1068
+ for n, (x, y) in enumerate(train_pairs):
1069
+ y_out = m(x)
1070
+ loss = ((y_out - y) ** 2).mean() * 100.0
1071
+ if epoch == 0 and n == 0:
1072
+ avg_loss = loss.item()
1073
+ else:
1074
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
1075
+ if n == 0 and epoch % 5 == 0:
1076
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
1077
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
1078
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
1079
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
1080
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
1081
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
1082
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
1083
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
1084
+ lr = scheduler.get_last_lr()[0]
1085
+ logging.info(
1086
+ f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
1087
+ ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
1088
+ loss.log().backward()
1089
+ optim.step()
1090
+ optim.zero_grad()
1091
+ scheduler.step_batch()
1092
+
1093
+ # diagnostic.print_diagnostics()
1094
+
1095
+ stop = timeit.default_timer()
1096
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
1097
+
1098
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
1099
+ # logging.info("state dict = ", scheduler.state_dict())
1100
+ # logging.info("optim state_dict = ", optim.state_dict())
1101
+ logging.info(f"input_magnitudes = {input_magnitudes}")
1102
+ logging.info(f"output_magnitudes = {output_magnitudes}")
1103
+
1104
+
1105
+ if __name__ == "__main__":
1106
+ torch.set_num_threads(1)
1107
+ torch.set_num_interop_threads(1)
1108
+ logging.getLogger().setLevel(logging.INFO)
1109
+ import subprocess
1110
+
1111
+ s = subprocess.check_output(
1112
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
1113
+ )
1114
+ logging.info(s)
1115
+ import sys
1116
+
1117
+ if len(sys.argv) > 1:
1118
+ hidden_dim = int(sys.argv[1])
1119
+ else:
1120
+ hidden_dim = 200
1121
+
1122
+ _test_scaled_adam(hidden_dim)
1123
+ _test_eden()
lib/voicecraft/steps/trainer.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os, random
3
+ import torch
4
+ import math, pickle
5
+ from tqdm import tqdm
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ import torch.nn as nn
9
+ import torch.distributed as dist
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ import numpy as np
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ import logging
14
+ from data import gigaspeech
15
+ from models import voicecraft
16
+
17
+ from .trainer_utils import DistributedDynamicBatchSampler, StatefulDistributedSampler, AverageMeter, print_model_info
18
+ from .optim import ScaledAdam, Eden
19
+
20
+
21
+ class Trainer:
22
+
23
+ def __init__(self, args, world_size, rank):
24
+ self.start_time = time.time()
25
+ self.args = args
26
+ self.world_size, self.rank = world_size, rank
27
+ self.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
28
+ if self.rank == 0:
29
+ self.writer = SummaryWriter(args.exp_dir)
30
+ self.seed_everything(seed=self.args.seed)
31
+ self.meters = self._setup_meters()
32
+
33
+ self.progress, self.total_progress = self._setup_progress()
34
+
35
+ self.model, self.trainables, self.optim_states, self.scheduler_states = self._setup_models()
36
+
37
+ self.train_dataset_length, self.train_sampler, self.train_loader, self.valid_loader = self._setup_dataloader()
38
+ if self.args.num_steps != None:
39
+ self.total_step = self.args.num_steps
40
+ self.args.num_epochs = math.ceil(self.total_step / math.floor(self.train_dataset_length / self.args.batch_size)) if not self.args.dynamic_batching else None
41
+ else:
42
+ self.total_step = int(math.floor(self.train_dataset_length / self.args.batch_size))*self.args.num_epochs
43
+
44
+ self.optimizer, self.scheduler = self._setup_optimizer()
45
+ self.scaler = torch.cuda.amp.GradScaler()
46
+ self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank], find_unused_parameters=False)
47
+
48
+ if self.rank == 0:
49
+ self.early_stop_accu_steps = 0
50
+ if self.args.dynamic_batching:
51
+ logging.info(f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}")
52
+ else:
53
+ logging.info(f"batch size (summed over all GPUs): {self.args.batch_size}")
54
+
55
+ def train(self):
56
+ flag = True
57
+ skip_flag = False
58
+ data_start_time = time.time()
59
+ while flag:
60
+ self.train_sampler.set_epoch(self.progress['epoch'])
61
+ for i, batch in enumerate(self.train_loader):
62
+ data_end_time = time.time()
63
+ self.model.train()
64
+ if self.progress['step'] > self.total_step:
65
+ flag = False
66
+ self.validate_and_save()
67
+ if self.rank == 0:
68
+ self.writer.close()
69
+ break
70
+ if isinstance(self.scheduler, Eden):
71
+ self.scheduler.step_epoch(self.progress['step']//self.args.pseudo_epoch_size + 1)
72
+ if self.args.optimizer_name == "ScaledAdam":
73
+ cur_lr = self.scheduler.get_last_lr()[0]
74
+ else:
75
+ lrs = [param_group['lr'] for param_group in self.optimizer.param_groups]
76
+ assert lrs[0] == lrs[1]
77
+ cur_lr = lrs[0]
78
+
79
+ if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0:
80
+ self.writer.add_scalar("train/lr", cur_lr, self.progress['step'])
81
+ self.wandb.log({"train/lr": cur_lr}, step=self.progress['step'])
82
+
83
+ all_inds = list(range(len(batch['y'])))
84
+ sum_losses = 0
85
+ sum_top10acc = 0
86
+ sum_ntoken = 0
87
+ sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)]
88
+ for j in range(self.args.gradient_accumulation_steps):
89
+ cur_ind = all_inds[j::self.args.gradient_accumulation_steps]
90
+ cur_batch = {key: batch[key][cur_ind] for key in batch}
91
+ with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32):
92
+ out = self.model(cur_batch)
93
+
94
+ record_loss = out['loss'].detach().to(self.rank)
95
+ top10acc = out['top10acc'].to(self.rank)
96
+ effective_ntoken = out['effective_ntoken'].to(self.rank)
97
+ is_nan = torch.tensor(int(torch.isnan(record_loss).any()), dtype=torch.float32, device=self.rank)
98
+
99
+ dist.all_reduce(record_loss, op=dist.ReduceOp.SUM)
100
+ dist.all_reduce(top10acc, op=dist.ReduceOp.SUM)
101
+ dist.all_reduce(effective_ntoken, op=dist.ReduceOp.SUM)
102
+ dist.all_reduce(is_nan, op=dist.ReduceOp.SUM)
103
+
104
+ # check if loss is nan
105
+ if is_nan.item() > 0:
106
+ logging.info(f"loss at step {self.progress['step']} is nan, therefore skip this batch")
107
+ skip_flag = True
108
+ continue
109
+
110
+ sum_losses += record_loss.item()
111
+ sum_top10acc += top10acc.item()
112
+ sum_ntoken += effective_ntoken.item()
113
+
114
+ if 'top10acc_by_codebook' in out:
115
+ for cb in range(self.args.n_codebooks):
116
+ top10acc_cbi = out['top10acc_by_codebook'][cb]
117
+ dist.all_reduce(top10acc_cbi, op=dist.ReduceOp.SUM)
118
+ sum_top10acc_cbi[cb] += top10acc_cbi.item()
119
+
120
+ if self.rank == 0:
121
+ average_loss = sum_losses / sum_ntoken
122
+ average_top10acc = sum_top10acc / sum_ntoken
123
+ self.meters['train_loss'].update(average_loss, batch['x'].shape[0]*self.world_size)
124
+ self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size)
125
+ self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size)
126
+ average_top10acc_cbi = [sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks for cb in range(self.args.n_codebooks)]
127
+ for cb in range(self.args.n_codebooks):
128
+ self.meters[f'train_top10acc_cb{cb+1}'].update(average_top10acc_cbi[cb], batch['x'].shape[0]*self.world_size)
129
+
130
+ if self.progress['step'] % self.args.tb_write_every_n_steps == 0:
131
+ self.writer.add_scalar('train/loss', average_loss, self.progress['step'])
132
+ self.writer.add_scalar('train/top10acc', average_top10acc, self.progress['step'])
133
+ self.writer.add_scalar("train/ntokens", sum_ntoken, self.progress['step'])
134
+ for cb in range(self.args.n_codebooks):
135
+ self.writer.add_scalar(f'train/top10acc_cb{cb+1}', average_top10acc_cbi[cb], self.progress['step'])
136
+
137
+ if self.args.optimizer_name == "ScaledAdam":
138
+ self.scaler.scale(out['loss']).backward()
139
+ else:
140
+ self.scaler.scale(out['loss']/out['effective_ntoken']).backward()
141
+
142
+ if skip_flag:
143
+ self.optimizer.zero_grad()
144
+ skip_flag = False
145
+ continue
146
+
147
+ if self.args.optimizer_name != "ScaledAdam":
148
+ self.scaler.unscale_(self.optimizer)
149
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip_val)
150
+ self.scaler.step(self.optimizer)
151
+ self.scaler.update()
152
+
153
+ self.optimizer.zero_grad()
154
+
155
+ if self.args.optimizer_name == "ScaledAdam":
156
+ self.scheduler.step_batch(self.progress['step'])
157
+ else:
158
+ self.scheduler.step()
159
+
160
+ if self.rank == 0:
161
+ self.meters['data_time'].update(data_end_time - data_start_time)
162
+ self.meters['train_time'].update(time.time() - data_end_time)
163
+ if self.progress['step'] % self.args.tb_write_every_n_steps == 0:
164
+ self.writer.add_scalar("train/data_time", data_end_time - data_start_time, self.progress['step'])
165
+ self.writer.add_scalar("train/train_time", time.time() - data_end_time, self.progress['step'])
166
+
167
+
168
+ # logging
169
+ if self.progress['step'] % self.args.print_every_n_steps == 0:
170
+ log_out = {}
171
+ log_out['cur_epoch'] = f"{self.progress['epoch']}/{self.args.num_epochs}" if self.args.num_epochs is not None else f"{self.progress['epoch']}"
172
+ log_out['cur_step'] = f"{int(self.progress['cur_step']+1)}"
173
+ log_out['total_step'] = f"{self.progress['step']}/{self.args.num_steps}"
174
+ log_out['lr'] = f"{cur_lr:.7f}"
175
+ log_out['ntokens'] = f"{sum_ntoken}"
176
+ for key in self.meters:
177
+ if self.meters[key].val != 0 or self.meters[key].avg != 0:
178
+ log_out[key] = f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" if isinstance(self.meters[key].val, float) else f"{self.meters[key].val}"
179
+ logging.info(log_out)
180
+ if np.isnan(self.meters['train_loss'].avg):
181
+ logging.warning("training diverged...")
182
+ raise RuntimeError("training diverged...")
183
+
184
+ # validation and save models
185
+ if self.progress['step'] % self.args.val_every_n_steps == 0:
186
+ dist.barrier()
187
+ self.validate_and_save()
188
+
189
+ self.progress['step'] += 1
190
+ self.progress['cur_step'] += 1
191
+
192
+ data_start_time = time.time()
193
+ self.progress['epoch'] += 1
194
+ self.progress['cur_step'] = 0 # reset cur_step to be 0
195
+ dist.destroy_process_group()
196
+
197
+ def validate_and_save(self):
198
+ self.model.eval()
199
+
200
+ score = self.validate(self.valid_loader)
201
+
202
+ if self.rank == 0:
203
+ if self.args.early_stop_threshold > 0:
204
+ if self.progress['best_score'] - score < self.args.early_stop_threshold:
205
+ self.early_stop_accu_steps += self.args.val_every_n_steps
206
+ if self.early_stop_accu_steps >= self.args.early_stop_step-1:
207
+ logging.info(f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}")
208
+ logging.info(f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}")
209
+ dist.destroy_process_group()
210
+ raise RuntimeError("early stop")
211
+ else:
212
+ self.early_stop_accu_steps = 0
213
+
214
+ if (score < self.progress['best_score']):
215
+ self.progress['best_step'] = self.progress['step']
216
+ self.progress['best_score'] = score
217
+ save_path = os.path.join(self.args.exp_dir,"best_bundle.pth")
218
+ torch.save(
219
+ {
220
+ "model": self.model.module.state_dict(),
221
+ "optimizer": self.optimizer.state_dict(),
222
+ "scheduler": self.scheduler.state_dict(),
223
+ "config": self.args,
224
+ "phn2num": self.train_loader.dataset.phn2num
225
+ },save_path
226
+ )
227
+ logging.info(f"save *best* models at {save_path} at global step {self.progress['step']}")
228
+ self._save_progress()
229
+ save_path = os.path.join(self.args.exp_dir,"bundle.pth")
230
+ torch.save(
231
+ {
232
+ "model": self.model.module.state_dict(),
233
+ "optimizer": self.optimizer.state_dict(),
234
+ "scheduler": self.scheduler.state_dict(),
235
+ "config": self.args,
236
+ "phn2num": self.train_loader.dataset.phn2num
237
+ },save_path
238
+ )
239
+ logging.info(f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}")
240
+
241
+ dist.barrier()
242
+
243
+ def validate(self, valid_loader=None, hide_progress=True):
244
+ if valid_loader == None:
245
+ valid_loader = self.valid_loader
246
+ self.model.eval()
247
+
248
+ start_val_time = time.time()
249
+ sum_losses = 0
250
+ sum_top10acc = 0
251
+ sum_ntoken = 0
252
+ sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)]
253
+
254
+ with torch.no_grad():
255
+ for i, batch in enumerate(tqdm(valid_loader, disable=hide_progress)):
256
+ out = self.model(batch)
257
+ sum_losses += out['loss']
258
+ sum_top10acc += out['top10acc']
259
+ sum_ntoken += out['effective_ntoken']
260
+ if 'top10acc_by_codebook' in out:
261
+ for cb in range(self.args.n_codebooks):
262
+ sum_top10acc_cbi[cb] += out['top10acc_by_codebook'][cb]
263
+
264
+ dist.all_reduce(sum_losses, op=dist.ReduceOp.SUM)
265
+ dist.all_reduce(sum_top10acc, op=dist.ReduceOp.SUM)
266
+ dist.all_reduce(sum_ntoken, op=dist.ReduceOp.SUM)
267
+
268
+ if 'top10acc_by_codebook' in out:
269
+ for cb in range(self.args.n_codebooks):
270
+ dist.all_reduce(sum_top10acc_cbi[cb], op=dist.ReduceOp.SUM)
271
+
272
+ if self.rank == 0:
273
+ val_loss = sum_losses / sum_ntoken
274
+ val_top10acc = sum_top10acc / sum_ntoken
275
+ # logging
276
+ self.meters['val_loss'].update(val_loss)
277
+ logging.info(f"val loss: {val_loss:.5f}")
278
+ self.writer.add_scalar("val/loss", val_loss, self.progress['step'])
279
+
280
+ self.meters['val_top10acc'].update(val_top10acc)
281
+ logging.info(f"val top10acc: {val_top10acc:.5f}")
282
+ self.writer.add_scalar("val/top10acc", val_top10acc, self.progress['step'])
283
+ for cb in range(self.args.n_codebooks):
284
+ average_top10acc_cbi = sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks
285
+ self.meters[f'val_top10acc_cb{cb+1}'].update(average_top10acc_cbi)
286
+ self.writer.add_scalar(f'val/top10acc_cb{cb+1}', average_top10acc_cbi, self.progress['step'])
287
+
288
+ logging.info(f"validation takes: {time.time() - start_val_time:.2f}s")
289
+ logging.info(f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}")
290
+ return val_loss.item()
291
+ else:
292
+ return None
293
+
294
+ def _setup_meters(self):
295
+ meters = {}
296
+ meter_names = ['train_loss', 'val_loss', 'train_top10acc', 'val_top10acc', 'data_time', 'train_time']
297
+ meter_names += ['train_dur_loss', 'train_dur_acc', 'val_dur_loss', 'val_dur_acc']
298
+ meter_names += [f'train_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)]
299
+ meter_names += [f'val_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)]
300
+ for name in meter_names:
301
+ meters[name] = AverageMeter()
302
+ return meters
303
+ def _setup_progress(self):
304
+ progress = {}
305
+ progress['best_step'] = 1
306
+ progress['best_score'] = np.inf # this records loss value
307
+ progress['step'] = 1
308
+ progress['epoch'] = 1
309
+ progress['cur_step'] = 0 # step in the current epoch, for resuming the sampler
310
+ total_progress = []
311
+ # if self.args.resume or self.args.validate:
312
+ if self.args.resume:
313
+ progress_pkl = "%s/progress.pkl" % self.args.exp_dir
314
+ with open(progress_pkl, "rb") as f:
315
+ total_progress = pickle.load(f)
316
+ progress['best_step'], progress['best_score'], progress['step'], progress['epoch'], progress['cur_step'], _ = total_progress[-1]
317
+ if self.rank == 0:
318
+ logging.info("\nResume training from:")
319
+ logging.info(" epoch = %s" % progress['epoch'])
320
+ logging.info(" cur_step = %s" % progress['cur_step'])
321
+ logging.info(" step = %s" % progress['step'])
322
+ logging.info(" best_step = %s" % progress['best_step'])
323
+ logging.info(" best_score = %s" % progress['best_score'])
324
+ return progress, total_progress
325
+
326
+ def _save_progress(self):
327
+ self.total_progress.append([self.progress['best_step'], self.progress['best_score'], int(self.progress['step']+1), self.progress['epoch'], int(self.progress['cur_step']+1), time.time() - self.start_time])
328
+ with open("%s/progress.pkl" % self.args.exp_dir, "wb") as f:
329
+ pickle.dump(self.total_progress, f)
330
+
331
+ def _setup_dataloader(self):
332
+ assert self.args.dataset == 'gigaspeech', "only gigaspeech is supported for now"
333
+ train_dataset, val_dataset = gigaspeech.dataset(self.args, 'train'), gigaspeech.dataset(self.args, 'validation')
334
+
335
+ if self.args.dynamic_batching:
336
+ train_sampler = DistributedDynamicBatchSampler(train_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=train_dataset.lengths_list, verbose=True, epoch=0)
337
+ valid_sampler = DistributedDynamicBatchSampler(val_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=val_dataset.lengths_list, verbose=True, epoch=0)
338
+ else:
339
+ train_sampler = StatefulDistributedSampler(train_dataset, self.args.batch_size//self.world_size, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True)
340
+ valid_sampler = DistributedSampler(val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, seed=self.args.seed, drop_last=False)
341
+
342
+ if self.progress['step'] > 1:
343
+ train_sampler.set_epoch_resume(self.progress['epoch'], self.progress['cur_step'])
344
+
345
+ if self.args.dynamic_batching:
346
+ train_loader = torch.utils.data.DataLoader(train_dataset,
347
+ batch_sampler=train_sampler,
348
+ num_workers=self.args.num_workers//self.world_size,
349
+ collate_fn=train_dataset.collate, persistent_workers=True
350
+ )
351
+ valid_loader = torch.utils.data.DataLoader(val_dataset,
352
+ batch_sampler=valid_sampler,
353
+ num_workers=self.args.num_workers//self.world_size,
354
+ collate_fn=val_dataset.collate, persistent_workers=True
355
+ )
356
+ else:
357
+ train_loader = torch.utils.data.DataLoader(train_dataset,
358
+ batch_size=self.args.batch_size//self.world_size, sampler=train_sampler, num_workers=self.args.num_workers//self.world_size,
359
+ collate_fn=train_dataset.collate, persistent_workers=True
360
+ )
361
+ valid_loader = torch.utils.data.DataLoader(val_dataset,
362
+ batch_size=self.args.batch_size//self.world_size, sampler=valid_sampler,
363
+ num_workers=self.args.num_workers//self.world_size,
364
+ collate_fn=val_dataset.collate, persistent_workers=True
365
+ )
366
+ return len(train_dataset), train_sampler, train_loader, valid_loader
367
+
368
+
369
+
370
+ def _setup_models(self):
371
+ model = voicecraft.VoiceCraft(self.args)
372
+
373
+ if self.rank == 0:
374
+ logging.info(model)
375
+ logging.info("model parameters")
376
+ print_model_info(model)
377
+
378
+ if self.progress['step'] > 1:
379
+ bundle = torch.load(os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu")
380
+ model.load_state_dict(bundle['model'])
381
+ optim_states = bundle['optimizer']
382
+ scheduler_states = bundle['scheduler']
383
+ if self.rank == 0:
384
+ logging.info("loaded parameters and data indices from epoch %d, global step %d" % (self.progress['epoch'], self.progress['step']))
385
+ del bundle['model']
386
+ else:
387
+ optim_states = None
388
+ scheduler_states = None
389
+
390
+ if self.args.load_model_from != None and self.progress['step'] <= 1:
391
+ sd = torch.load(self.args.load_model_from, map_location="cpu")['model']
392
+ model.load_state_dict(sd)
393
+ del sd
394
+
395
+ if self.args.optimizer_name == "ScaledAdam":
396
+ trainables = [p for p in model.parameters() if p.requires_grad]
397
+ else:
398
+ no_decay = [".bias", ".audio_embeddings.weight", ".text_embeddings.weight", ".norm.weight", ".norm1.weight", ".norm2.weight"]
399
+ optimizer_grouped_parameters = [
400
+ {
401
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
402
+ "weight_decay": self.args.weight_decay,
403
+ },
404
+ {
405
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
406
+ "weight_decay": 0.0,
407
+ },
408
+ ]
409
+ if len(optimizer_grouped_parameters[1]['params']) == 0:
410
+ logging.info("there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names")
411
+ trainables = optimizer_grouped_parameters[0]
412
+ else:
413
+ trainables = optimizer_grouped_parameters
414
+ model.to(self.device)
415
+
416
+ return model, trainables, optim_states, scheduler_states
417
+
418
+
419
+ def _setup_optimizer(self):
420
+ if self.args.optimizer_name == "ScaledAdam":
421
+ parameters_names = []
422
+ parameters_names.append([n for n,p in self.model.named_parameters() if p.requires_grad])
423
+ optimizer = ScaledAdam(
424
+ self.trainables,
425
+ lr=self.args.lr,
426
+ betas=(0.9, 0.95),
427
+ clipping_scale=2.0,
428
+ parameters_names=parameters_names,
429
+ show_dominant_parameters=False,
430
+ clipping_update_period=self.args.clipping_update_period,
431
+ )
432
+ scheduler = Eden(optimizer, self.args.reduce_lr_start_step, self.args.reduce_lr_start_epoch, warmup_batches=self.total_step * self.args.warmup_fraction)
433
+
434
+ else:
435
+ optimizer = AdamW(self.trainables, lr=self.args.lr)
436
+ warmup_steps = self.total_step * self.args.warmup_fraction
437
+ def lr_lambda(current_step: int):
438
+ if current_step < warmup_steps:
439
+ return float(current_step) / float(max(1, warmup_steps))
440
+ return max(
441
+ 0.0, float(self.total_step - current_step) / float(max(1, self.total_step - warmup_steps))
442
+ )
443
+
444
+ scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1)
445
+
446
+ # if resume
447
+ if self.progress['step'] > 1:
448
+ optimizer.load_state_dict(self.optim_states)
449
+ for state in optimizer.state.values():
450
+ for k, v in state.items():
451
+ if isinstance(v, torch.Tensor):
452
+ state[k] = v.cuda()
453
+ del self.optim_states
454
+
455
+ scheduler.load_state_dict(self.scheduler_states)
456
+
457
+ optimizer.zero_grad()
458
+ return optimizer, scheduler
459
+
460
+ def seed_everything(self, seed=1):
461
+ os.environ['PYTHONHASHSEED'] = str(seed)
462
+ random.seed(seed)
463
+ np.random.seed(seed)
464
+ torch.manual_seed(seed)
465
+ torch.cuda.manual_seed(seed)
466
+ torch.backends.cudnn.benchmark = False
467
+ torch.backends.cudnn.deterministic = True
lib/voicecraft/steps/trainer_utils.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import math
4
+ import torch.distributed as dist
5
+ from torch.utils.data.sampler import Sampler
6
+ import copy
7
+ import numpy as np
8
+ from typing import List
9
+ from scipy.stats import lognorm
10
+ import logging
11
+
12
+ class StatefulDistributedSampler(Sampler[int]):
13
+ def __init__(self, dataset, batch_size, num_replicas = None, rank = None, shuffle = True, seed = 0, drop_last = False):
14
+ if num_replicas is None:
15
+ if not dist.is_available():
16
+ raise RuntimeError("Requires distributed package to be available")
17
+ num_replicas = dist.get_world_size()
18
+ if rank is None:
19
+ if not dist.is_available():
20
+ raise RuntimeError("Requires distributed package to be available")
21
+ rank = dist.get_rank()
22
+ if rank >= num_replicas or rank < 0:
23
+ raise ValueError(
24
+ "Invalid rank {}, rank should be in the interval"
25
+ " [0, {}]".format(rank, num_replicas - 1))
26
+ self.dataset = dataset
27
+ self.batch_size = batch_size
28
+ self.num_replicas = num_replicas
29
+ self.rank = rank
30
+ self.epoch = 0
31
+ self.cur_epoch = 0
32
+ self.drop_last = drop_last
33
+ # If the dataset length is evenly divisible by # of replicas, then there
34
+ # is no need to drop any data, since the dataset will be split equally.
35
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
36
+ # Split to nearest available length that is evenly divisible.
37
+ # This is to ensure each rank receives the same amount of data when
38
+ # using this Sampler.
39
+ self.num_samples = math.ceil(
40
+ (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
41
+ )
42
+ else:
43
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
44
+ self.total_size = self.num_samples * self.num_replicas
45
+ self.shuffle = shuffle
46
+ self.seed = seed
47
+ self.continue_flag = False
48
+ def __len__(self):
49
+ return self.num_samples
50
+
51
+ def set_epoch(self, epoch):
52
+ r"""
53
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
54
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
55
+ sampler will yield the same ordering.
56
+
57
+ Args:
58
+ epoch (int): Epoch number.
59
+ """
60
+ self.epoch = epoch
61
+
62
+ if self.shuffle:
63
+ # deterministically shuffle based on epoch and seed
64
+ g = torch.Generator()
65
+ g.manual_seed(self.seed + self.epoch)
66
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
67
+ else:
68
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
69
+
70
+ if not self.drop_last:
71
+ # add extra samples to make it evenly divisible
72
+ padding_size = self.total_size - len(indices)
73
+ if padding_size <= len(indices):
74
+ indices += indices[:padding_size]
75
+ else:
76
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
77
+ else:
78
+ # remove tail of data to make it evenly divisible.
79
+ indices = indices[:self.total_size]
80
+ assert len(indices) == self.total_size
81
+
82
+ # subsample
83
+ indices = indices[self.rank:self.total_size:self.num_replicas]
84
+ assert len(indices) == self.num_samples
85
+ self.indices = indices
86
+
87
+ if self.continue_flag:
88
+ self.indices = self.indices[int(self.cur_step*self.batch_size):]
89
+ self.num_samples = len(self.indices)
90
+ self.continue_flag = False
91
+
92
+ def __iter__(self):
93
+ for idx in self.indices:
94
+ yield idx
95
+
96
+ def set_epoch_resume(self, epoch, cur_step):
97
+ self.epoch = epoch
98
+ self.cur_step = cur_step
99
+ self.continue_flag = True
100
+
101
+
102
+ class StatefulSampler(Sampler):
103
+ def __init__(self, data_source_length, batch_size, use_random=True, seed=1, epoch=0):
104
+ self.use_random = use_random
105
+ self.data_source_length = data_source_length
106
+ self.num_samples = self.data_source_length
107
+ self.batch_size = batch_size
108
+ self.continue_flag = False
109
+ self.seed = seed
110
+ self.epoch = epoch
111
+ self.cur_step = 0
112
+
113
+ def __len__(self):
114
+ return self.num_samples
115
+
116
+ def __iter__(self):
117
+
118
+ for idx in self.indices:
119
+ yield idx
120
+
121
+ def set_epoch(self, epoch):
122
+ self.epoch = epoch
123
+ if self.use_random:
124
+ # deterministically shuffle based on epoch and seed
125
+ g = torch.Generator()
126
+ g.manual_seed(self.seed + self.epoch)
127
+ self.indices = torch.randperm(self.data_source_length, generator=g).tolist() # type: ignore[arg-type]
128
+ else:
129
+ self.indices = list(range(self.data_source_length)) # type: ignore[arg-type]
130
+ if self.continue_flag == True:
131
+ self.continue_flag = False
132
+ self.indices = self.indices[int(self.cur_step*self.batch_size):]
133
+
134
+ self.num_samples = len(self.indices)
135
+
136
+ def set_epoch_resume(self, epoch, cur_step):
137
+ self.epoch = epoch
138
+ self.cur_step = cur_step
139
+ self.continue_flag = True
140
+
141
+
142
+ class AverageMeter:
143
+ """Computes and stores the average and current value"""
144
+ def __init__(self):
145
+ self.reset()
146
+
147
+ def reset(self):
148
+ self.val = 0
149
+ self.avg = 0
150
+ self.sum = 0
151
+ self.count = 0
152
+
153
+ def update(self, val, n=1):
154
+ self.val = val
155
+ self.sum += val * n
156
+ self.count += n
157
+ self.avg = self.sum / self.count
158
+
159
+ def print_model_info(model, print_model = False, print_params = True):
160
+ if print_model:
161
+ logging.info(model)
162
+ if print_params:
163
+ all_params = {}
164
+ for name, p in model.named_parameters():
165
+ name = name.split(".")[0]
166
+ if name in all_params:
167
+ all_params[name] += p.numel()
168
+ else:
169
+ all_params[name] = p.numel()
170
+ logging.info("num of parameters of each components:")
171
+ for name in all_params:
172
+ logging.info(f"{name}: {all_params[name]/1000000.:.2f}m")
173
+
174
+
175
+ class DistributedDynamicBatchSampler(Sampler):
176
+ """
177
+ modified from SpeechBrian, https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/sampler.py#L307
178
+ This BatchSampler batches examples together by grouping them by their length.
179
+
180
+ Every example in the batch have approximately the same length and
181
+ thus padding is minimized.
182
+ This enables faster training on datasets
183
+ where length of examples can vary significantly (e.g Librispeech).
184
+ Inspired by: https://www.tensorflow.org/api_docs/python/tf/data/experimental/bucket_by_sequence_length
185
+
186
+ Dynamic batching is performed by specifying a max_batch_length which is the
187
+ upper limit for the sum of the length of examples in a batch:
188
+ e.g., if ex1 has length 4, ex2 length 5 and if max_batch_length is set to 6
189
+ ex1 and ex2 will be placed, alone, in two distinct batches.
190
+
191
+ Length for each example can be obtained in two manners.
192
+ If the input dataset is a DynamicItemDataset it can be obtained by specifying a
193
+ length_func. Default assumes a "duration" entry is in the annotation.
194
+ Length for each example can also be passed to this class upon instantiation
195
+ by specifying a list containing the length for each example and passing it to
196
+ lengths_list.
197
+
198
+ Examples are grouped together by defining a set of possible discrete intervals
199
+ (buckets). Examples whose length fall into these intervals can be batched together.
200
+
201
+ The number of buckets can be specified by using the arg num_buckets.
202
+ There is usually an optimal range for the value of this argument.
203
+
204
+ If num_buckets == 1, all examples can be batched together. You have maximum randomization
205
+ but your training speed will be slower due to the fact that a large amount of the values will be padding
206
+ as long and short examples can be batched together.
207
+ As the number of buckets grows only examples with similar
208
+ length can be grouped together.
209
+ This trades-off speed with randomization.
210
+ TLDR: Low number -> better randomization, High number -> faster training.
211
+ NOTE THAT: if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size
212
+ will be small impacting training speed and possibly performance.
213
+
214
+ The buckets can also be specified by passing a list to the bucket_boundaries
215
+ argument instead of specifying a left_bucket_length and a bucket_length_multiplier.
216
+
217
+ Example
218
+ -------
219
+ >>> import torch
220
+ >>> import speechbrain as sb
221
+ >>> from speechbrain.dataio.sampler import DynamicBatchSampler
222
+ >>> from speechbrain.dataio.dataset import DynamicItemDataset
223
+ >>> from speechbrain.dataio.dataloader import SaveableDataLoader
224
+ >>> from speechbrain.dataio.batch import PaddedBatch
225
+ >>> import numpy as np
226
+ >>> item_lengths = sorted([np.random.randint(10, 100) for x in range(20)])
227
+ >>> dataset = {"ex_{}".format(x) : {"wav" :torch.randn(x)} for x in item_lengths}
228
+ >>> dataset = DynamicItemDataset(dataset)
229
+ >>> dataset.set_output_keys(["wav"])
230
+ >>> length_func = lambda x : len(x) # trivial in this example
231
+ >>> bsampler = DynamicBatchSampler(dataset, 20, 4, length_func, shuffle=False, batch_ordering='descending')
232
+ >>> dataloader = SaveableDataLoader(dataset, batch_sampler=bsampler, collate_fn=PaddedBatch)
233
+ >>> for i, b in enumerate(dataloader):
234
+ ... data, length = b["wav"]
235
+ >>> assert data.shape[-1] == max(item_lengths)
236
+
237
+ Arguments
238
+ ---------
239
+ dataset : torch.utils.data.Dataset
240
+ Pytorch Dataset from which elements will be sampled.
241
+ max_batch_length : int
242
+ Upper limit for the sum of the length of examples in a batch.
243
+ Should be chosen based on your GPU memory.
244
+ num_buckets : int
245
+ Number of discrete buckets used to group examples together.
246
+ If num_buckets == 1, all examples can be batched together. As the number of buckets grows only examples with similar
247
+ length can be grouped together. This trades-off speed with randomization.
248
+ Low number -> better randomization, High number -> faster training.
249
+ However if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size
250
+ will be small impacting training speed and possibly performance.
251
+ NOTE: you have either to specify manually the bucket_boundaries or the number of buckets.
252
+ length_func : callable
253
+ Function used to get length of each example from the dataset.
254
+ This argument can be used only when the dataset is a Speechbrain DynamicItemDataset object.
255
+ Can be anything: e.g. lambda x: x["duration"]*16000 returns number of samples
256
+ if duration key in the annotation is in seconds and the file has 16kHz sampling freq.
257
+ shuffle : bool
258
+ Whether or not shuffle examples between each epoch.
259
+ batch_ordering : string
260
+ If ``random``, batches are randomly permuted; otherwise ``ascending`` or ``descending`` sorted by length.
261
+ max_batch_ex: int
262
+ If set, it limits the maximum number of examples that can be in a batch superseeding max_batch_length
263
+ in instances where the amount of examples will exceeed the value specified here.
264
+ E.g. you have a lot of short examples and the batch size for those will be too high, you can use this argument
265
+ to limit the batch size for these short examples.
266
+ bucket_boundaries : list
267
+ Overrides bucket_length_multiplier and left_bucket_length by specifying manually
268
+ the buckets right boundaries.
269
+ lengths_list: list
270
+ Overrides length_func by passing a list containing the length of each example
271
+ in the dataset. This argument must be set when the dataset is a plain
272
+ Pytorch Dataset object and not a DynamicItemDataset object as length_func
273
+ cannot be used on Pytorch Datasets.
274
+ epoch : int
275
+ The epoch to start at.
276
+ drop_last : bool
277
+ If ``True``, the sampler will drop the last examples which
278
+ have not been grouped.
279
+ verbose: bool
280
+ If ``True``, log also the stats for each batch at the first epoch.
281
+ """
282
+
283
+ def __init__(
284
+ self,
285
+ dataset,
286
+ args,
287
+ num_replicas = None,
288
+ rank = None,
289
+ shuffle = True,
290
+ seed = 0,
291
+ drop_last = False,
292
+ length_func=lambda x: x["duration"],
293
+ batch_ordering: str = "random",
294
+ max_batch_ex: int = None,
295
+ bucket_boundaries: List[int] = [],
296
+ lengths_list: List[int] = None,
297
+ epoch: int = 0,
298
+ verbose: bool = False,
299
+ ):
300
+ self.args = args
301
+ if num_replicas is None:
302
+ if not dist.is_available():
303
+ raise RuntimeError("Requires distributed package to be available")
304
+ num_replicas = dist.get_world_size()
305
+ if rank is None:
306
+ if not dist.is_available():
307
+ raise RuntimeError("Requires distributed package to be available")
308
+ rank = dist.get_rank()
309
+ if rank >= num_replicas or rank < 0:
310
+ raise ValueError(
311
+ "Invalid rank {}, rank should be in the interval"
312
+ " [0, {}]".format(rank, num_replicas - 1))
313
+ self.num_replicas = num_replicas
314
+ self.rank = rank
315
+ max_batch_length = self.args.max_num_tokens if dataset.split == "train" else self.args.val_max_num_tokens
316
+ logging.info(f"max_num_tokens per GPU for {dataset.split} split: {max_batch_length}")
317
+ num_buckets = self.args.num_buckets
318
+ #############
319
+
320
+
321
+
322
+
323
+ self._dataset = dataset
324
+ self._ex_lengths = {}
325
+ # ex_ids = self._dataset.data_ids
326
+ self.verbose = verbose
327
+
328
+ # We do not put a default on num_buckets to encourage users to play with this parameter
329
+ if num_buckets is None and len(bucket_boundaries) == 0:
330
+ raise RuntimeError(
331
+ "Please specify either num_buckets or bucket boundaries."
332
+ "Check the docs, and/or the tutorial !"
333
+ )
334
+ assert lengths_list != None
335
+ max_len = int(self.args.audio_max_length * self.args.encodec_sr)
336
+ lengths_list = [min(l, max_len) for l in lengths_list] # replace all utt whose length is longer than max_len to max_len, will also do this in __getitem__ in dataset
337
+ for indx in range(len(lengths_list)):
338
+ self._ex_lengths[str(indx)] = lengths_list[indx]
339
+ # if lengths_list is not None:
340
+ # # take length of examples from this argument and bypass length_key
341
+ # for indx in range(len(lengths_list)):
342
+ # self._ex_lengths[str(indx)] = lengths_list[indx]
343
+ # else:
344
+ # # use length func
345
+ # if not isinstance(dataset, DynamicItemDataset):
346
+ # raise NotImplementedError(
347
+ # "Dataset should be a Speechbrain DynamicItemDataset when using length function"
348
+ # )
349
+ # for indx in range(len(self._dataset)):
350
+ # self._ex_lengths[str(indx)] = length_func(
351
+ # self._dataset.data[ex_ids[indx]]
352
+ # )
353
+
354
+ if len(bucket_boundaries) > 0:
355
+ if not all([x >= 0 for x in bucket_boundaries]):
356
+ raise ValueError(
357
+ "All elements in bucket boundaries should be non-negative (>= 0)."
358
+ )
359
+ if not len(set(bucket_boundaries)) == len(bucket_boundaries):
360
+ raise ValueError(
361
+ "Bucket_boundaries should not contain duplicates."
362
+ )
363
+ np.testing.assert_array_equal(
364
+ np.array(bucket_boundaries),
365
+ np.array(sorted(bucket_boundaries)),
366
+ err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!",
367
+ )
368
+ self._bucket_boundaries = np.array(sorted(bucket_boundaries))
369
+ else:
370
+ # use num_buckets
371
+ self._bucket_boundaries = np.array(
372
+ self._get_boundaries_through_warping(
373
+ # max_batch_length=max_batch_length,
374
+ max_batch_length=max(lengths_list),
375
+ num_quantiles=num_buckets,
376
+ )
377
+ )
378
+
379
+ self._max_batch_length = max_batch_length
380
+ self._shuffle_ex = shuffle
381
+ self._batch_ordering = batch_ordering
382
+ self._seed = seed
383
+ self._drop_last = drop_last
384
+ if max_batch_ex is None:
385
+ max_batch_ex = np.inf
386
+ self._max_batch_ex = max_batch_ex
387
+ # Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length?
388
+ self._bucket_lens = [
389
+ max(1, int(max_batch_length / self._bucket_boundaries[i]))
390
+ for i in range(len(self._bucket_boundaries))
391
+ ] + [1]
392
+ self._epoch = epoch
393
+ self._cur_step = 0
394
+ self.continue_flag = False
395
+ self._generate_batches()
396
+ self.num_samples = int(math.floor(len(self._batches) / self.num_replicas))
397
+ self.total_size = int(self.num_samples * self.num_replicas)
398
+ self._replica_batches = self._batches[self.rank:self.total_size:self.num_replicas]
399
+ assert len(self._replica_batches) == self.num_samples, f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}"
400
+ logging.info(f"len(self._batches): {len(self._batches)}")
401
+ logging.info(f"self.num_replicas: {self.num_replicas}")
402
+ logging.info(f"num of batches on each replica: {self.num_samples}")
403
+
404
+ def get_durations(self, batch):
405
+ """Gets durations of the elements in the batch."""
406
+ return [self._ex_lengths[str(idx)] for idx in batch]
407
+
408
+ def _get_boundaries_through_warping(
409
+ self, max_batch_length: int, num_quantiles: int,
410
+ ) -> List[int]:
411
+
412
+ # NOTE: the following lines do not cover that there is only one example in the dataset
413
+ # warp frames (duration) distribution of train data
414
+ logging.info("Batch quantisation in latent space")
415
+ # linspace set-up
416
+ num_boundaries = num_quantiles + 1
417
+ # create latent linearly equal spaced buckets
418
+ latent_boundaries = np.linspace(
419
+ 1 / num_boundaries, num_quantiles / num_boundaries, num_quantiles,
420
+ )
421
+ # get quantiles using lognormal distribution
422
+ quantiles = lognorm.ppf(latent_boundaries, 1)
423
+ # scale up to to max_batch_length
424
+ bucket_boundaries = quantiles * max_batch_length / quantiles[-1]
425
+ # compute resulting bucket length multipliers
426
+ length_multipliers = [
427
+ bucket_boundaries[x + 1] / bucket_boundaries[x]
428
+ for x in range(num_quantiles - 1)
429
+ ]
430
+ # logging
431
+ logging.debug(
432
+ "Latent bucket boundary - buckets: {} - length multipliers: {}".format(
433
+ list(map("{:.2f}".format, bucket_boundaries)),
434
+ list(map("{:.2f}".format, length_multipliers)),
435
+ )
436
+ )
437
+ return list(sorted(bucket_boundaries))
438
+
439
+ def _permute_batches(self):
440
+
441
+ if self._batch_ordering == "random":
442
+ # deterministically shuffle based on epoch and seed
443
+ g = torch.Generator()
444
+ g.manual_seed(self._seed + self._epoch) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset
445
+ sampler = torch.randperm(
446
+ len(self._batches), generator=g
447
+ ).tolist() # type: ignore
448
+ tmp = []
449
+ for idx in sampler:
450
+ tmp.append(self._batches[idx])
451
+ self._batches = tmp
452
+
453
+ elif self._batch_ordering == "ascending":
454
+ self._batches = sorted(
455
+ self._batches,
456
+ key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
457
+ )
458
+ elif self._batch_ordering == "descending":
459
+ self._batches = sorted(
460
+ self._batches,
461
+ key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
462
+ reverse=True,
463
+ )
464
+ else:
465
+ raise NotImplementedError
466
+
467
+ def _generate_batches(self):
468
+ logging.info("DynamicBatchSampler: Generating dynamic batches")
469
+ if self._shuffle_ex:
470
+ # deterministically shuffle based on epoch and seed
471
+ g = torch.Generator()
472
+ g.manual_seed(self._seed + self._epoch) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset
473
+ sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore
474
+ # pyp note: this is actually randomly permoted indices
475
+ else:
476
+ # take examples as they are: e.g. they have been sorted
477
+ sampler = range(len(self._dataset)) # type: ignore
478
+
479
+ self._batches = []
480
+ bucket_batches = [[] for i in self._bucket_lens]
481
+
482
+ stats_tracker = [
483
+ {"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0}
484
+ for i in self._bucket_lens
485
+ ]
486
+
487
+ for idx in sampler:
488
+ # length of pre-sampled audio
489
+ item_len = self._ex_lengths[str(idx)]
490
+ # bucket to fill up most padding
491
+ bucket_id = np.searchsorted(self._bucket_boundaries, item_len)
492
+ # fill audio's duration into that bucket
493
+ bucket_batches[bucket_id].append(idx)
494
+
495
+ stats_tracker[bucket_id]["min"] = min(
496
+ stats_tracker[bucket_id]["min"], item_len
497
+ )
498
+ stats_tracker[bucket_id]["max"] = max(
499
+ stats_tracker[bucket_id]["max"], item_len
500
+ )
501
+ stats_tracker[bucket_id]["tot"] += item_len
502
+ stats_tracker[bucket_id]["n_ex"] += 1
503
+ # track #samples - why not duration/#frames; rounded up?
504
+ # keep track of durations, if necessary
505
+
506
+ if (
507
+ len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id]
508
+ or len(bucket_batches[bucket_id]) >= self._max_batch_ex
509
+ ):
510
+ self._batches.append(bucket_batches[bucket_id])
511
+ bucket_batches[bucket_id] = []
512
+ # keep track of durations
513
+
514
+ # Dump remaining batches
515
+ if not self._drop_last:
516
+ for batch in bucket_batches:
517
+ if batch:
518
+ self._batches.append(batch)
519
+
520
+ self._permute_batches() # possibly reorder batches
521
+
522
+ if self._epoch == 0: # only log at first epoch
523
+ # frames per batch & their padding remaining
524
+ boundaries = [0] + self._bucket_boundaries.tolist()
525
+
526
+ for bucket_indx in range(len(self._bucket_boundaries)):
527
+ try:
528
+ num_batches = stats_tracker[bucket_indx]["tot"] // (
529
+ self._max_batch_length
530
+ )
531
+ pad_factor = (
532
+ stats_tracker[bucket_indx]["max"]
533
+ - stats_tracker[bucket_indx]["min"]
534
+ ) / (
535
+ stats_tracker[bucket_indx]["tot"]
536
+ / stats_tracker[bucket_indx]["n_ex"]
537
+ )
538
+ except ZeroDivisionError:
539
+ num_batches = 0
540
+ pad_factor = 0
541
+
542
+ logging.debug(
543
+ (
544
+ "DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and "
545
+ + "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}."
546
+ ).format(
547
+ bucket_indx,
548
+ boundaries[bucket_indx],
549
+ boundaries[bucket_indx + 1],
550
+ self._bucket_lens[bucket_indx],
551
+ stats_tracker[bucket_indx]["n_ex"],
552
+ num_batches,
553
+ pad_factor * 100,
554
+ )
555
+ )
556
+
557
+ if self.verbose:
558
+ batch_stats = {
559
+ "tot_frames": [],
560
+ "tot_pad_frames": [],
561
+ "pad_%": [],
562
+ }
563
+ for batch in self._batches:
564
+ tot_frames = sum(
565
+ [self._ex_lengths[str(idx)] for idx in batch]
566
+ )
567
+ batch_stats["tot_frames"].append(tot_frames)
568
+ max_frames = max(
569
+ [self._ex_lengths[str(idx)] for idx in batch]
570
+ )
571
+ tot_pad = sum(
572
+ [
573
+ max_frames - self._ex_lengths[str(idx)]
574
+ for idx in batch
575
+ ]
576
+ )
577
+ batch_stats["tot_pad_frames"].append(tot_pad)
578
+ batch_stats["pad_%"].append(tot_pad / tot_frames * 100)
579
+
580
+ padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total."
581
+ padding_details = "DynamicBatchSampler: " + padding_details
582
+ for i in range(len(self._batches)):
583
+ logging.debug(
584
+ padding_details.format(
585
+ i,
586
+ batch_stats["tot_frames"][i],
587
+ len(self._batches[i]),
588
+ batch_stats["tot_pad_frames"][i],
589
+ batch_stats["pad_%"][i],
590
+ )
591
+ )
592
+
593
+ def __iter__(self):
594
+
595
+ for batch in self._replica_batches:
596
+ yield batch
597
+
598
+
599
+ # if self._shuffle_ex: # re-generate examples if ex_ordering == "random"
600
+ # self._generate_batches()
601
+ # if self._batch_ordering == "random":
602
+ # # we randomly permute the batches only --> faster
603
+ # self._permute_batches()
604
+
605
+ def set_epoch(self, epoch):
606
+ """
607
+ You can also just access self.epoch, but we maintain this interface
608
+ to mirror torch.utils.data.distributed.DistributedSampler
609
+ """
610
+ self._epoch = epoch
611
+ self._generate_batches()
612
+ self._replica_batches = self._batches[self.rank:self.total_size:self.num_replicas]
613
+ self.num_samples = int(math.floor(len(self._batches) / self.num_replicas))
614
+ assert len(self._replica_batches) == self.num_samples, f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}"
615
+
616
+ if self.continue_flag:
617
+ self.continue_flag = False
618
+ self._replica_batches = self._replica_batches[self._cur_step:]
619
+ self.num_samples = len(self._replica_batches)
620
+
621
+
622
+ def __len__(self):
623
+ return self.num_samples
624
+
625
+ def set_epoch_resume(self, epoch, cur_step):
626
+ self.continue_flag = True
627
+ self._epoch = epoch
628
+ self._cur_step = cur_step
lib/voicecraft/z_scripts/e830M.sh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ source ~/miniconda3/etc/profile.d/conda.sh
3
+ conda activate voicecraft
4
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
5
+ export WORLD_SIZE=4
6
+
7
+ dataset=gigaspeech
8
+ mkdir -p ./logs/${dataset}
9
+
10
+ exp_root="path/to/store/exp_results"
11
+ exp_name=e830M
12
+ dataset_dir="path/to/stored_extracted_codes_and_phonemes/xl" # xs if you only extracted xs in previous step
13
+ encodec_codes_folder_name="encodec_16khz_4codebooks"
14
+
15
+ # export CUDA_LAUNCH_BLOCKING=1 # for debugging
16
+
17
+ torchrun --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:41977 --nproc_per_node=${WORLD_SIZE} \
18
+ ../main.py \
19
+ --reduced_eog 1 \
20
+ --drop_long 1 \
21
+ --eos 2051 \
22
+ --n_special 4 \
23
+ --pad_x 0 \
24
+ --codebook_weight "[5,1,0.5,0.1]" \
25
+ --encodec_sr 50 \
26
+ --num_steps 50000 \
27
+ --lr 0.05 \
28
+ --warmup_fraction 0.01 \
29
+ --optimizer_name "ScaledAdam" \
30
+ --pseudo_epoch_size 3000 \
31
+ --reduce_lr_start_step 3000 \
32
+ --reduce_lr_start_epoch 4 \
33
+ --clipping_update_period 1000 \
34
+ --d_model 2048 \
35
+ --audio_embedding_dim 2048 \
36
+ --nhead 16 \
37
+ --num_decoder_layers 16 \
38
+ --max_num_tokens 100000 \
39
+ --gradient_accumulation_steps 26 \
40
+ --val_max_num_tokens 6000 \
41
+ --num_buckets 6 \
42
+ --audio_max_length 20 \
43
+ --audio_min_length 2 \
44
+ --text_max_length 400 \
45
+ --text_min_length 10 \
46
+ --mask_len_min 1 \
47
+ --mask_len_max 600 \
48
+ --tb_write_every_n_steps 10 \
49
+ --print_every_n_steps 400 \
50
+ --val_every_n_steps 1600 \
51
+ --text_vocab_size 100 \
52
+ --text_pad_token 100 \
53
+ --phn_folder_name "phonemes" \
54
+ --manifest_name "manifest" \
55
+ --encodec_folder_name ${encodec_codes_folder_name} \
56
+ --audio_vocab_size 2048 \
57
+ --empty_token 2048 \
58
+ --eog 2049 \
59
+ --audio_pad_token 2050 \
60
+ --n_codebooks 4 \
61
+ --max_n_spans 3 \
62
+ --shuffle_mask_embedding 0 \
63
+ --mask_sample_dist poisson1 \
64
+ --max_mask_portion 0.9 \
65
+ --min_gap 5 \
66
+ --num_workers 8 \
67
+ --dynamic_batching 1 \
68
+ --dataset $dataset \
69
+ --exp_dir "${exp_root}/${dataset}/${exp_name}" \
70
+ --dataset_dir ${dataset_dir}
71
+ # >> ./logs/${dataset}/${exp_name}.log 2>&1
requirements.txt CHANGED
@@ -5,6 +5,7 @@ stftpitchshift==1.5.1
5
  torchcrepe
6
  setuptools
7
  wheel
 
8
  httpx==0.23.0
9
  faiss-gpu
10
  fairseq
@@ -20,4 +21,14 @@ mega.py
20
  gdown==5.1.0
21
  onnxruntime
22
  pyngrok==4.1.12
23
- torch
 
 
 
 
 
 
 
 
 
 
 
5
  torchcrepe
6
  setuptools
7
  wheel
8
+ whisper
9
  httpx==0.23.0
10
  faiss-gpu
11
  fairseq
 
21
  gdown==5.1.0
22
  onnxruntime
23
  pyngrok==4.1.12
24
+ xformers==0.0.22
25
+ torchaudio==2.0.2
26
+ torch==2.0.1 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201
27
+ tensorboard==2.16.2
28
+ phonemizer==3.2.1
29
+ datasets==2.16.0
30
+ torchmetrics==0.11.1
31
+ # install MFA for getting forced-alignment, this could take a few minutes
32
+ montreal-forced-aligner=2.2.17
33
+ openfst=1.8.2
34
+ kaldi=5.5.1068
run.sh CHANGED
@@ -1,6 +1,6 @@
1
  # Install Debian packages
2
  sudo apt-get update
3
- sudo apt-get install -qq -y build-essential ffmpeg aria2
4
 
5
  # Upgrade pip and setuptools
6
  pip install --upgrade pip
@@ -9,6 +9,9 @@ pip install --upgrade setuptools
9
  # Install wheel package (built-package format for Python)
10
  pip install wheel
11
 
 
 
 
12
  # Install Python packages using pip
13
  pip install -r requirements.txt
14
 
 
1
  # Install Debian packages
2
  sudo apt-get update
3
+ sudo apt-get install -qq -y build-essential ffmpeg aria2 espeak-ng
4
 
5
  # Upgrade pip and setuptools
6
  pip install --upgrade pip
 
9
  # Install wheel package (built-package format for Python)
10
  pip install wheel
11
 
12
+ # Install audiocraft
13
+ pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft
14
+
15
  # Install Python packages using pip
16
  pip install -r requirements.txt
17