PoTaTo721 commited on
Commit
a4dbd99
1 Parent(s): c89e500

Fix API BUGS

Browse files
Files changed (2) hide show
  1. app.py +23 -36
  2. tools/api.py +21 -34
app.py CHANGED
@@ -120,8 +120,6 @@ def build_html_error_message(error):
120
  @torch.inference_mode()
121
  def inference(req: ServeTTSRequest):
122
 
123
- global prompt_tokens, prompt_texts
124
-
125
  idstr: str | None = req.reference_id
126
  if idstr is not None:
127
  ref_folder = Path("references") / idstr
@@ -130,43 +128,32 @@ def inference(req: ServeTTSRequest):
130
  ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
131
  )
132
 
133
- if req.use_memory_cache == "never" or (
134
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
135
- ):
136
- prompt_tokens = [
137
- encode_reference(
138
- decoder_model=decoder_model,
139
- reference_audio=audio_to_bytes(str(ref_audio)),
140
- enable_reference_audio=True,
141
- )
142
- for ref_audio in ref_audios
143
- ]
144
- prompt_texts = [
145
- read_ref_text(str(ref_audio.with_suffix(".lab")))
146
- for ref_audio in ref_audios
147
- ]
148
- else:
149
- logger.info("Use same references")
150
 
151
  else:
152
  # Parse reference audio aka prompt
153
  refs = req.references
154
 
155
- if req.use_memory_cache == "never" or (
156
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
157
- ):
158
- prompt_tokens = [
159
- encode_reference(
160
- decoder_model=decoder_model,
161
- reference_audio=ref.audio,
162
- enable_reference_audio=True,
163
- )
164
- for ref in refs
165
- ]
166
- prompt_texts = [ref.text for ref in refs]
167
- else:
168
- logger.info("Use same references")
169
-
170
  if req.seed is not None:
171
  set_seed(req.seed)
172
  logger.warning(f"set seed: {req.seed}")
@@ -421,8 +408,8 @@ def build_app():
421
  with gr.Row():
422
  use_memory_cache = gr.Radio(
423
  label=i18n("Use Memory Cache"),
424
- choices=["never", "on-demand", "always"],
425
- value="on-demand",
426
  )
427
 
428
  with gr.Row():
 
120
  @torch.inference_mode()
121
  def inference(req: ServeTTSRequest):
122
 
 
 
123
  idstr: str | None = req.reference_id
124
  if idstr is not None:
125
  ref_folder = Path("references") / idstr
 
128
  ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
129
  )
130
 
131
+ prompt_tokens = [
132
+ encode_reference(
133
+ decoder_model=decoder_model,
134
+ reference_audio=audio_to_bytes(str(ref_audio)),
135
+ enable_reference_audio=True,
136
+ )
137
+ for ref_audio in ref_audios
138
+ ]
139
+ prompt_texts = [
140
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
141
+ for ref_audio in ref_audios
142
+ ]
 
 
 
 
 
143
 
144
  else:
145
  # Parse reference audio aka prompt
146
  refs = req.references
147
 
148
+ prompt_tokens = [
149
+ encode_reference(
150
+ decoder_model=decoder_model,
151
+ reference_audio=ref.audio,
152
+ enable_reference_audio=True,
153
+ )
154
+ for ref in refs
155
+ ]
156
+ prompt_texts = [ref.text for ref in refs]
 
 
 
 
 
 
157
  if req.seed is not None:
158
  set_seed(req.seed)
159
  logger.warning(f"set seed: {req.seed}")
 
408
  with gr.Row():
409
  use_memory_cache = gr.Radio(
410
  label=i18n("Use Memory Cache"),
411
+ choices=["never"],
412
+ value="never",
413
  )
414
 
415
  with gr.Row():
tools/api.py CHANGED
@@ -605,8 +605,6 @@ def api_invoke_chat(
605
  @torch.inference_mode()
606
  def inference(req: ServeTTSRequest):
607
 
608
- global prompt_tokens, prompt_texts
609
-
610
  idstr: str | None = req.reference_id
611
  if idstr is not None:
612
  ref_folder = Path("references") / idstr
@@ -615,43 +613,32 @@ def inference(req: ServeTTSRequest):
615
  ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
616
  )
617
 
618
- if req.use_memory_cache == "never" or (
619
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
620
- ):
621
- prompt_tokens = [
622
- encode_reference(
623
- decoder_model=decoder_model,
624
- reference_audio=audio_to_bytes(str(ref_audio)),
625
- enable_reference_audio=True,
626
- )
627
- for ref_audio in ref_audios
628
- ]
629
- prompt_texts = [
630
- read_ref_text(str(ref_audio.with_suffix(".lab")))
631
- for ref_audio in ref_audios
632
- ]
633
- else:
634
- logger.info("Use same references")
635
 
636
  else:
637
  # Parse reference audio aka prompt
638
  refs = req.references
639
 
640
- if req.use_memory_cache == "never" or (
641
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
642
- ):
643
- prompt_tokens = [
644
- encode_reference(
645
- decoder_model=decoder_model,
646
- reference_audio=ref.audio,
647
- enable_reference_audio=True,
648
- )
649
- for ref in refs
650
- ]
651
- prompt_texts = [ref.text for ref in refs]
652
- else:
653
- logger.info("Use same references")
654
-
655
  if req.seed is not None:
656
  set_seed(req.seed)
657
  logger.warning(f"set seed: {req.seed}")
 
605
  @torch.inference_mode()
606
  def inference(req: ServeTTSRequest):
607
 
 
 
608
  idstr: str | None = req.reference_id
609
  if idstr is not None:
610
  ref_folder = Path("references") / idstr
 
613
  ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
614
  )
615
 
616
+ prompt_tokens = [
617
+ encode_reference(
618
+ decoder_model=decoder_model,
619
+ reference_audio=audio_to_bytes(str(ref_audio)),
620
+ enable_reference_audio=True,
621
+ )
622
+ for ref_audio in ref_audios
623
+ ]
624
+ prompt_texts = [
625
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
626
+ for ref_audio in ref_audios
627
+ ]
 
 
 
 
 
628
 
629
  else:
630
  # Parse reference audio aka prompt
631
  refs = req.references
632
 
633
+ prompt_tokens = [
634
+ encode_reference(
635
+ decoder_model=decoder_model,
636
+ reference_audio=ref.audio,
637
+ enable_reference_audio=True,
638
+ )
639
+ for ref in refs
640
+ ]
641
+ prompt_texts = [ref.text for ref in refs]
 
 
 
 
 
 
642
  if req.seed is not None:
643
  set_seed(req.seed)
644
  logger.warning(f"set seed: {req.seed}")