mrfakename commited on
Commit
c971ea2
·
verified ·
1 Parent(s): 897409a

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/train/finetune_gradio.py +14 -8
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -34,6 +34,7 @@ python_executable = sys.executable or "python"
34
  tts_api = None
35
  last_checkpoint = ""
36
  last_device = ""
 
37
 
38
  path_basic = os.path.abspath(os.path.join(__file__, "../../../.."))
39
  path_data = os.path.join(path_basic, "data")
@@ -800,7 +801,7 @@ def vocab_extend(project_name, symbols, model_type):
800
  return "Symbols are okay no need to extend."
801
 
802
  size_vocab = len(vocab)
803
-
804
  for item in miss_symbols:
805
  vocab.append(item)
806
 
@@ -915,8 +916,8 @@ def get_random_sample_infer(project_name):
915
  )
916
 
917
 
918
- def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
919
- global last_checkpoint, last_device, tts_api
920
 
921
  if not os.path.isfile(file_checkpoint):
922
  return None, "checkpoint not found!"
@@ -926,15 +927,19 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
926
  else:
927
  device_test = None
928
 
929
- if last_checkpoint != file_checkpoint or last_device != device_test:
930
  if last_checkpoint != file_checkpoint:
931
  last_checkpoint = file_checkpoint
 
932
  if last_device != device_test:
933
  last_device = device_test
934
 
935
- tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
 
 
 
936
 
937
- print("update", device_test, file_checkpoint)
938
 
939
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
940
  tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
@@ -1273,7 +1278,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1273
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1274
 
1275
  nfe_step = gr.Number(label="n_step", value=32)
1276
-
1277
  with gr.Row():
1278
  cm_checkpoint = gr.Dropdown(
1279
  choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
@@ -1285,6 +1290,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1285
  ref_text = gr.Textbox(label="ref text")
1286
  ref_audio = gr.Audio(label="audio ref", type="filepath")
1287
  gen_text = gr.Textbox(label="gen text")
 
1288
  random_sample_infer.click(
1289
  fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
1290
  )
@@ -1297,7 +1303,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1297
 
1298
  check_button_infer.click(
1299
  fn=infer,
1300
- inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step],
1301
  outputs=[gen_audio, txt_info_gpu],
1302
  )
1303
 
 
34
  tts_api = None
35
  last_checkpoint = ""
36
  last_device = ""
37
+ last_ema = None
38
 
39
  path_basic = os.path.abspath(os.path.join(__file__, "../../../.."))
40
  path_data = os.path.join(path_basic, "data")
 
801
  return "Symbols are okay no need to extend."
802
 
803
  size_vocab = len(vocab)
804
+ vocab.pop() # fix empty space leave
805
  for item in miss_symbols:
806
  vocab.append(item)
807
 
 
916
  )
917
 
918
 
919
+ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema):
920
+ global last_checkpoint, last_device, tts_api, last_ema
921
 
922
  if not os.path.isfile(file_checkpoint):
923
  return None, "checkpoint not found!"
 
927
  else:
928
  device_test = None
929
 
930
+ if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema:
931
  if last_checkpoint != file_checkpoint:
932
  last_checkpoint = file_checkpoint
933
+
934
  if last_device != device_test:
935
  last_device = device_test
936
 
937
+ if last_ema != use_ema:
938
+ last_ema = use_ema
939
+
940
+ tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test, use_ema=use_ema)
941
 
942
+ print("update >> ", device_test, file_checkpoint, use_ema)
943
 
944
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
945
  tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
 
1278
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1279
 
1280
  nfe_step = gr.Number(label="n_step", value=32)
1281
+ ch_use_ema = gr.Checkbox(label="use ema", value=True)
1282
  with gr.Row():
1283
  cm_checkpoint = gr.Dropdown(
1284
  choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
 
1290
  ref_text = gr.Textbox(label="ref text")
1291
  ref_audio = gr.Audio(label="audio ref", type="filepath")
1292
  gen_text = gr.Textbox(label="gen text")
1293
+
1294
  random_sample_infer.click(
1295
  fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
1296
  )
 
1303
 
1304
  check_button_infer.click(
1305
  fn=infer,
1306
+ inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema],
1307
  outputs=[gen_audio, txt_info_gpu],
1308
  )
1309