Spaces:
Running
Running
poemsforaphrodite
commited on
Commit
•
72ff919
1
Parent(s):
c2f0c2e
Upload openvoice_app.py with huggingface_hub
Browse files- openvoice_app.py +18 -60
openvoice_app.py
CHANGED
@@ -6,88 +6,45 @@ import langid
|
|
6 |
from openvoice import se_extractor
|
7 |
from openvoice.api import BaseSpeakerTTS, ToneColorConverter
|
8 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Argument parsing
|
11 |
parser = argparse.ArgumentParser()
|
12 |
parser.add_argument("--share", action='store_true', default=False, help="make link public")
|
13 |
args = parser.parse_args()
|
14 |
-
|
15 |
-
|
16 |
-
# Paths and device setup
|
17 |
-
en_ckpt_base = 'base_speakers/EN'
|
18 |
-
zh_ckpt_base = 'base_speakers/ZH'
|
19 |
-
ckpt_converter = 'converter'
|
20 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
output_dir = 'outputs'
|
22 |
os.makedirs(output_dir, exist_ok=True)
|
23 |
|
24 |
-
# Load models
|
25 |
-
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
|
26 |
-
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
|
27 |
-
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
|
28 |
-
zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
|
29 |
-
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
|
30 |
-
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
|
31 |
-
|
32 |
-
# Load speaker embeddings
|
33 |
-
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
|
34 |
-
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
|
35 |
-
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
|
36 |
|
37 |
-
# Supported languages
|
38 |
supported_languages = ['zh', 'en']
|
39 |
|
40 |
def predict(prompt, style, audio_file_pth):
|
41 |
text_hint = ''
|
42 |
-
|
43 |
-
# Detect the input language
|
44 |
-
language_predicted = langid.classify(prompt)[0].strip()
|
45 |
-
print(f"Detected language: {language_predicted}")
|
46 |
-
|
47 |
-
if language_predicted not in supported_languages:
|
48 |
-
text_hint += f"[ERROR] The detected language {language_predicted} is not supported. Supported languages: {supported_languages}\n"
|
49 |
-
return text_hint, None, None
|
50 |
-
|
51 |
-
if language_predicted == "zh":
|
52 |
-
tts_model = zh_base_speaker_tts
|
53 |
-
source_se = zh_source_se
|
54 |
-
language = 'Chinese'
|
55 |
-
if style != 'default':
|
56 |
-
text_hint += f"[ERROR] The style {style} is not supported for Chinese. Supported style: 'default'\n"
|
57 |
-
return text_hint, None, None
|
58 |
-
else:
|
59 |
-
tts_model = en_base_speaker_tts
|
60 |
-
source_se = en_source_default_se if style == 'default' else en_source_style_se
|
61 |
-
language = 'English'
|
62 |
-
if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']:
|
63 |
-
text_hint += f"[ERROR] The style {style} is not supported for English. Supported styles: ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n"
|
64 |
-
return text_hint, None, None
|
65 |
-
|
66 |
if len(prompt) < 2:
|
67 |
text_hint += "[ERROR] Please provide a longer prompt text.\n"
|
68 |
return text_hint, None, None
|
69 |
if len(prompt) > 200:
|
70 |
text_hint += "[ERROR] Text length limited to 200 characters. Please try shorter text.\n"
|
71 |
return text_hint, None, None
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
save_path
|
83 |
-
encode_message = "@MyShell"
|
84 |
-
tone_color_converter.convert(audio_src_path=src_path, src_se=source_se, tgt_se=target_se, output_path=save_path, message=encode_message)
|
85 |
-
|
86 |
-
text_hint += "Response generated successfully.\n"
|
87 |
return text_hint, save_path, audio_file_pth
|
88 |
|
89 |
-
title = "MyShell OpenVoice"
|
90 |
-
|
91 |
# Gradio interface setup
|
92 |
with gr.Blocks(gr.themes.Glass()) as demo:
|
93 |
with gr.Row():
|
@@ -123,3 +80,4 @@ demo.launch(debug=True, show_api=False, share=args.share)
|
|
123 |
|
124 |
# Hide Gradio footer
|
125 |
css = "footer {visibility: hidden}"
|
|
|
|
6 |
from openvoice import se_extractor
|
7 |
from openvoice.api import BaseSpeakerTTS, ToneColorConverter
|
8 |
from dotenv import load_dotenv
|
9 |
+
from openai import OpenAI
|
10 |
+
from elevenlabs.client import ElevenLabs
|
11 |
+
from elevenlabs import play,save
|
12 |
+
load_dotenv()
|
13 |
|
14 |
# Argument parsing
|
15 |
parser = argparse.ArgumentParser()
|
16 |
parser.add_argument("--share", action='store_true', default=False, help="make link public")
|
17 |
args = parser.parse_args()
|
18 |
+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
19 |
+
client = ElevenLabs(api_key=os.environ.get("ELEVENLABS_API_KEY"))
|
|
|
|
|
|
|
|
|
20 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
output_dir = 'outputs'
|
22 |
os.makedirs(output_dir, exist_ok=True)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
|
|
25 |
supported_languages = ['zh', 'en']
|
26 |
|
27 |
def predict(prompt, style, audio_file_pth):
|
28 |
text_hint = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
if len(prompt) < 2:
|
30 |
text_hint += "[ERROR] Please provide a longer prompt text.\n"
|
31 |
return text_hint, None, None
|
32 |
if len(prompt) > 200:
|
33 |
text_hint += "[ERROR] Text length limited to 200 characters. Please try shorter text.\n"
|
34 |
return text_hint, None, None
|
35 |
+
|
36 |
+
print(audio_file_pth)
|
37 |
+
voice = client.clone(
|
38 |
+
name="TrialVoice",
|
39 |
+
description="A trial voice model for testing",
|
40 |
+
files=[audio_file_pth],
|
41 |
+
)
|
42 |
+
#text should be prompt
|
43 |
+
audio = client.generate(text=prompt, voice=voice)
|
44 |
+
save(audio, "result.mp3")
|
45 |
+
save_path="result.mp3"
|
|
|
|
|
|
|
|
|
46 |
return text_hint, save_path, audio_file_pth
|
47 |
|
|
|
|
|
48 |
# Gradio interface setup
|
49 |
with gr.Blocks(gr.themes.Glass()) as demo:
|
50 |
with gr.Row():
|
|
|
80 |
|
81 |
# Hide Gradio footer
|
82 |
css = "footer {visibility: hidden}"
|
83 |
+
|