poemsforaphrodite commited on
Commit
72ff919
1 Parent(s): c2f0c2e

Upload openvoice_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. openvoice_app.py +18 -60
openvoice_app.py CHANGED
@@ -6,88 +6,45 @@ import langid
6
  from openvoice import se_extractor
7
  from openvoice.api import BaseSpeakerTTS, ToneColorConverter
8
  from dotenv import load_dotenv
 
 
 
 
9
 
10
  # Argument parsing
11
  parser = argparse.ArgumentParser()
12
  parser.add_argument("--share", action='store_true', default=False, help="make link public")
13
  args = parser.parse_args()
14
- load_dotenv()
15
-
16
- # Paths and device setup
17
- en_ckpt_base = 'base_speakers/EN'
18
- zh_ckpt_base = 'base_speakers/ZH'
19
- ckpt_converter = 'converter'
20
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
  output_dir = 'outputs'
22
  os.makedirs(output_dir, exist_ok=True)
23
 
24
- # Load models
25
- en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
26
- en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
27
- zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
28
- zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
29
- tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
30
- tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
31
-
32
- # Load speaker embeddings
33
- en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
34
- en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
35
- zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
36
 
37
- # Supported languages
38
  supported_languages = ['zh', 'en']
39
 
40
  def predict(prompt, style, audio_file_pth):
41
  text_hint = ''
42
-
43
- # Detect the input language
44
- language_predicted = langid.classify(prompt)[0].strip()
45
- print(f"Detected language: {language_predicted}")
46
-
47
- if language_predicted not in supported_languages:
48
- text_hint += f"[ERROR] The detected language {language_predicted} is not supported. Supported languages: {supported_languages}\n"
49
- return text_hint, None, None
50
-
51
- if language_predicted == "zh":
52
- tts_model = zh_base_speaker_tts
53
- source_se = zh_source_se
54
- language = 'Chinese'
55
- if style != 'default':
56
- text_hint += f"[ERROR] The style {style} is not supported for Chinese. Supported style: 'default'\n"
57
- return text_hint, None, None
58
- else:
59
- tts_model = en_base_speaker_tts
60
- source_se = en_source_default_se if style == 'default' else en_source_style_se
61
- language = 'English'
62
- if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']:
63
- text_hint += f"[ERROR] The style {style} is not supported for English. Supported styles: ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n"
64
- return text_hint, None, None
65
-
66
  if len(prompt) < 2:
67
  text_hint += "[ERROR] Please provide a longer prompt text.\n"
68
  return text_hint, None, None
69
  if len(prompt) > 200:
70
  text_hint += "[ERROR] Text length limited to 200 characters. Please try shorter text.\n"
71
  return text_hint, None, None
72
-
73
- try:
74
- target_se, audio_name = se_extractor.get_se(audio_file_pth, tone_color_converter, target_dir='processed', vad=True)
75
- except Exception as e:
76
- text_hint += f"[ERROR] Error extracting tone color: {str(e)}\n"
77
- return text_hint, None, None
78
-
79
- src_path = f'{output_dir}/tmp.wav'
80
- tts_model.tts(prompt, src_path, speaker=style, language=language)
81
-
82
- save_path = f'{output_dir}/output.wav'
83
- encode_message = "@MyShell"
84
- tone_color_converter.convert(audio_src_path=src_path, src_se=source_se, tgt_se=target_se, output_path=save_path, message=encode_message)
85
-
86
- text_hint += "Response generated successfully.\n"
87
  return text_hint, save_path, audio_file_pth
88
 
89
- title = "MyShell OpenVoice"
90
-
91
  # Gradio interface setup
92
  with gr.Blocks(gr.themes.Glass()) as demo:
93
  with gr.Row():
@@ -123,3 +80,4 @@ demo.launch(debug=True, show_api=False, share=args.share)
123
 
124
  # Hide Gradio footer
125
  css = "footer {visibility: hidden}"
 
 
6
  from openvoice import se_extractor
7
  from openvoice.api import BaseSpeakerTTS, ToneColorConverter
8
  from dotenv import load_dotenv
9
+ from openai import OpenAI
10
+ from elevenlabs.client import ElevenLabs
11
+ from elevenlabs import play,save
12
+ load_dotenv()
13
 
14
  # Argument parsing
15
  parser = argparse.ArgumentParser()
16
  parser.add_argument("--share", action='store_true', default=False, help="make link public")
17
  args = parser.parse_args()
18
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
19
+ client = ElevenLabs(api_key=os.environ.get("ELEVENLABS_API_KEY"))
 
 
 
 
20
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
  output_dir = 'outputs'
22
  os.makedirs(output_dir, exist_ok=True)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
25
  supported_languages = ['zh', 'en']
26
 
27
  def predict(prompt, style, audio_file_pth):
28
  text_hint = ''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if len(prompt) < 2:
30
  text_hint += "[ERROR] Please provide a longer prompt text.\n"
31
  return text_hint, None, None
32
  if len(prompt) > 200:
33
  text_hint += "[ERROR] Text length limited to 200 characters. Please try shorter text.\n"
34
  return text_hint, None, None
35
+
36
+ print(audio_file_pth)
37
+ voice = client.clone(
38
+ name="TrialVoice",
39
+ description="A trial voice model for testing",
40
+ files=[audio_file_pth],
41
+ )
42
+ #text should be prompt
43
+ audio = client.generate(text=prompt, voice=voice)
44
+ save(audio, "result.mp3")
45
+ save_path="result.mp3"
 
 
 
 
46
  return text_hint, save_path, audio_file_pth
47
 
 
 
48
  # Gradio interface setup
49
  with gr.Blocks(gr.themes.Glass()) as demo:
50
  with gr.Row():
 
80
 
81
  # Hide Gradio footer
82
  css = "footer {visibility: hidden}"
83
+