VanYsa commited on
Commit
14b3eb9
·
1 Parent(s): c3f6601

Added and organized comments

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -18,10 +18,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
18
 
19
  from transformers import pipeline
20
 
 
 
21
  # Set an environment variable
22
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
 
24
- # Variables
25
  SAMPLE_RATE = 16000 # Hz
26
  MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
27
  DESCRIPTION = '''
@@ -29,8 +30,8 @@ DESCRIPTION = '''
29
  <h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
30
  <p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
31
  <p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://huggingface.co/kakao-enterprise/vits-ljs"><b>VITS-ljs by Kakao Enterprise</b></a> for text to speech (TTS).</p>
32
- <p>This demo accepts audio inputs not more than 40 seconds long.</p>
33
- <p>Transcription and responses are limited to the English language.</p>
34
  </div>
35
  '''
36
  PLACEHOLDER = """
@@ -42,7 +43,7 @@ PLACEHOLDER = """
42
 
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
 
45
- ### ASR model
46
  canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
47
  canary_model.eval()
48
  # make sure beam size always 1 for consistency
@@ -51,7 +52,7 @@ decoding_cfg = canary_model.cfg.decoding
51
  decoding_cfg.beam.beam_size = 1
52
  canary_model.change_decoding_strategy(decoding_cfg)
53
 
54
- ### LLM model
55
  # Load the tokenizer and model
56
  llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
57
  llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
@@ -64,11 +65,11 @@ terminators = [
64
  llm_tokenizer.convert_tokens_to_ids("<|eot_id|>")
65
  ]
66
 
67
- ### TTS model
68
  pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device)
69
 
70
 
71
-
72
  def convert_audio(audio_filepath, tmpdir, utt_id):
73
  """
74
  Convert all files to monochannel 16 kHz wav files.
@@ -99,8 +100,8 @@ def convert_audio(audio_filepath, tmpdir, utt_id):
99
 
100
  def transcribe(audio_filepath):
101
  """
102
- Transcribes a converted audio file.
103
- Set to english language with punctuations.
104
  Returns the transcribed text as a string.
105
  """
106
 
@@ -136,15 +137,15 @@ def transcribe(audio_filepath):
136
  def add_message(history, message):
137
  """
138
  Adds the input message in the chatbot.
139
- Returns the updated chatbot history.
140
  """
141
  history.append((message, None))
142
  return history
143
 
144
  def bot(history, message):
145
  """
146
- Gets the bot's response and places the user and bot messages in the chatbot
147
- Returns the appended chatbot history.
148
  """
149
  response = bot_response(message, history)
150
  lines = response.split("\n")
@@ -162,8 +163,8 @@ def bot(history, message):
162
  @spaces.GPU()
163
  def bot_response(message, history):
164
  """
165
- Generates a streaming response using the llama3-8b model.
166
- Set max_new_tokens = 100, temperature=0.6, and top_p=0.9
167
  Returns the generated response in string format.
168
  """
169
  conversation = []
@@ -175,7 +176,7 @@ def bot_response(message, history):
175
 
176
  outputs = llama3_model.generate(
177
  input_ids,
178
- max_new_tokens = 100,
179
  eos_token_id = terminators,
180
  do_sample=True,
181
  temperature=0.6,
@@ -190,7 +191,7 @@ def bot_response(message, history):
190
  @spaces.GPU()
191
  def voice_player(history):
192
  """
193
- Plays the generated response using the VITS-ljs model.
194
  Returns the audio player with the generated response.
195
  """
196
  _, text = history[-1]
@@ -205,7 +206,9 @@ def voice_player(history):
205
  visible=True)
206
  return voice
207
 
 
208
 
 
209
  with gr.Blocks(
210
  title="MyAlexa",
211
  css="""
@@ -251,13 +254,13 @@ with gr.Blocks(
251
  visible=False # set to True to see processing time of asr transcription
252
  )
253
 
254
- gr.HTML("<p><b>Step 3 [Optional]:</b> Replay MyAlexa's voice response.</p>")
255
 
256
  out_audio = gr.Audio( # Shows an audio player for the generated response
257
  value = None,
258
- label="Response Voice Player",
259
  show_label=True,
260
- visible=False # set to True to see processing time of initial tts audio generation
261
  )
262
 
263
  chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot], api_name="add_message_in_chatbot")
@@ -270,6 +273,7 @@ with gr.Blocks(
270
  outputs = [chat_input]
271
  )
272
 
 
273
  demo.queue()
274
  if __name__ == "__main__":
275
  demo.launch()
 
18
 
19
  from transformers import pipeline
20
 
21
+ #### Variables ###
22
+
23
  # Set an environment variable
24
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
25
 
 
26
  SAMPLE_RATE = 16000 # Hz
27
  MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
28
  DESCRIPTION = '''
 
30
  <h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
31
  <p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
32
  <p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://huggingface.co/kakao-enterprise/vits-ljs"><b>VITS-ljs by Kakao Enterprise</b></a> for text to speech (TTS).</p>
33
+ <p>This demo accepts audio inputs not more than 40 seconds long. Transcription and responses are limited to the English language.</p>
34
+ <p>The LLM max_new_tokens, temperature and top_p are set to 512, 0.6 and 0.9 respectively</p>
35
  </div>
36
  '''
37
  PLACEHOLDER = """
 
43
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
 
46
+ ### ASR model ###
47
  canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
48
  canary_model.eval()
49
  # make sure beam size always 1 for consistency
 
52
  decoding_cfg.beam.beam_size = 1
53
  canary_model.change_decoding_strategy(decoding_cfg)
54
 
55
+ ### LLM model ###
56
  # Load the tokenizer and model
57
  llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
58
  llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
 
65
  llm_tokenizer.convert_tokens_to_ids("<|eot_id|>")
66
  ]
67
 
68
+ ### TTS model ###
69
  pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device)
70
 
71
 
72
+ ### Start of functions ###
73
  def convert_audio(audio_filepath, tmpdir, utt_id):
74
  """
75
  Convert all files to monochannel 16 kHz wav files.
 
100
 
101
  def transcribe(audio_filepath):
102
  """
103
+ Transcribes a converted audio file using the asr model.
104
+ Set to the english language with punctuations.
105
  Returns the transcribed text as a string.
106
  """
107
 
 
137
  def add_message(history, message):
138
  """
139
  Adds the input message in the chatbot.
140
+ Returns the updated chatbot.
141
  """
142
  history.append((message, None))
143
  return history
144
 
145
  def bot(history, message):
146
  """
147
+ Gets the bot's response and adds it in the chatbot.
148
+ Returns the appended chatbot.
149
  """
150
  response = bot_response(message, history)
151
  lines = response.split("\n")
 
163
  @spaces.GPU()
164
  def bot_response(message, history):
165
  """
166
+ Generates a streaming response using the llm model.
167
+ Set max_new_tokens = 512, temperature=0.6, and top_p=0.9
168
  Returns the generated response in string format.
169
  """
170
  conversation = []
 
176
 
177
  outputs = llama3_model.generate(
178
  input_ids,
179
+ max_new_tokens = 512,
180
  eos_token_id = terminators,
181
  do_sample=True,
182
  temperature=0.6,
 
191
  @spaces.GPU()
192
  def voice_player(history):
193
  """
194
+ Plays the generated response using the tts model.
195
  Returns the audio player with the generated response.
196
  """
197
  _, text = history[-1]
 
206
  visible=True)
207
  return voice
208
 
209
+ ### End of functions ###
210
 
211
+ ### Interface using Blocks###
212
  with gr.Blocks(
213
  title="MyAlexa",
214
  css="""
 
254
  visible=False # set to True to see processing time of asr transcription
255
  )
256
 
257
+ gr.HTML("<p><b>[Optional]:</b> Replay MyAlexa's voice response.</p>")
258
 
259
  out_audio = gr.Audio( # Shows an audio player for the generated response
260
  value = None,
261
+ label="Response Audio Player",
262
  show_label=True,
263
+ visible=False # set to True to see processing time of the first tts audio generation
264
  )
265
 
266
  chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot], api_name="add_message_in_chatbot")
 
273
  outputs = [chat_input]
274
  )
275
 
276
+ ### Queue and launch the demo ###
277
  demo.queue()
278
  if __name__ == "__main__":
279
  demo.launch()