Spaces:
kgout
/
Running on Zero

kgout commited on
Commit
9bc7f89
·
verified ·
1 Parent(s): 80d3cc4

Update audiosr/pipeline.py

Browse files
Files changed (1) hide show
  1. audiosr/pipeline.py +176 -175
audiosr/pipeline.py CHANGED
@@ -1,175 +1,176 @@
1
- import os
2
- import re
3
-
4
- import yaml
5
- import torch
6
- import torchaudio
7
- import numpy as np
8
-
9
- import audiosr.latent_diffusion.modules.phoneme_encoder.text as text
10
- from audiosr.latent_diffusion.models.ddpm import LatentDiffusion
11
- from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding
12
- from audiosr.utils import (
13
- default_audioldm_config,
14
- download_checkpoint,
15
- read_audio_file,
16
- lowpass_filtering_prepare_inference,
17
- wav_feature_extraction,
18
- )
19
- import os
20
-
21
-
22
- def seed_everything(seed):
23
- import random, os
24
- import numpy as np
25
- import torch
26
-
27
- random.seed(seed)
28
- os.environ["PYTHONHASHSEED"] = str(seed)
29
- np.random.seed(seed)
30
- torch.manual_seed(seed)
31
- torch.cuda.manual_seed(seed)
32
- torch.backends.cudnn.deterministic = True
33
- torch.backends.cudnn.benchmark = True
34
-
35
-
36
- def text2phoneme(data):
37
- return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"])
38
-
39
-
40
- def text_to_filename(text):
41
- return text.replace(" ", "_").replace("'", "_").replace('"', "_")
42
-
43
-
44
- def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
45
- norm_mean = -4.2677393
46
- norm_std = 4.5689974
47
-
48
- if sampling_rate != 16000:
49
- waveform_16k = torchaudio.functional.resample(
50
- waveform, orig_freq=sampling_rate, new_freq=16000
51
- )
52
- else:
53
- waveform_16k = waveform
54
-
55
- waveform_16k = waveform_16k - waveform_16k.mean()
56
- fbank = torchaudio.compliance.kaldi.fbank(
57
- waveform_16k,
58
- htk_compat=True,
59
- sample_frequency=16000,
60
- use_energy=False,
61
- window_type="hanning",
62
- num_mel_bins=128,
63
- dither=0.0,
64
- frame_shift=10,
65
- )
66
-
67
- TARGET_LEN = log_mel_spec.size(0)
68
-
69
- # cut and pad
70
- n_frames = fbank.shape[0]
71
- p = TARGET_LEN - n_frames
72
- if p > 0:
73
- m = torch.nn.ZeroPad2d((0, 0, 0, p))
74
- fbank = m(fbank)
75
- elif p < 0:
76
- fbank = fbank[:TARGET_LEN, :]
77
-
78
- fbank = (fbank - norm_mean) / (norm_std * 2)
79
-
80
- return {"ta_kaldi_fbank": fbank} # [1024, 128]
81
-
82
-
83
- def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
84
- log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
85
-
86
- batch = {
87
- "waveform": torch.FloatTensor(waveform),
88
- "stft": torch.FloatTensor(stft),
89
- "log_mel_spec": torch.FloatTensor(log_mel_spec),
90
- "sampling_rate": 48000,
91
- }
92
-
93
- # print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size())
94
-
95
- batch.update(lowpass_filtering_prepare_inference(batch))
96
-
97
- assert "waveform_lowpass" in batch.keys()
98
- lowpass_mel, lowpass_stft = wav_feature_extraction(
99
- batch["waveform_lowpass"], target_frame
100
- )
101
- batch["lowpass_mel"] = lowpass_mel
102
-
103
- for k in batch.keys():
104
- if type(batch[k]) == torch.Tensor:
105
- batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0)
106
-
107
- return batch, duration
108
-
109
-
110
- def round_up_duration(duration):
111
- return int(round(duration / 2.5) + 1) * 2.5
112
-
113
-
114
- def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
115
- if device is None or device == "auto":
116
- if torch.cuda.is_available():
117
- device = torch.device("cuda:0")
118
- elif torch.backends.mps.is_available():
119
- device = torch.device("mps")
120
- else:
121
- device = torch.device("cpu")
122
-
123
- print("Loading AudioSR: %s" % model_name)
124
- print("Loading model on %s" % device)
125
-
126
- ckpt_path = download_checkpoint(model_name)
127
-
128
- if config is not None:
129
- assert type(config) is str
130
- config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
131
- else:
132
- config = default_audioldm_config(model_name)
133
-
134
- # # Use text as condition instead of using waveform during training
135
- config["model"]["params"]["device"] = device
136
- # config["model"]["params"]["cond_stage_key"] = "text"
137
-
138
- # No normalization here
139
- latent_diffusion = LatentDiffusion(**config["model"]["params"])
140
-
141
- resume_from_checkpoint = ckpt_path
142
-
143
- checkpoint = torch.load(resume_from_checkpoint, map_location='cpu')
144
-
145
- latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False)
146
-
147
- latent_diffusion.eval()
148
- latent_diffusion = latent_diffusion.to(device)
149
-
150
- return latent_diffusion
151
-
152
-
153
- def super_resolution(
154
- latent_diffusion,
155
- input_file,
156
- seed=42,
157
- ddim_steps=200,
158
- guidance_scale=3.5,
159
- latent_t_per_second=12.8,
160
- config=None,
161
- ):
162
- seed_everything(int(seed))
163
- waveform = None
164
-
165
- batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform)
166
-
167
- with torch.no_grad():
168
- waveform = latent_diffusion.generate_batch(
169
- batch,
170
- unconditional_guidance_scale=guidance_scale,
171
- ddim_steps=ddim_steps,
172
- duration=duration,
173
- )
174
-
175
- return waveform
 
 
1
+ import os
2
+ import re
3
+
4
+ import yaml
5
+ import torch
6
+ import torchaudio
7
+ import numpy as np
8
+ import spaces
9
+
10
+ import audiosr.latent_diffusion.modules.phoneme_encoder.text as text
11
+ from audiosr.latent_diffusion.models.ddpm import LatentDiffusion
12
+ from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding
13
+ from audiosr.utils import (
14
+ default_audioldm_config,
15
+ download_checkpoint,
16
+ read_audio_file,
17
+ lowpass_filtering_prepare_inference,
18
+ wav_feature_extraction,
19
+ )
20
+ import os
21
+
22
+
23
+ def seed_everything(seed):
24
+ import random, os
25
+ import numpy as np
26
+ import torch
27
+
28
+ random.seed(seed)
29
+ os.environ["PYTHONHASHSEED"] = str(seed)
30
+ np.random.seed(seed)
31
+ torch.manual_seed(seed)
32
+ torch.cuda.manual_seed(seed)
33
+ torch.backends.cudnn.deterministic = True
34
+ torch.backends.cudnn.benchmark = True
35
+
36
+
37
+ def text2phoneme(data):
38
+ return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"])
39
+
40
+
41
+ def text_to_filename(text):
42
+ return text.replace(" ", "_").replace("'", "_").replace('"', "_")
43
+
44
+
45
+ def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
46
+ norm_mean = -4.2677393
47
+ norm_std = 4.5689974
48
+
49
+ if sampling_rate != 16000:
50
+ waveform_16k = torchaudio.functional.resample(
51
+ waveform, orig_freq=sampling_rate, new_freq=16000
52
+ )
53
+ else:
54
+ waveform_16k = waveform
55
+
56
+ waveform_16k = waveform_16k - waveform_16k.mean()
57
+ fbank = torchaudio.compliance.kaldi.fbank(
58
+ waveform_16k,
59
+ htk_compat=True,
60
+ sample_frequency=16000,
61
+ use_energy=False,
62
+ window_type="hanning",
63
+ num_mel_bins=128,
64
+ dither=0.0,
65
+ frame_shift=10,
66
+ )
67
+
68
+ TARGET_LEN = log_mel_spec.size(0)
69
+
70
+ # cut and pad
71
+ n_frames = fbank.shape[0]
72
+ p = TARGET_LEN - n_frames
73
+ if p > 0:
74
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
75
+ fbank = m(fbank)
76
+ elif p < 0:
77
+ fbank = fbank[:TARGET_LEN, :]
78
+
79
+ fbank = (fbank - norm_mean) / (norm_std * 2)
80
+
81
+ return {"ta_kaldi_fbank": fbank} # [1024, 128]
82
+
83
+
84
+ def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
85
+ log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
86
+
87
+ batch = {
88
+ "waveform": torch.FloatTensor(waveform),
89
+ "stft": torch.FloatTensor(stft),
90
+ "log_mel_spec": torch.FloatTensor(log_mel_spec),
91
+ "sampling_rate": 48000,
92
+ }
93
+
94
+ # print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size())
95
+
96
+ batch.update(lowpass_filtering_prepare_inference(batch))
97
+
98
+ assert "waveform_lowpass" in batch.keys()
99
+ lowpass_mel, lowpass_stft = wav_feature_extraction(
100
+ batch["waveform_lowpass"], target_frame
101
+ )
102
+ batch["lowpass_mel"] = lowpass_mel
103
+
104
+ for k in batch.keys():
105
+ if type(batch[k]) == torch.Tensor:
106
+ batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0)
107
+
108
+ return batch, duration
109
+
110
+
111
+ def round_up_duration(duration):
112
+ return int(round(duration / 2.5) + 1) * 2.5
113
+
114
+ @spaces.GPU
115
+ def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
116
+ if device is None or device == "auto":
117
+ if torch.cuda.is_available():
118
+ device = torch.device("cuda:0")
119
+ elif torch.backends.mps.is_available():
120
+ device = torch.device("mps")
121
+ else:
122
+ device = torch.device("cpu")
123
+
124
+ print("Loading AudioSR: %s" % model_name)
125
+ print("Loading model on %s" % device)
126
+
127
+ ckpt_path = download_checkpoint(model_name)
128
+
129
+ if config is not None:
130
+ assert type(config) is str
131
+ config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
132
+ else:
133
+ config = default_audioldm_config(model_name)
134
+
135
+ # # Use text as condition instead of using waveform during training
136
+ config["model"]["params"]["device"] = device
137
+ # config["model"]["params"]["cond_stage_key"] = "text"
138
+
139
+ # No normalization here
140
+ latent_diffusion = LatentDiffusion(**config["model"]["params"])
141
+
142
+ resume_from_checkpoint = ckpt_path
143
+
144
+ checkpoint = torch.load(resume_from_checkpoint, map_location='cpu')
145
+
146
+ latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False)
147
+
148
+ latent_diffusion.eval()
149
+ latent_diffusion = latent_diffusion.to(device)
150
+
151
+ return latent_diffusion
152
+
153
+
154
+ def super_resolution(
155
+ latent_diffusion,
156
+ input_file,
157
+ seed=42,
158
+ ddim_steps=200,
159
+ guidance_scale=3.5,
160
+ latent_t_per_second=12.8,
161
+ config=None,
162
+ ):
163
+ seed_everything(int(seed))
164
+ waveform = None
165
+
166
+ batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform)
167
+
168
+ with torch.no_grad():
169
+ waveform = latent_diffusion.generate_batch(
170
+ batch,
171
+ unconditional_guidance_scale=guidance_scale,
172
+ ddim_steps=ddim_steps,
173
+ duration=duration,
174
+ )
175
+
176
+ return waveform