VanYsa commited on
Commit
7771452
·
1 Parent(s): 5817424

test LLM 3

Browse files
Files changed (1) hide show
  1. app.py +201 -25
app.py CHANGED
@@ -1,24 +1,154 @@
1
  import gradio as gr
2
- import spaces
3
- import torch
4
-
 
 
 
5
  import transformers
6
  import torch
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
 
 
 
 
 
 
 
10
 
11
  pipeline = transformers.pipeline(
12
  "text-generation",
13
- model=model_name,
14
  model_kwargs={"torch_dtype": torch.bfloat16},
15
  device="cuda",
16
  )
17
 
18
- @spaces.GPU
19
- def chat_function(message, history, system_prompt,max_new_tokens,temperature):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  messages = [
21
- {"role": "system", "content": system_prompt},
22
  {"role": "user", "content": message},
23
  ]
24
  prompt = pipeline.tokenizer.apply_chat_template(
@@ -40,20 +170,66 @@ def chat_function(message, history, system_prompt,max_new_tokens,temperature):
40
  top_p=0.9,
41
  )
42
  return outputs[0]["generated_text"][len(prompt):]
 
43
 
44
- gr.ChatInterface(
45
- chat_function,
46
- chatbot=gr.Chatbot(height=400),
47
- textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
48
- title="LLAMA 3 8B Chat",
49
- description="""
50
- This space is dedicated for chatting with Meta's Latest LLM - Llama 8b Instruct. Find this model here: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
51
- Feel free to play with customization in the "Additional Inputs".
52
- """,
53
- theme="soft",
54
- additional_inputs=[
55
- gr.Textbox("You are helpful AI.", label="System Prompt"),
56
- gr.Slider(512, 4096, label="Max New Tokens"),
57
- gr.Slider(0, 1, label="Temperature")
58
- ]
59
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import json
3
+ import librosa
4
+ import os
5
+ import soundfile as sf
6
+ import tempfile
7
+ import uuid
8
  import transformers
9
  import torch
10
+ import time
11
+ import spaces
12
+
13
+ from nemo.collections.asr.models import ASRModel
14
+
15
+ from transformers import GemmaTokenizer, AutoModelForCausalLM
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
17
+ from threading import Thread
18
+
19
+ # Set an environment variable
20
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
21
+
22
+
23
+ SAMPLE_RATE = 16000 # Hz
24
+ MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
25
+ DESCRIPTION = '''
26
+ <div>
27
+ <h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
28
+ <p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
29
+ <p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://https://huggingface.co/docs/transformers/en/model_doc/vits"><b>VITS</b></a> for text to speech (TTS).</p>
30
+ <p>This demo accepts audio inputs not more than 40 seconds long.</p>
31
+ <p>Transcription and responses are limited to the English language.</p>
32
+ </div>
33
+ '''
34
+ PLACEHOLDER = """
35
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
36
+ <img src="https://i.ibb.co/S35q17Q/My-Alexa-Logo.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
37
+ <p style="font-size: 28px; margin-bottom: 2px; opacity: 0.65;">What's on your mind?</p>
38
+ </div>
39
+ """
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+
43
+ ### ASR model
44
+ canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
45
+ canary_model.eval()
46
 
47
+ # make sure beam size always 1 for consistency
48
+ canary_model.change_decoding_strategy(None)
49
+ decoding_cfg = canary_model.cfg.decoding
50
+ decoding_cfg.beam.beam_size = 1
51
+ canary_model.change_decoding_strategy(decoding_cfg)
52
+
53
+ ### LLM model
54
+ llm_model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
55
 
56
  pipeline = transformers.pipeline(
57
  "text-generation",
58
+ model=llm_model_name,
59
  model_kwargs={"torch_dtype": torch.bfloat16},
60
  device="cuda",
61
  )
62
 
63
+ def convert_audio(audio_filepath, tmpdir, utt_id):
64
+ """
65
+ Convert all files to monochannel 16 kHz wav files.
66
+ Do not convert and raise error if audio is too long.
67
+ Returns output filename and duration.
68
+ """
69
+
70
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
71
+
72
+ duration = librosa.get_duration(y=data, sr=sr)
73
+
74
+ if duration > MAX_AUDIO_SECONDS:
75
+ raise gr.Error(
76
+ f"This demo can transcribe up to {MAX_AUDIO_SECONDS} seconds of audio. "
77
+ "If you wish, you may trim the audio using the Audio viewer in Step 1 "
78
+ "(click on the scissors icon to start trimming audio)."
79
+ )
80
+
81
+ if sr != SAMPLE_RATE:
82
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
83
+
84
+ out_filename = os.path.join(tmpdir, utt_id + '.wav')
85
+
86
+ # save output audio
87
+ sf.write(out_filename, data, SAMPLE_RATE)
88
+
89
+ return out_filename, duration
90
+
91
+ def transcribe(audio_filepath):
92
+ """
93
+ Transcribes a converted audio file.
94
+ Set to english language with punctuations.
95
+ Returns the output text.
96
+ """
97
+
98
+ if audio_filepath is None:
99
+ raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
100
+
101
+ utt_id = uuid.uuid4()
102
+ with tempfile.TemporaryDirectory() as tmpdir:
103
+ converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
104
+
105
+ # make manifest file and save
106
+ manifest_data = {
107
+ "audio_filepath": converted_audio_filepath,
108
+ "source_lang": "en",
109
+ "target_lang": "en",
110
+ "taskname": "asr",
111
+ "pnc": "yes",
112
+ "answer": "predict",
113
+ "duration": str(duration),
114
+ }
115
+
116
+ manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
117
+
118
+ with open(manifest_filepath, 'w') as fout:
119
+ line = json.dumps(manifest_data)
120
+ fout.write(line + '\n')
121
+
122
+ # call transcribe, passing in manifest filepath
123
+ output_text = canary_model.transcribe(manifest_filepath)[0]
124
+
125
+ return output_text
126
+
127
+ def add_message(history, message):
128
+ """
129
+ Adds the input message in the chatbot.
130
+ Returns the updated chatbot with an empty input textbox.
131
+ """
132
+ history.append((message, None))
133
+ return history
134
+
135
+ def bot(history,message):
136
+ """
137
+ Prints the LLM's response in the chatbot
138
+ """
139
+ response = bot_response(message, history, 0.7, 100)
140
+ #response = "bot_response(message)"
141
+ history[-1][1] = ""
142
+ for character in response:
143
+ history[-1][1] += character
144
+ time.sleep(0.05)
145
+ yield history
146
+
147
+
148
+ @spaces.GPU()
149
+ def bot_response(message, history, max_new_tokens, temperature):
150
  messages = [
151
+ {"role": "system", "content": "You are a helpful AI assistant."},
152
  {"role": "user", "content": message},
153
  ]
154
  prompt = pipeline.tokenizer.apply_chat_template(
 
170
  top_p=0.9,
171
  )
172
  return outputs[0]["generated_text"][len(prompt):]
173
+
174
 
175
+ with gr.Blocks(
176
+ title="MyAlexa",
177
+ css="""
178
+ textarea { font-size: 18px;}
179
+ """,
180
+ theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
181
+ ) as demo:
182
+
183
+ gr.HTML(DESCRIPTION)
184
+ chatbot = gr.Chatbot(
185
+ [],
186
+ elem_id="chatbot",
187
+ bubble_full_width=False,
188
+ placeholder=PLACEHOLDER,
189
+ label='MyAlexa'
190
+ )
191
+ with gr.Row():
192
+ with gr.Column():
193
+ gr.HTML(
194
+ "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
195
+ )
196
+
197
+ audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
198
+
199
+
200
+ with gr.Column():
201
+
202
+ gr.HTML("<p><b>Step 2:</b> Enter audio as input and wait for MyAlexa's response.</p>")
203
+
204
+ submit_button = gr.Button(
205
+ value="Submit audio",
206
+ variant="primary"
207
+ )
208
+
209
+ chat_input = gr.Textbox(
210
+ label="Transcribed text:",
211
+ interactive=False,
212
+ placeholder="Enter message",
213
+ elem_id="chat_input",
214
+ visible=True
215
+ )
216
+ gr.HTML("<p><b>Step 2:</b> Enter audio as input and wait for MyAlexa's response.</p>")
217
+
218
+ submit_button = gr.Button(
219
+ value="Submit audio",
220
+ variant="primary"
221
+ )
222
+
223
+ chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot])
224
+ bot_msg = chat_msg.then(bot, [chatbot, chat_input], chatbot, api_name="bot_response")
225
+ # bot_msg.then(lambda: gr.Textbox(interactive=False), None, [chat_input])
226
+
227
+ submit_button.click(
228
+ fn=transcribe,
229
+ inputs = [audio_file],
230
+ outputs = [chat_input]
231
+ )
232
+
233
+ demo.queue()
234
+ if __name__ == "__main__":
235
+ demo.launch()