Spaces:
Running
Running
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -34,6 +34,7 @@ python_executable = sys.executable or "python"
|
|
34 |
tts_api = None
|
35 |
last_checkpoint = ""
|
36 |
last_device = ""
|
|
|
37 |
|
38 |
path_basic = os.path.abspath(os.path.join(__file__, "../../../.."))
|
39 |
path_data = os.path.join(path_basic, "data")
|
@@ -800,7 +801,7 @@ def vocab_extend(project_name, symbols, model_type):
|
|
800 |
return "Symbols are okay no need to extend."
|
801 |
|
802 |
size_vocab = len(vocab)
|
803 |
-
|
804 |
for item in miss_symbols:
|
805 |
vocab.append(item)
|
806 |
|
@@ -915,8 +916,8 @@ def get_random_sample_infer(project_name):
|
|
915 |
)
|
916 |
|
917 |
|
918 |
-
def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
919 |
-
global last_checkpoint, last_device, tts_api
|
920 |
|
921 |
if not os.path.isfile(file_checkpoint):
|
922 |
return None, "checkpoint not found!"
|
@@ -926,15 +927,19 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
|
926 |
else:
|
927 |
device_test = None
|
928 |
|
929 |
-
if last_checkpoint != file_checkpoint or last_device != device_test:
|
930 |
if last_checkpoint != file_checkpoint:
|
931 |
last_checkpoint = file_checkpoint
|
|
|
932 |
if last_device != device_test:
|
933 |
last_device = device_test
|
934 |
|
935 |
-
|
|
|
|
|
|
|
936 |
|
937 |
-
print("update", device_test, file_checkpoint)
|
938 |
|
939 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
940 |
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
@@ -1273,7 +1278,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1273 |
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
1274 |
|
1275 |
nfe_step = gr.Number(label="n_step", value=32)
|
1276 |
-
|
1277 |
with gr.Row():
|
1278 |
cm_checkpoint = gr.Dropdown(
|
1279 |
choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
|
@@ -1285,6 +1290,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1285 |
ref_text = gr.Textbox(label="ref text")
|
1286 |
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
1287 |
gen_text = gr.Textbox(label="gen text")
|
|
|
1288 |
random_sample_infer.click(
|
1289 |
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
|
1290 |
)
|
@@ -1297,7 +1303,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1297 |
|
1298 |
check_button_infer.click(
|
1299 |
fn=infer,
|
1300 |
-
inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step],
|
1301 |
outputs=[gen_audio, txt_info_gpu],
|
1302 |
)
|
1303 |
|
|
|
34 |
tts_api = None
|
35 |
last_checkpoint = ""
|
36 |
last_device = ""
|
37 |
+
last_ema = None
|
38 |
|
39 |
path_basic = os.path.abspath(os.path.join(__file__, "../../../.."))
|
40 |
path_data = os.path.join(path_basic, "data")
|
|
|
801 |
return "Symbols are okay no need to extend."
|
802 |
|
803 |
size_vocab = len(vocab)
|
804 |
+
vocab.pop() # fix empty space leave
|
805 |
for item in miss_symbols:
|
806 |
vocab.append(item)
|
807 |
|
|
|
916 |
)
|
917 |
|
918 |
|
919 |
+
def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema):
|
920 |
+
global last_checkpoint, last_device, tts_api, last_ema
|
921 |
|
922 |
if not os.path.isfile(file_checkpoint):
|
923 |
return None, "checkpoint not found!"
|
|
|
927 |
else:
|
928 |
device_test = None
|
929 |
|
930 |
+
if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema:
|
931 |
if last_checkpoint != file_checkpoint:
|
932 |
last_checkpoint = file_checkpoint
|
933 |
+
|
934 |
if last_device != device_test:
|
935 |
last_device = device_test
|
936 |
|
937 |
+
if last_ema != use_ema:
|
938 |
+
last_ema = use_ema
|
939 |
+
|
940 |
+
tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test, use_ema=use_ema)
|
941 |
|
942 |
+
print("update >> ", device_test, file_checkpoint, use_ema)
|
943 |
|
944 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
945 |
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
|
|
1278 |
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
1279 |
|
1280 |
nfe_step = gr.Number(label="n_step", value=32)
|
1281 |
+
ch_use_ema = gr.Checkbox(label="use ema", value=True)
|
1282 |
with gr.Row():
|
1283 |
cm_checkpoint = gr.Dropdown(
|
1284 |
choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
|
|
|
1290 |
ref_text = gr.Textbox(label="ref text")
|
1291 |
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
1292 |
gen_text = gr.Textbox(label="gen text")
|
1293 |
+
|
1294 |
random_sample_infer.click(
|
1295 |
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
|
1296 |
)
|
|
|
1303 |
|
1304 |
check_button_infer.click(
|
1305 |
fn=infer,
|
1306 |
+
inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema],
|
1307 |
outputs=[gen_audio, txt_info_gpu],
|
1308 |
)
|
1309 |
|