Update app.py
Browse files
app.py
CHANGED
@@ -4,38 +4,50 @@ from sgmse.model import ScoreModel
|
|
4 |
import gradio as gr
|
5 |
from sgmse.util.other import pad_spec
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
# Load the pre-trained model
|
8 |
-
model = ScoreModel.load_from_checkpoint("
|
9 |
|
10 |
def enhance_speech(audio_file):
|
11 |
# Load and process the audio file
|
12 |
y, sr = torchaudio.load(audio_file)
|
13 |
-
|
14 |
-
T_orig = y.size(1)
|
15 |
|
16 |
# Normalize
|
17 |
norm_factor = y.abs().max()
|
18 |
y = y / norm_factor
|
19 |
-
|
20 |
# Prepare DNN input
|
21 |
-
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args
|
22 |
-
Y = pad_spec(Y, mode=pad_mode
|
23 |
-
|
24 |
# Reverse sampling
|
25 |
sampler = model.get_pc_sampler(
|
26 |
-
'reverse_diffusion', args
|
27 |
-
corrector_steps=args
|
|
|
28 |
sample, _ = sampler()
|
29 |
-
|
30 |
# Backward transform in time domain
|
31 |
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
32 |
|
33 |
# Renormalize
|
34 |
x_hat = x_hat * norm_factor
|
35 |
-
|
36 |
# Save the enhanced audio
|
37 |
output_file = 'enhanced_output.wav'
|
38 |
-
torchaudio.save(output_file, x_hat.cpu()
|
39 |
|
40 |
return output_file
|
41 |
|
|
|
4 |
import gradio as gr
|
5 |
from sgmse.util.other import pad_spec
|
6 |
|
7 |
+
# Define parameters based on the argparse configuration in enhancement.py
|
8 |
+
args = {
|
9 |
+
"test_dir": "./test_data", # example directory, adjust as needed
|
10 |
+
"enhanced_dir": "./enhanced_data", # example directory, adjust as needed
|
11 |
+
"ckpt": "https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/train_vb_29nqe0uh_epoch%3D115.ckpt",
|
12 |
+
"corrector": "ald",
|
13 |
+
"corrector_steps": 1,
|
14 |
+
"snr": 0.5,
|
15 |
+
"N": 30,
|
16 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
}
|
18 |
+
|
19 |
# Load the pre-trained model
|
20 |
+
model = ScoreModel.load_from_checkpoint(args["ckpt"])
|
21 |
|
22 |
def enhance_speech(audio_file):
|
23 |
# Load and process the audio file
|
24 |
y, sr = torchaudio.load(audio_file)
|
25 |
+
T_orig = y.size(1)
|
|
|
26 |
|
27 |
# Normalize
|
28 |
norm_factor = y.abs().max()
|
29 |
y = y / norm_factor
|
30 |
+
|
31 |
# Prepare DNN input
|
32 |
+
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args["device"]))), 0)
|
33 |
+
Y = pad_spec(Y, mode="constant") # Ensure pad_mode is defined; replace with actual pad_mode if needed
|
34 |
+
|
35 |
# Reverse sampling
|
36 |
sampler = model.get_pc_sampler(
|
37 |
+
'reverse_diffusion', args["corrector"], Y.to(args["device"]),
|
38 |
+
N=args["N"], corrector_steps=args["corrector_steps"], snr=args["snr"]
|
39 |
+
)
|
40 |
sample, _ = sampler()
|
41 |
+
|
42 |
# Backward transform in time domain
|
43 |
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
44 |
|
45 |
# Renormalize
|
46 |
x_hat = x_hat * norm_factor
|
47 |
+
|
48 |
# Save the enhanced audio
|
49 |
output_file = 'enhanced_output.wav'
|
50 |
+
torchaudio.save(output_file, x_hat.cpu(), sr)
|
51 |
|
52 |
return output_file
|
53 |
|