SSR-Speech / app.py
OpenSound's picture
Update app.py
d6275a2 verified
raw
history blame
41.5 kB
import os
os.system("bash setup.sh")
import requests
import re
from num2words import num2words
import gradio as gr
import torch
import torchaudio
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
)
from edit_utils_en import parse_edit_en
from edit_utils_en import parse_tts_en
from edit_utils_zh import parse_edit_zh
from edit_utils_zh import parse_tts_zh
from inference_scale import inference_one_sample
import librosa
import soundfile as sf
from models import ssr
import io
import numpy as np
import random
import uuid
import opencc
import spaces
import nltk
nltk.download('punkt')
DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
os.makedirs(MODELS_PATH, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
# if not os.path.exists(os.path.join(MODELS_PATH, "wmencodec.th")):
# # download wmencodec
# url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th"
# filename = os.path.join(MODELS_PATH, "wmencodec.th")
# response = requests.get(url, stream=True)
# response.raise_for_status()
# with open(filename, "wb") as file:
# for chunk in response.iter_content(chunk_size=8192):
# file.write(chunk)
# print(f"File downloaded to: {filename}")
# else:
# print("wmencodec model found")
# if not os.path.exists(os.path.join(MODELS_PATH, "English.pth")):
# # download english model
# url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/English.pth"
# filename = os.path.join(MODELS_PATH, "English.pth")
# response = requests.get(url, stream=True)
# response.raise_for_status()
# with open(filename, "wb") as file:
# for chunk in response.iter_content(chunk_size=8192):
# file.write(chunk)
# print(f"File downloaded to: {filename}")
# else:
# print("english model found")
# if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")):
# # download mandarin model
# url = "https://huggingface.co/westbrook/SSR-Speech-Mandarin/resolve/main/Mandarin.pth"
# filename = os.path.join(MODELS_PATH, "Mandarin.pth")
# response = requests.get(url, stream=True)
# response.raise_for_status()
# with open(filename, "wb") as file:
# for chunk in response.iter_content(chunk_size=8192):
# file.write(chunk)
# print(f"File downloaded to: {filename}")
# else:
# print("mandarin model found")
def get_random_string():
return "".join(str(uuid.uuid4()).split("-"))
@spaces.GPU
def seed_everything(seed):
if seed != -1:
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def get_mask_interval(transcribe_state, word_span):
print(transcribe_state)
seg_num = len(transcribe_state['segments'])
data = []
for i in range(seg_num):
words = transcribe_state['segments'][i]['words']
for item in words:
data.append([item['start'], item['end'], item['word']])
s, e = word_span[0], word_span[1]
assert s <= e, f"s:{s}, e:{e}"
assert s >= 0, f"s:{s}"
assert e <= len(data), f"e:{e}"
if e == 0: # start
start = 0.
end = float(data[0][0])
elif s == len(data): # end
start = float(data[-1][1])
end = float(data[-1][1]) # don't know the end yet
elif s == e: # insert
start = float(data[s-1][1])
end = float(data[s][0])
else:
start = float(data[s-1][1]) if s > 0 else float(data[s][0])
end = float(data[e][0]) if e < len(data) else float(data[-1][1])
return (start, end)
def traditional_to_simplified(segments):
converter = opencc.OpenCC('t2s')
seg_num = len(segments)
for i in range(seg_num):
words = segments[i]['words']
for j in range(len(words)):
segments[i]['words'][j]['word'] = converter.convert(segments[i]['words'][j]['word'])
segments[i]['text'] = converter.convert(segments[i]['text'])
return segments
from whisperx import load_align_model, load_model, load_audio
from whisperx import align as align_func
# Load models
text_tokenizer_en = TextTokenizer(backend="espeak")
text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn')
# ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
# ckpt_en = torch.load(ssrspeech_fn_en)
# model_en = ssr.SSR_Speech(ckpt_en["config"])
# model_en.load_state_dict(ckpt_en["model"])
# config_en = model_en.args
# phn2num_en = ckpt_en["phn2num"]
# model_en.to(device)
# ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
# ckpt_zh = torch.load(ssrspeech_fn_zh)
# model_zh = ssr.SSR_Speech(ckpt_zh["config"])
# model_zh.load_state_dict(ckpt_zh["model"])
# config_zh = model_zh.args
# phn2num_zh = ckpt_zh["phn2num"]
# model_zh.to(device)
# encodec_fn = f"{MODELS_PATH}/wmencodec.th"
# ssrspeech_model_en = {
# "config": config_en,
# "phn2num": phn2num_en,
# "model": model_en,
# "text_tokenizer": text_tokenizer_en,
# "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
# }
# ssrspeech_model_zh = {
# "config": config_zh,
# "phn2num": phn2num_zh,
# "model": model_zh,
# "text_tokenizer": text_tokenizer_zh,
# "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
# }
def get_transcribe_state(segments):
transcript = " ".join([segment["text"] for segment in segments])
transcript = transcript[1:] if transcript[0] == " " else transcript
return {
"segments": segments,
"transcript": transcript,
}
@spaces.GPU
def transcribe_en(audio_path):
language = "en"
transcribe_model_name = "medium.en"
transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
for segment in segments:
segment['text'] = replace_numbers_with_words(segment['text'])
_, segments = align_en(segments, audio_path)
state = get_transcribe_state(segments)
success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
return [
state["transcript"], state['segments'],
state, success_message
]
@spaces.GPU
def transcribe_zh(audio_path):
language = "zh"
transcribe_model_name = "medium"
transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
_, segments = align_zh(segments, audio_path)
state = get_transcribe_state(segments)
success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
converter = opencc.OpenCC('t2s')
state["transcript"] = converter.convert(state["transcript"])
return [
state["transcript"], state['segments'],
state, success_message
]
@spaces.GPU
def align_en(segments, audio_path):
language = "en"
align_model, metadata = load_align_model(language_code=language, device=device)
audio = load_audio(audio_path)
segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
state = get_transcribe_state(segments)
return state, segments
@spaces.GPU
def align_zh(segments, audio_path):
language = "zh"
align_model, metadata = load_align_model(language_code=language, device=device)
audio = load_audio(audio_path)
segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
state = get_transcribe_state(segments)
return state, segments
def get_output_audio(audio_tensors, codec_audio_sr):
result = torch.cat(audio_tensors, 1)
buffer = io.BytesIO()
torchaudio.save(buffer, result, int(codec_audio_sr), format="wav")
buffer.seek(0)
return buffer.read()
def replace_numbers_with_words(sentence):
sentence = re.sub(r'(\d+)', r' \1 ', sentence) # add spaces around numbers
def replace_with_words(match):
num = match.group(0)
try:
return num2words(num) # Convert numbers to words
except:
return num # In case num2words fails (unlikely with digits but just to be safe)
return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
@spaces.GPU
def run_edit_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
audio_path, original_transcript, transcript):
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
kvcache = 1
stop_repetition = 2
aug_text = True if aug_text == 1 else False
seed_everything(seed)
# resample audio
audio, _ = librosa.load(audio_path, sr=16000)
sf.write(audio_path, audio, 16000)
# text normalization
target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
[orig_transcript, segments, _, _] = transcribe_en(audio_path)
orig_transcript = orig_transcript.lower()
target_transcript = target_transcript.lower()
transcribe_state,_ = align_en(segments, audio_path)
print(orig_transcript)
print(target_transcript)
operations, orig_spans = parse_edit_en(orig_transcript, target_transcript)
print(operations)
print("orig_spans: ", orig_spans)
if len(orig_spans) > 3:
raise gr.Error("Current model only supports maximum 3 editings")
starting_intervals = []
ending_intervals = []
for orig_span in orig_spans:
start, end = get_mask_interval(transcribe_state, orig_span)
starting_intervals.append(start)
ending_intervals.append(end)
print("intervals: ", starting_intervals, ending_intervals)
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
def combine_spans(spans, threshold=0.2):
spans.sort(key=lambda x: x[0])
combined_spans = []
current_span = spans[0]
for i in range(1, len(spans)):
next_span = spans[i]
if current_span[1] >= next_span[0] - threshold:
current_span[1] = max(current_span[1], next_span[1])
else:
combined_spans.append(current_span)
current_span = next_span
combined_spans.append(current_span)
return combined_spans
morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
for start, end in zip(starting_intervals, ending_intervals)] # in seconds
morphed_span = combine_spans(morphed_span, threshold=0.2)
print("morphed_spans: ", morphed_span)
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
new_audio = inference_one_sample(
ssrspeech_model_en["model"],
ssrspeech_model_en["config"],
ssrspeech_model_en["phn2num"],
ssrspeech_model_en["text_tokenizer"],
ssrspeech_model_en["audio_tokenizer"],
audio_path, orig_transcript, target_transcript, mask_interval,
cfg_coef, cfg_stride, aug_text, False, True, False,
device, decode_config
)
audio_tensors = []
# save segments for comparison
new_audio = new_audio[0].cpu()
torchaudio.save(audio_path, new_audio, codec_audio_sr)
audio_tensors.append(new_audio)
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
return output_audio, success_message
@spaces.GPU
def run_tts_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
audio_path, original_transcript, transcript):
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
kvcache = 1
stop_repetition = 2
aug_text = True if aug_text == 1 else False
seed_everything(seed)
# resample audio
audio, _ = librosa.load(audio_path, sr=16000)
sf.write(audio_path, audio, 16000)
# text normalization
target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
[orig_transcript, segments, _, _] = transcribe_en(audio_path)
orig_transcript = orig_transcript.lower()
target_transcript = target_transcript.lower()
transcribe_state,_ = align_en(segments, audio_path)
print(orig_transcript)
print(target_transcript)
info = torchaudio.info(audio_path)
duration = info.num_frames / info.sample_rate
cut_length = duration
# Cut long audio for tts
if duration > prompt_length:
seg_num = len(transcribe_state['segments'])
for i in range(seg_num):
words = transcribe_state['segments'][i]['words']
for item in words:
if item['end'] >= prompt_length:
cut_length = min(item['end'], cut_length)
audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
sf.write(audio_path, audio, 16000)
[orig_transcript, segments, _, _] = transcribe_en(audio_path)
orig_transcript = orig_transcript.lower()
target_transcript = target_transcript.lower()
transcribe_state,_ = align_en(segments, audio_path)
print(orig_transcript)
target_transcript_copy = target_transcript # for tts cut out
target_transcript_copy = target_transcript_copy.split(' ')[0]
target_transcript = orig_transcript + ' ' + target_transcript
print(target_transcript)
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
morphed_span = [(audio_dur, audio_dur)] # in seconds
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
print("mask_interval: ", mask_interval)
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
new_audio = inference_one_sample(
ssrspeech_model_en["model"],
ssrspeech_model_en["config"],
ssrspeech_model_en["phn2num"],
ssrspeech_model_en["text_tokenizer"],
ssrspeech_model_en["audio_tokenizer"],
audio_path, orig_transcript, target_transcript, mask_interval,
cfg_coef, cfg_stride, aug_text, False, True, True,
device, decode_config
)
audio_tensors = []
# save segments for comparison
new_audio = new_audio[0].cpu()
torchaudio.save(audio_path, new_audio, codec_audio_sr)
[new_transcript, new_segments, _, _] = transcribe_en(audio_path)
transcribe_state,_ = align_en(new_segments, audio_path)
tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
tmp2 = target_transcript_copy.lower()
if tmp1 == tmp2:
offset = transcribe_state['segments'][0]['words'][0]['start']
else:
offset = transcribe_state['segments'][0]['words'][1]['start']
new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
audio_tensors.append(new_audio)
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
return output_audio, success_message
@spaces.GPU
def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
audio_path, original_transcript, transcript):
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
kvcache = 1
stop_repetition = 2
aug_text = True if aug_text == 1 else False
seed_everything(seed)
# resample audio
audio, _ = librosa.load(audio_path, sr=16000)
sf.write(audio_path, audio, 16000)
# text normalization
target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
[orig_transcript, segments, _, _] = transcribe_zh(audio_path)
print(orig_transcript)
converter = opencc.OpenCC('t2s')
orig_transcript = converter.convert(orig_transcript)
transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
print(orig_transcript)
print(target_transcript)
operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript)
print(operations)
print("orig_spans: ", orig_spans)
if len(orig_spans) > 3:
raise gr.Error("Current model only supports maximum 3 editings")
starting_intervals = []
ending_intervals = []
for orig_span in orig_spans:
start, end = get_mask_interval(transcribe_state, orig_span)
starting_intervals.append(start)
ending_intervals.append(end)
print("intervals: ", starting_intervals, ending_intervals)
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
def combine_spans(spans, threshold=0.2):
spans.sort(key=lambda x: x[0])
combined_spans = []
current_span = spans[0]
for i in range(1, len(spans)):
next_span = spans[i]
if current_span[1] >= next_span[0] - threshold:
current_span[1] = max(current_span[1], next_span[1])
else:
combined_spans.append(current_span)
current_span = next_span
combined_spans.append(current_span)
return combined_spans
morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
for start, end in zip(starting_intervals, ending_intervals)] # in seconds
morphed_span = combine_spans(morphed_span, threshold=0.2)
print("morphed_spans: ", morphed_span)
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
new_audio = inference_one_sample(
ssrspeech_model_zh["model"],
ssrspeech_model_zh["config"],
ssrspeech_model_zh["phn2num"],
ssrspeech_model_zh["text_tokenizer"],
ssrspeech_model_zh["audio_tokenizer"],
audio_path, orig_transcript, target_transcript, mask_interval,
cfg_coef, cfg_stride, aug_text, False, True, False,
device, decode_config
)
audio_tensors = []
# save segments for comparison
new_audio = new_audio[0].cpu()
torchaudio.save(audio_path, new_audio, codec_audio_sr)
audio_tensors.append(new_audio)
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
return output_audio, success_message
@spaces.GPU
def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
audio_path, original_transcript, transcript):
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
kvcache = 1
stop_repetition = 2
aug_text = True if aug_text == 1 else False
seed_everything(seed)
# resample audio
audio, _ = librosa.load(audio_path, sr=16000)
sf.write(audio_path, audio, 16000)
# text normalization
target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
[orig_transcript, segments, _, _] = transcribe_zh(audio_path)
converter = opencc.OpenCC('t2s')
orig_transcript = converter.convert(orig_transcript)
transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
print(orig_transcript)
print(target_transcript)
info = torchaudio.info(audio_path)
duration = info.num_frames / info.sample_rate
cut_length = duration
# Cut long audio for tts
if duration > prompt_length:
seg_num = len(transcribe_state['segments'])
for i in range(seg_num):
words = transcribe_state['segments'][i]['words']
for item in words:
if item['end'] >= prompt_length:
cut_length = min(item['end'], cut_length)
audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
sf.write(audio_path, audio, 16000)
[orig_transcript, segments, _, _] = transcribe_zh(audio_path)
converter = opencc.OpenCC('t2s')
orig_transcript = converter.convert(orig_transcript)
transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
print(orig_transcript)
target_transcript_copy = target_transcript # for tts cut out
target_transcript_copy = target_transcript_copy[0]
target_transcript = orig_transcript + target_transcript
print(target_transcript)
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
morphed_span = [(audio_dur, audio_dur)] # in seconds
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
print("mask_interval: ", mask_interval)
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
new_audio = inference_one_sample(
ssrspeech_model_zh["model"],
ssrspeech_model_zh["config"],
ssrspeech_model_zh["phn2num"],
ssrspeech_model_zh["text_tokenizer"],
ssrspeech_model_zh["audio_tokenizer"],
audio_path, orig_transcript, target_transcript, mask_interval,
cfg_coef, cfg_stride, aug_text, False, True, True,
device, decode_config
)
audio_tensors = []
# save segments for comparison
new_audio = new_audio[0].cpu()
torchaudio.save(audio_path, new_audio, codec_audio_sr)
[new_transcript, new_segments, _,_] = transcribe_zh(audio_path)
transcribe_state,_ = align_zh(traditional_to_simplified(new_segments), audio_path)
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
tmp1 = transcribe_state['segments'][0]['words'][0]['word']
tmp2 = target_transcript_copy
if tmp1 == tmp2:
offset = transcribe_state['segments'][0]['words'][0]['start']
else:
offset = transcribe_state['segments'][0]['words'][1]['start']
new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
audio_tensors.append(new_audio)
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
return output_audio, success_message
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Ssrspeech gradio app.")
parser.add_argument("--demo-path", default="./demo", help="Path to demo directory")
parser.add_argument("--tmp-path", default="./demo/temp", help="Path to tmp directory")
parser.add_argument("--models-path", default="./pretrained_models", help="Path to ssrspeech models directory")
parser.add_argument("--port", default=7860, type=int, help="App port")
parser.add_argument("--share", action="store_true", help="Launch with public url")
os.environ["USER"] = os.getenv("USER", "user")
args = parser.parse_args()
DEMO_PATH = args.demo_path
TMP_PATH = args.tmp_path
MODELS_PATH = args.models_path
# app = get_app()
# app.queue().launch(share=args.share, server_port=args.port)
# CSS styling (optional)
css = """
#col-container {
margin: 0 auto;
max-width: 1280px;
}
"""
# Gradio Blocks layout
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# SSR-Speech: High-quality Speech Editor and Text-to-Speech Synthesizer
Generate and edit speech from text. Adjust advanced settings for more control.
Learn more about 🚀**SSR-Speech** on the [SSR-Speech Homepage](https://wanghelin1997.github.io/SSR-Speech-Demo/).
""")
# Tabs for Generate and Edit
with gr.Tab("English Speech Editing"):
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
with gr.Group():
original_transcript = gr.Textbox(label="Original transcript", lines=5, value="but when I had approached so near to them the common object, which the sense deceives, lost not by distance any of its marks.",
info="Use whisperx model to get the transcript.")
transcribe_btn = gr.Button(value="Transcribe")
with gr.Column(scale=3):
with gr.Group():
transcript = gr.Textbox(label="Text", lines=7, value="but when I saw the mirage of the lake in the distance, which the sense deceives, lost not by distance any of its marks.", interactive=True)
run_btn = gr.Button(value="Run")
with gr.Column(scale=2):
output_audio = gr.Audio(label="Output Audio")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
info="set to 1 to use classifer-free guidance, change if you don't like the results")
cfg_coef = gr.Number(label="cfg_coef", value=1.5,
info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
cfg_stride = gr.Number(label="cfg_stride", value=5,
info="cfg stride, 5 is a good value for English, change if you don't like the results")
prompt_length = gr.Number(label="prompt_length", value=3,
info="used for tts prompt, will automatically cut the prompt audio to this length")
sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
success_output = gr.HTML()
semgents = gr.State() # not used
state = gr.State() # not used
audio_state = gr.State(value=f"{DEMO_PATH}/84_121550_000074_000000.wav")
input_audio.change(
lambda audio: audio,
inputs=[input_audio],
outputs=[audio_state]
)
transcribe_btn.click(fn=transcribe_en,
inputs=[audio_state],
outputs=[original_transcript, semgents, state, success_output])
run_btn.click(fn=run_edit_en,
inputs=[
seed, sub_amount,
aug_text, cfg_coef, cfg_stride, prompt_length,
audio_state, original_transcript, transcript,
],
outputs=[output_audio, success_output])
transcript.submit(fn=run_edit_en,
inputs=[
seed, sub_amount,
aug_text, cfg_coef, cfg_stride, prompt_length,
audio_state, original_transcript, transcript,
],
outputs=[output_audio, success_output]
)
with gr.Tab("English TTS"):
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
with gr.Group():
original_transcript = gr.Textbox(label="Original transcript", lines=5, value="but when I had approached so near to them the common object, which the sense deceives, lost not by distance any of its marks.",
info="Use whisperx model to get the transcript.")
transcribe_btn = gr.Button(value="Transcribe")
with gr.Column(scale=3):
with gr.Group():
transcript = gr.Textbox(label="Text", lines=7, value="I cannot believe that the same model can also do text to speech synthesis too!", interactive=True)
run_btn = gr.Button(value="Run")
with gr.Column(scale=2):
output_audio = gr.Audio(label="Output Audio")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
info="set to 1 to use classifer-free guidance, change if you don't like the results")
cfg_coef = gr.Number(label="cfg_coef", value=1.5,
info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
cfg_stride = gr.Number(label="cfg_stride", value=5,
info="cfg stride, 5 is a good value for English, change if you don't like the results")
prompt_length = gr.Number(label="prompt_length", value=3,
info="used for tts prompt, will automatically cut the prompt audio to this length")
sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
success_output = gr.HTML()
semgents = gr.State() # not used
state = gr.State() # not used
audio_state = gr.State(value=f"{DEMO_PATH}/84_121550_000074_000000.wav")
input_audio.change(
lambda audio: audio,
inputs=[input_audio],
outputs=[audio_state]
)
transcribe_btn.click(fn=transcribe_en,
inputs=[audio_state],
outputs=[original_transcript, semgents, state, success_output])
run_btn.click(fn=run_tts_en,
inputs=[
seed, sub_amount,
aug_text, cfg_coef, cfg_stride, prompt_length,
audio_state, original_transcript, transcript,
],
outputs=[output_audio, success_output])
transcript.submit(fn=run_tts_en,
inputs=[
seed, sub_amount,
aug_text, cfg_coef, cfg_stride, prompt_length,
audio_state, original_transcript, transcript,
],
outputs=[output_audio, success_output]
)
with gr.Tab("Mandarin Speech Editing"):
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
with gr.Group():
original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
info="Use whisperx model to get the transcript.")
transcribe_btn = gr.Button(value="Transcribe")
with gr.Column(scale=3):
with gr.Group():
transcript = gr.Textbox(label="Text", lines=7, value="价格已基本都在一万到两万之间", interactive=True)
run_btn = gr.Button(value="Run")
with gr.Column(scale=2):
output_audio = gr.Audio(label="Output Audio")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
info="set to 1 to use classifer-free guidance, change if you don't like the results")
cfg_coef = gr.Number(label="cfg_coef", value=1.5,
info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
cfg_stride = gr.Number(label="cfg_stride", value=1,
info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
prompt_length = gr.Number(label="prompt_length", value=3,
info="used for tts prompt, will automatically cut the prompt audio to this length")
sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
success_output = gr.HTML()
semgents = gr.State() # not used
state = gr.State() # not used
audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
input_audio.change(
lambda audio: audio,
inputs=[input_audio],
outputs=[audio_state]
)
transcribe_btn.click(fn=transcribe_zh,
inputs=[audio_state],
outputs=[original_transcript, semgents, state, success_output])
run_btn.click(fn=run_edit_zh,
inputs=[
seed, sub_amount,
aug_text, cfg_coef, cfg_stride, prompt_length,
audio_state, original_transcript, transcript,
],
outputs=[output_audio, success_output])
transcript.submit(fn=run_edit_zh,
inputs=[
seed, sub_amount,
aug_text, cfg_coef, cfg_stride, prompt_length,
audio_state, original_transcript, transcript,
],
outputs=[output_audio, success_output]
)
with gr.Tab("Mandarin TTS"):
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
with gr.Group():
original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
info="Use whisperx model to get the transcript.")
transcribe_btn = gr.Button(value="Transcribe")
with gr.Column(scale=3):
with gr.Group():
transcript = gr.Textbox(label="Text", lines=7, value="我简直不敢相信同一个模型也可以进行文本到语音的生成", interactive=True)
run_btn = gr.Button(value="Run")
with gr.Column(scale=2):
output_audio = gr.Audio(label="Output Audio")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
info="set to 1 to use classifer-free guidance, change if you don't like the results")
cfg_coef = gr.Number(label="cfg_coef", value=1.5,
info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
cfg_stride = gr.Number(label="cfg_stride", value=1,
info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
prompt_length = gr.Number(label="prompt_length", value=3,
info="used for tts prompt, will automatically cut the prompt audio to this length")
sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
success_output = gr.HTML()
semgents = gr.State() # not used
state = gr.State() # not used
audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
input_audio.change(
lambda audio: audio,
inputs=[input_audio],
outputs=[audio_state]
)
transcribe_btn.click(fn=transcribe_zh,
inputs=[audio_state],
outputs=[original_transcript, semgents, state, success_output])
run_btn.click(fn=run_tts_zh,
inputs=[
seed, sub_amount,
aug_text, cfg_coef, cfg_stride, prompt_length,
audio_state, original_transcript, transcript,
],
outputs=[output_audio, success_output])
transcript.submit(fn=run_tts_zh,
inputs=[
seed, sub_amount,
aug_text, cfg_coef, cfg_stride, prompt_length,
audio_state, original_transcript, transcript,
],
outputs=[output_audio, success_output]
)
# Launch the Gradio demo
demo.launch()