Shokoufehhh commited on
Commit
8a4b264
·
verified ·
1 Parent(s): d16671b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -4,38 +4,50 @@ from sgmse.model import ScoreModel
4
  import gradio as gr
5
  from sgmse.util.other import pad_spec
6
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Load the pre-trained model
8
- model = ScoreModel.load_from_checkpoint("https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/train_vb_29nqe0uh_epoch%3D115.ckpt")
9
 
10
  def enhance_speech(audio_file):
11
  # Load and process the audio file
12
  y, sr = torchaudio.load(audio_file)
13
-
14
- T_orig = y.size(1)
15
 
16
  # Normalize
17
  norm_factor = y.abs().max()
18
  y = y / norm_factor
19
-
20
  # Prepare DNN input
21
- Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0)
22
- Y = pad_spec(Y, mode=pad_mode)
23
-
24
  # Reverse sampling
25
  sampler = model.get_pc_sampler(
26
- 'reverse_diffusion', args.corrector, Y.to(args.device), N=args.N,
27
- corrector_steps=args.corrector_steps, snr=args.snr)
 
28
  sample, _ = sampler()
29
-
30
  # Backward transform in time domain
31
  x_hat = model.to_audio(sample.squeeze(), T_orig)
32
 
33
  # Renormalize
34
  x_hat = x_hat * norm_factor
35
-
36
  # Save the enhanced audio
37
  output_file = 'enhanced_output.wav'
38
- torchaudio.save(output_file, x_hat.cpu().numpy(), sr)
39
 
40
  return output_file
41
 
 
4
  import gradio as gr
5
  from sgmse.util.other import pad_spec
6
 
7
+ # Define parameters based on the argparse configuration in enhancement.py
8
+ args = {
9
+ "test_dir": "./test_data", # example directory, adjust as needed
10
+ "enhanced_dir": "./enhanced_data", # example directory, adjust as needed
11
+ "ckpt": "https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/train_vb_29nqe0uh_epoch%3D115.ckpt",
12
+ "corrector": "ald",
13
+ "corrector_steps": 1,
14
+ "snr": 0.5,
15
+ "N": 30,
16
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
17
+ }
18
+
19
  # Load the pre-trained model
20
+ model = ScoreModel.load_from_checkpoint(args["ckpt"])
21
 
22
  def enhance_speech(audio_file):
23
  # Load and process the audio file
24
  y, sr = torchaudio.load(audio_file)
25
+ T_orig = y.size(1)
 
26
 
27
  # Normalize
28
  norm_factor = y.abs().max()
29
  y = y / norm_factor
30
+
31
  # Prepare DNN input
32
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args["device"]))), 0)
33
+ Y = pad_spec(Y, mode="constant") # Ensure pad_mode is defined; replace with actual pad_mode if needed
34
+
35
  # Reverse sampling
36
  sampler = model.get_pc_sampler(
37
+ 'reverse_diffusion', args["corrector"], Y.to(args["device"]),
38
+ N=args["N"], corrector_steps=args["corrector_steps"], snr=args["snr"]
39
+ )
40
  sample, _ = sampler()
41
+
42
  # Backward transform in time domain
43
  x_hat = model.to_audio(sample.squeeze(), T_orig)
44
 
45
  # Renormalize
46
  x_hat = x_hat * norm_factor
47
+
48
  # Save the enhanced audio
49
  output_file = 'enhanced_output.wav'
50
+ torchaudio.save(output_file, x_hat.cpu(), sr)
51
 
52
  return output_file
53