Added and organized comments
Browse files
app.py
CHANGED
@@ -18,10 +18,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
18 |
|
19 |
from transformers import pipeline
|
20 |
|
|
|
|
|
21 |
# Set an environment variable
|
22 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
23 |
|
24 |
-
# Variables
|
25 |
SAMPLE_RATE = 16000 # Hz
|
26 |
MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
|
27 |
DESCRIPTION = '''
|
@@ -29,8 +30,8 @@ DESCRIPTION = '''
|
|
29 |
<h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
|
30 |
<p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
|
31 |
<p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://huggingface.co/kakao-enterprise/vits-ljs"><b>VITS-ljs by Kakao Enterprise</b></a> for text to speech (TTS).</p>
|
32 |
-
<p>This demo accepts audio inputs not more than 40 seconds long.</p>
|
33 |
-
<p>
|
34 |
</div>
|
35 |
'''
|
36 |
PLACEHOLDER = """
|
@@ -42,7 +43,7 @@ PLACEHOLDER = """
|
|
42 |
|
43 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
44 |
|
45 |
-
### ASR model
|
46 |
canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
|
47 |
canary_model.eval()
|
48 |
# make sure beam size always 1 for consistency
|
@@ -51,7 +52,7 @@ decoding_cfg = canary_model.cfg.decoding
|
|
51 |
decoding_cfg.beam.beam_size = 1
|
52 |
canary_model.change_decoding_strategy(decoding_cfg)
|
53 |
|
54 |
-
### LLM model
|
55 |
# Load the tokenizer and model
|
56 |
llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
57 |
llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
|
@@ -64,11 +65,11 @@ terminators = [
|
|
64 |
llm_tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
65 |
]
|
66 |
|
67 |
-
### TTS model
|
68 |
pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device)
|
69 |
|
70 |
|
71 |
-
|
72 |
def convert_audio(audio_filepath, tmpdir, utt_id):
|
73 |
"""
|
74 |
Convert all files to monochannel 16 kHz wav files.
|
@@ -99,8 +100,8 @@ def convert_audio(audio_filepath, tmpdir, utt_id):
|
|
99 |
|
100 |
def transcribe(audio_filepath):
|
101 |
"""
|
102 |
-
Transcribes a converted audio file.
|
103 |
-
Set to english language with punctuations.
|
104 |
Returns the transcribed text as a string.
|
105 |
"""
|
106 |
|
@@ -136,15 +137,15 @@ def transcribe(audio_filepath):
|
|
136 |
def add_message(history, message):
|
137 |
"""
|
138 |
Adds the input message in the chatbot.
|
139 |
-
Returns the updated chatbot
|
140 |
"""
|
141 |
history.append((message, None))
|
142 |
return history
|
143 |
|
144 |
def bot(history, message):
|
145 |
"""
|
146 |
-
Gets the bot's response and
|
147 |
-
Returns the appended chatbot
|
148 |
"""
|
149 |
response = bot_response(message, history)
|
150 |
lines = response.split("\n")
|
@@ -162,8 +163,8 @@ def bot(history, message):
|
|
162 |
@spaces.GPU()
|
163 |
def bot_response(message, history):
|
164 |
"""
|
165 |
-
Generates a streaming response using the
|
166 |
-
Set max_new_tokens =
|
167 |
Returns the generated response in string format.
|
168 |
"""
|
169 |
conversation = []
|
@@ -175,7 +176,7 @@ def bot_response(message, history):
|
|
175 |
|
176 |
outputs = llama3_model.generate(
|
177 |
input_ids,
|
178 |
-
max_new_tokens =
|
179 |
eos_token_id = terminators,
|
180 |
do_sample=True,
|
181 |
temperature=0.6,
|
@@ -190,7 +191,7 @@ def bot_response(message, history):
|
|
190 |
@spaces.GPU()
|
191 |
def voice_player(history):
|
192 |
"""
|
193 |
-
Plays the generated response using the
|
194 |
Returns the audio player with the generated response.
|
195 |
"""
|
196 |
_, text = history[-1]
|
@@ -205,7 +206,9 @@ def voice_player(history):
|
|
205 |
visible=True)
|
206 |
return voice
|
207 |
|
|
|
208 |
|
|
|
209 |
with gr.Blocks(
|
210 |
title="MyAlexa",
|
211 |
css="""
|
@@ -251,13 +254,13 @@ with gr.Blocks(
|
|
251 |
visible=False # set to True to see processing time of asr transcription
|
252 |
)
|
253 |
|
254 |
-
gr.HTML("<p><b>
|
255 |
|
256 |
out_audio = gr.Audio( # Shows an audio player for the generated response
|
257 |
value = None,
|
258 |
-
label="Response
|
259 |
show_label=True,
|
260 |
-
visible=False # set to True to see processing time of
|
261 |
)
|
262 |
|
263 |
chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot], api_name="add_message_in_chatbot")
|
@@ -270,6 +273,7 @@ with gr.Blocks(
|
|
270 |
outputs = [chat_input]
|
271 |
)
|
272 |
|
|
|
273 |
demo.queue()
|
274 |
if __name__ == "__main__":
|
275 |
demo.launch()
|
|
|
18 |
|
19 |
from transformers import pipeline
|
20 |
|
21 |
+
#### Variables ###
|
22 |
+
|
23 |
# Set an environment variable
|
24 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
25 |
|
|
|
26 |
SAMPLE_RATE = 16000 # Hz
|
27 |
MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
|
28 |
DESCRIPTION = '''
|
|
|
30 |
<h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
|
31 |
<p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
|
32 |
<p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://huggingface.co/kakao-enterprise/vits-ljs"><b>VITS-ljs by Kakao Enterprise</b></a> for text to speech (TTS).</p>
|
33 |
+
<p>This demo accepts audio inputs not more than 40 seconds long. Transcription and responses are limited to the English language.</p>
|
34 |
+
<p>The LLM max_new_tokens, temperature and top_p are set to 512, 0.6 and 0.9 respectively</p>
|
35 |
</div>
|
36 |
'''
|
37 |
PLACEHOLDER = """
|
|
|
43 |
|
44 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
|
46 |
+
### ASR model ###
|
47 |
canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
|
48 |
canary_model.eval()
|
49 |
# make sure beam size always 1 for consistency
|
|
|
52 |
decoding_cfg.beam.beam_size = 1
|
53 |
canary_model.change_decoding_strategy(decoding_cfg)
|
54 |
|
55 |
+
### LLM model ###
|
56 |
# Load the tokenizer and model
|
57 |
llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
58 |
llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
|
|
|
65 |
llm_tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
66 |
]
|
67 |
|
68 |
+
### TTS model ###
|
69 |
pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device)
|
70 |
|
71 |
|
72 |
+
### Start of functions ###
|
73 |
def convert_audio(audio_filepath, tmpdir, utt_id):
|
74 |
"""
|
75 |
Convert all files to monochannel 16 kHz wav files.
|
|
|
100 |
|
101 |
def transcribe(audio_filepath):
|
102 |
"""
|
103 |
+
Transcribes a converted audio file using the asr model.
|
104 |
+
Set to the english language with punctuations.
|
105 |
Returns the transcribed text as a string.
|
106 |
"""
|
107 |
|
|
|
137 |
def add_message(history, message):
|
138 |
"""
|
139 |
Adds the input message in the chatbot.
|
140 |
+
Returns the updated chatbot.
|
141 |
"""
|
142 |
history.append((message, None))
|
143 |
return history
|
144 |
|
145 |
def bot(history, message):
|
146 |
"""
|
147 |
+
Gets the bot's response and adds it in the chatbot.
|
148 |
+
Returns the appended chatbot.
|
149 |
"""
|
150 |
response = bot_response(message, history)
|
151 |
lines = response.split("\n")
|
|
|
163 |
@spaces.GPU()
|
164 |
def bot_response(message, history):
|
165 |
"""
|
166 |
+
Generates a streaming response using the llm model.
|
167 |
+
Set max_new_tokens = 512, temperature=0.6, and top_p=0.9
|
168 |
Returns the generated response in string format.
|
169 |
"""
|
170 |
conversation = []
|
|
|
176 |
|
177 |
outputs = llama3_model.generate(
|
178 |
input_ids,
|
179 |
+
max_new_tokens = 512,
|
180 |
eos_token_id = terminators,
|
181 |
do_sample=True,
|
182 |
temperature=0.6,
|
|
|
191 |
@spaces.GPU()
|
192 |
def voice_player(history):
|
193 |
"""
|
194 |
+
Plays the generated response using the tts model.
|
195 |
Returns the audio player with the generated response.
|
196 |
"""
|
197 |
_, text = history[-1]
|
|
|
206 |
visible=True)
|
207 |
return voice
|
208 |
|
209 |
+
### End of functions ###
|
210 |
|
211 |
+
### Interface using Blocks###
|
212 |
with gr.Blocks(
|
213 |
title="MyAlexa",
|
214 |
css="""
|
|
|
254 |
visible=False # set to True to see processing time of asr transcription
|
255 |
)
|
256 |
|
257 |
+
gr.HTML("<p><b>[Optional]:</b> Replay MyAlexa's voice response.</p>")
|
258 |
|
259 |
out_audio = gr.Audio( # Shows an audio player for the generated response
|
260 |
value = None,
|
261 |
+
label="Response Audio Player",
|
262 |
show_label=True,
|
263 |
+
visible=False # set to True to see processing time of the first tts audio generation
|
264 |
)
|
265 |
|
266 |
chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot], api_name="add_message_in_chatbot")
|
|
|
273 |
outputs = [chat_input]
|
274 |
)
|
275 |
|
276 |
+
### Queue and launch the demo ###
|
277 |
demo.queue()
|
278 |
if __name__ == "__main__":
|
279 |
demo.launch()
|