Steveeeeeeen HF staff commited on
Commit
22bde2c
·
verified ·
1 Parent(s): 1272193

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -13,7 +13,6 @@ device = "cuda"
13
  banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
14
  BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
15
 
16
-
17
  def load_model(model_name: str):
18
  """
19
  Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
@@ -35,25 +34,31 @@ def tts(text, speaker_audio, selected_language, model_choice):
35
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
36
  selected_language: str (language code)
37
  model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
38
-
39
- Returns (sample_rate, waveform) for Gradio audio output.
40
  """
41
- # Load the selected model
42
  model = load_model(model_choice)
43
 
44
  if not text:
45
  return None
 
 
46
  if speaker_audio is None:
47
  return None
48
 
49
- # Gradio gives audio in the format (sample_rate, numpy_array)
50
  sr, wav_np = speaker_audio
51
 
52
- # Convert to Torch tensor: shape (1, num_samples)
53
- wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
54
- if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
55
- # If shape is transposed, fix it
56
- wav_tensor = wav_tensor.T
 
 
 
 
 
57
 
58
  # Get speaker embedding
59
  with torch.no_grad():
@@ -101,16 +106,16 @@ def build_demo():
101
  ref_audio_input = gr.Audio(
102
  label="Reference Audio (Speaker Cloning)",
103
  type="numpy"
 
 
104
  )
105
 
106
- # Model dropdown
107
  model_dropdown = gr.Dropdown(
108
  label="Model Choice",
109
  choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
110
  value="Zyphra/Zonos-v0.1-hybrid",
111
  interactive=True,
112
  )
113
- # Language dropdown (you can filter or use all from supported_language_codes)
114
  language_dropdown = gr.Dropdown(
115
  label="Language Code",
116
  choices=supported_language_codes,
 
13
  banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
14
  BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
15
 
 
16
  def load_model(model_name: str):
17
  """
18
  Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
 
34
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
35
  selected_language: str (language code)
36
  model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
37
+
38
+ Returns (sr_out, wav_out_numpy).
39
  """
 
40
  model = load_model(model_choice)
41
 
42
  if not text:
43
  return None
44
+
45
+ # If the user did not provide a reference audio, skip
46
  if speaker_audio is None:
47
  return None
48
 
49
+ # Gradio gives audio in (sample_rate, numpy_array) format
50
  sr, wav_np = speaker_audio
51
 
52
+ # Convert to Torch tensor
53
+ wav_tensor = torch.from_numpy(wav_np).float()
54
+
55
+ # If stereo (shape [channels, samples]) or multi-channel, downmix to mono
56
+ # e.g. shape (2, samples) -> shape (samples,) by averaging
57
+ if wav_tensor.ndim == 2 and wav_tensor.shape[0] > 1:
58
+ wav_tensor = wav_tensor.mean(dim=0) # shape => (samples,)
59
+
60
+ # Now add a batch dimension => shape (1, samples)
61
+ wav_tensor = wav_tensor.unsqueeze(0)
62
 
63
  # Get speaker embedding
64
  with torch.no_grad():
 
106
  ref_audio_input = gr.Audio(
107
  label="Reference Audio (Speaker Cloning)",
108
  type="numpy"
109
+ # Optionally add mono=True if you want Gradio to always downmix automatically:
110
+ # mono=True
111
  )
112
 
 
113
  model_dropdown = gr.Dropdown(
114
  label="Model Choice",
115
  choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
116
  value="Zyphra/Zonos-v0.1-hybrid",
117
  interactive=True,
118
  )
 
119
  language_dropdown = gr.Dropdown(
120
  label="Language Code",
121
  choices=supported_language_codes,