TArtx commited on
Commit
f38c401
·
verified ·
1 Parent(s): 5fdd9cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -75
app.py CHANGED
@@ -7,12 +7,12 @@ import re
7
  from parler_tts import ParlerTTSForConditionalGeneration
8
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
9
 
10
- # Set device
11
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
13
- # Load Mini model and associated components
14
  repo_id = "TArtx/parler-tts-mini-v1-finetuned-12"
15
- model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
16
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
17
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
18
 
@@ -33,7 +33,6 @@ def preprocess(text):
33
  text = text.replace("-", " ")
34
  if text[-1] not in punctuation:
35
  text = f"{text}."
36
-
37
  abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
38
 
39
  def separate_abb(chunk):
@@ -48,78 +47,31 @@ def preprocess(text):
48
 
49
  # TTS generation function
50
  def gen_tts(text, description):
51
- inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
52
- prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
53
-
54
- set_seed(SEED)
55
- generation = model.generate(
56
- input_ids=inputs.input_ids,
57
- prompt_input_ids=prompt.input_ids,
58
- attention_mask=inputs.attention_mask,
59
- prompt_attention_mask=prompt.attention_mask,
60
- do_sample=True,
61
- temperature=1.0,
62
- )
63
- audio_arr = generation.cpu().numpy().squeeze()
64
- return SAMPLE_RATE, audio_arr
65
-
66
- # CSS for styling
67
- css = """
68
- #share-btn-container {
69
- display: flex;
70
- padding-left: 0.5rem !important;
71
- padding-right: 0.5rem !important;
72
- background-color: #000000;
73
- justify-content: center;
74
- align-items: center;
75
- border-radius: 9999px !important;
76
- width: 13rem;
77
- margin-top: 10px;
78
- margin-left: auto;
79
- flex: unset !important;
80
- }
81
- #share-btn {
82
- all: initial;
83
- color: #ffffff;
84
- font-weight: 600;
85
- cursor: pointer;
86
- font-family: 'IBM Plex Sans', sans-serif;
87
- margin-left: 0.5rem !important;
88
- padding-top: 0.25rem !important;
89
- padding-bottom: 0.25rem !important;
90
- right:0;
91
- }
92
- #share-btn * {
93
- all: unset !important;
94
- }
95
- #share-btn-container div:nth-child(-n+2){
96
- width: auto !important;
97
- min-height: 0px !important;
98
- }
99
- #share-btn-container .wrap {
100
- display: none !important;
101
- }
102
- """
103
 
104
  # Gradio interface
105
- with gr.Blocks(css=css) as block:
106
- gr.HTML(
107
- """
108
- <div style="text-align: center; max-width: 700px; margin: 0 auto;">
109
- <div
110
- style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
111
- >
112
- <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
113
- Parler-TTS 🗣️
114
- </h1>
115
- </div>
116
- </div>
117
  """
118
- )
119
- gr.HTML(
120
- f"""
121
- <p><a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a> is a training and inference library for
122
- high-fidelity text-to-speech (TTS) models. The demo uses the Mini v1 model by default.</p>
123
  """
124
  )
125
  with gr.Row():
@@ -136,4 +88,4 @@ with gr.Blocks(css=css) as block:
136
 
137
  # Launch the interface
138
  block.queue()
139
- block.launch(share=True)
 
7
  from parler_tts import ParlerTTSForConditionalGeneration
8
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
9
 
10
+ # Set device to CPU only
11
+ device = "cpu"
12
 
13
+ # Load Mini model and associated components with low memory usage
14
  repo_id = "TArtx/parler-tts-mini-v1-finetuned-12"
15
+ model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, low_cpu_mem_usage=True).to(device)
16
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
17
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
18
 
 
33
  text = text.replace("-", " ")
34
  if text[-1] not in punctuation:
35
  text = f"{text}."
 
36
  abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
37
 
38
  def separate_abb(chunk):
 
47
 
48
  # TTS generation function
49
  def gen_tts(text, description):
50
+ try:
51
+ # Tokenize inputs and prompts with truncation to avoid memory issues
52
+ inputs = tokenizer(description.strip(), return_tensors="pt", truncation=True, max_length=128).to(device)
53
+ prompt = tokenizer(preprocess(text), return_tensors="pt", truncation=True, max_length=128).to(device)
54
+
55
+ set_seed(SEED)
56
+ generation = model.generate(
57
+ input_ids=inputs.input_ids,
58
+ prompt_input_ids=prompt.input_ids,
59
+ attention_mask=inputs.attention_mask,
60
+ prompt_attention_mask=prompt.prompt_attention_mask,
61
+ do_sample=True,
62
+ temperature=1.0,
63
+ )
64
+ audio_arr = generation.cpu().numpy().squeeze()
65
+ return SAMPLE_RATE, audio_arr
66
+ except Exception as e:
67
+ return SAMPLE_RATE, f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  # Gradio interface
70
+ with gr.Blocks() as block:
71
+ gr.Markdown(
 
 
 
 
 
 
 
 
 
 
72
  """
73
+ ## Parler-TTS 🗣️
74
+ Parler-TTS is a training and inference library for high-fidelity text-to-speech (TTS) models. This demo uses the Mini v1 model.
 
 
 
75
  """
76
  )
77
  with gr.Row():
 
88
 
89
  # Launch the interface
90
  block.queue()
91
+ block.launch()