Steveeeeeeen HF staff commited on
Commit
71064d4
·
verified ·
1 Parent(s): b8a3553

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -31
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  import spaces
5
 
6
  from zonos.model import Zonos
7
- from zonos.conditioning import make_cond_dict # Keep this; remove supported_language_codes
8
 
9
  # We'll keep a global dictionary of loaded models to avoid reloading
10
  MODELS_CACHE = {}
@@ -13,15 +13,6 @@ device = "cuda"
13
  banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
14
  BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
15
 
16
- # Define a list of tuples: (Display Label, Language Code)
17
- LANGUAGES = [
18
- ("English", "en-us"),
19
- ("Japanese", "ja"),
20
- ("Chinese", "cmn"),
21
- ("French", "fr-fr"),
22
- ("German", "de"),
23
- ]
24
-
25
  def load_model(model_name: str):
26
  """
27
  Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
@@ -37,20 +28,15 @@ def load_model(model_name: str):
37
  return MODELS_CACHE[model_name]
38
 
39
  @spaces.GPU(duration=90)
40
- def tts(text, speaker_audio, selected_language_label, model_choice):
41
  """
42
  text: str (Text prompt to synthesize)
43
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
44
- selected_language_label: str (the display name from the dropdown, e.g. "Chinese")
45
  model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
46
 
47
  Returns (sr_out, wav_out_numpy).
48
  """
49
- # Map from label -> actual language code
50
- label_to_code = dict(LANGUAGES)
51
- # Convert the human-readable label back to the code
52
- selected_language = label_to_code[selected_language_label]
53
-
54
  model = load_model(model_choice)
55
 
56
  if not text:
@@ -66,11 +52,12 @@ def tts(text, speaker_audio, selected_language_label, model_choice):
66
  # Convert to Torch tensor
67
  wav_tensor = torch.from_numpy(wav_np).float()
68
 
69
- # If stereo or multi-channel, downmix to mono
 
70
  if wav_tensor.ndim == 2 and wav_tensor.shape[0] > 1:
71
- wav_tensor = wav_tensor.mean(dim=0) # => (samples,)
72
 
73
- # Add batch dimension => (1, samples)
74
  wav_tensor = wav_tensor.unsqueeze(0)
75
 
76
  # Get speaker embedding
@@ -79,12 +66,12 @@ def tts(text, speaker_audio, selected_language_label, model_choice):
79
  spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
80
 
81
  # Prepare conditioning dictionary
82
- cond_dict = {
83
- "text": text,
84
- "speaker": spk_embedding,
85
- "language": selected_language, # Use the code here
86
- "device": device,
87
- }
88
  conditioning = model.prepare_conditioning(cond_dict)
89
 
90
  # Generate codes
@@ -119,6 +106,8 @@ def build_demo():
119
  ref_audio_input = gr.Audio(
120
  label="Reference Audio (Speaker Cloning)",
121
  type="numpy"
 
 
122
  )
123
 
124
  model_dropdown = gr.Dropdown(
@@ -127,12 +116,10 @@ def build_demo():
127
  value="Zyphra/Zonos-v0.1-hybrid",
128
  interactive=True,
129
  )
130
-
131
- # For the language dropdown, we display only the friendly label
132
  language_dropdown = gr.Dropdown(
133
- label="Language",
134
- choices=[label for (label, code) in LANGUAGES],
135
- value="English", # default display
136
  interactive=True,
137
  )
138
 
@@ -150,3 +137,5 @@ def build_demo():
150
  if __name__ == "__main__":
151
  demo_app = build_demo()
152
  demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
4
  import spaces
5
 
6
  from zonos.model import Zonos
7
+ from zonos.conditioning import make_cond_dict, supported_language_codes
8
 
9
  # We'll keep a global dictionary of loaded models to avoid reloading
10
  MODELS_CACHE = {}
 
13
  banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
14
  BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
15
 
 
 
 
 
 
 
 
 
 
16
  def load_model(model_name: str):
17
  """
18
  Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
 
28
  return MODELS_CACHE[model_name]
29
 
30
  @spaces.GPU(duration=90)
31
+ def tts(text, speaker_audio, selected_language, model_choice):
32
  """
33
  text: str (Text prompt to synthesize)
34
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
35
+ selected_language: str (language code)
36
  model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
37
 
38
  Returns (sr_out, wav_out_numpy).
39
  """
 
 
 
 
 
40
  model = load_model(model_choice)
41
 
42
  if not text:
 
52
  # Convert to Torch tensor
53
  wav_tensor = torch.from_numpy(wav_np).float()
54
 
55
+ # If stereo (shape [channels, samples]) or multi-channel, downmix to mono
56
+ # e.g. shape (2, samples) -> shape (samples,) by averaging
57
  if wav_tensor.ndim == 2 and wav_tensor.shape[0] > 1:
58
+ wav_tensor = wav_tensor.mean(dim=0) # shape => (samples,)
59
 
60
+ # Now add a batch dimension => shape (1, samples)
61
  wav_tensor = wav_tensor.unsqueeze(0)
62
 
63
  # Get speaker embedding
 
66
  spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
67
 
68
  # Prepare conditioning dictionary
69
+ cond_dict = make_cond_dict(
70
+ text=text,
71
+ speaker=spk_embedding,
72
+ language=selected_language,
73
+ device=device,
74
+ )
75
  conditioning = model.prepare_conditioning(cond_dict)
76
 
77
  # Generate codes
 
106
  ref_audio_input = gr.Audio(
107
  label="Reference Audio (Speaker Cloning)",
108
  type="numpy"
109
+ # Optionally add mono=True if you want Gradio to always downmix automatically:
110
+ # mono=True
111
  )
112
 
113
  model_dropdown = gr.Dropdown(
 
116
  value="Zyphra/Zonos-v0.1-hybrid",
117
  interactive=True,
118
  )
 
 
119
  language_dropdown = gr.Dropdown(
120
+ label="Language Code",
121
+ choices=["en-us", "ja", "cmn", "fr-fr", "de"]
122
+ value="en-us",
123
  interactive=True,
124
  )
125
 
 
137
  if __name__ == "__main__":
138
  demo_app = build_demo()
139
  demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)
140
+
141
+ This is my code. replace supported_language_codes with a list of the languages i asked you and in the gr.Dropdown it display the name of the language instead of just "cmn'