sasan commited on
Commit
fc97145
·
1 Parent(s): c46f121

chore: Update TTS functionality with MeloTTS and handle import errors

Browse files
Files changed (1) hide show
  1. space.py +492 -0
space.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain.memory import ChatMessageHistory
3
+ from langchain.tools import tool
4
+ from langchain_core.utils.function_calling import convert_to_openai_tool
5
+ from loguru import logger
6
+ import spaces
7
+
8
+ from kitt.core import tts_gradio
9
+ from kitt.core import utils as kitt_utils
10
+ from kitt.core import voice_options
11
+ from kitt.core.model import generate_function_call as process_query
12
+ from kitt.core.stt import save_and_transcribe_audio
13
+ from kitt.core.tts import prep_for_tts, run_melo_tts, run_tts_replicate
14
+ from kitt.skills import (
15
+ code_interpreter,
16
+ date_time_info,
17
+ do_anything_else,
18
+ extract_func_args,
19
+ find_route,
20
+ get_forecast,
21
+ get_weather,
22
+ get_weather_current_location,
23
+ search_along_route_w_coordinates,
24
+ search_points_of_interest,
25
+ set_vehicle_destination,
26
+ set_vehicle_speed,
27
+ )
28
+ from kitt.skills.common import config, vehicle
29
+ from kitt.skills.routing import calculate_route, find_address
30
+
31
+ ORIGIN = "Luxembourg, Luxembourg"
32
+ DESTINATION = "Paris, France"
33
+ DEFAULT_LLM_BACKEND = "local"
34
+ ENABLE_HISTORY = True
35
+ ENABLE_TTS = True
36
+ TTS_BACKEND = "local"
37
+ USER_PREFERENCES = "User prefers italian food."
38
+
39
+ global_context = {
40
+ "vehicle": vehicle,
41
+ "query": "How is the weather?",
42
+ "route_points": [],
43
+ "origin": ORIGIN,
44
+ "destination": DESTINATION,
45
+ "enable_history": ENABLE_HISTORY,
46
+ "tts_enabled": ENABLE_TTS,
47
+ "tts_backend": TTS_BACKEND,
48
+ "llm_backend": DEFAULT_LLM_BACKEND,
49
+ "map_origin": ORIGIN,
50
+ "map_destination": DESTINATION,
51
+ "update_proxy": 0,
52
+ "map": None,
53
+ }
54
+
55
+ speaker_embedding_cache = {}
56
+ history = ChatMessageHistory()
57
+
58
+
59
+ # Generate options for hours (00-23)
60
+ hour_options = [f"{i:02d}:00:00" for i in range(24)]
61
+
62
+
63
+ @tool
64
+ def search_along_route(query=""):
65
+ """Search for points of interest along the route/way to the destination.
66
+
67
+ Args:
68
+ query (str, optional): The type of point of interest to search for. Defaults to "restaurant".
69
+
70
+ """
71
+ points = global_context["route_points"]
72
+ # maybe reshape
73
+ return search_along_route_w_coordinates(points, query)
74
+
75
+
76
+ def set_time(time_picker):
77
+ vehicle.time = time_picker
78
+ return vehicle
79
+
80
+
81
+ functions = [
82
+ # set_vehicle_speed,
83
+ set_vehicle_destination,
84
+ get_weather,
85
+ find_route,
86
+ search_points_of_interest,
87
+ search_along_route,
88
+ ]
89
+ openai_tools = [convert_to_openai_tool(tool) for tool in functions]
90
+
91
+
92
+ def clear_history():
93
+ logger.info("Clearing the conversation history...")
94
+ history.clear()
95
+
96
+
97
+ @spaces.GPU
98
+ def run_llama3_model(query, voice_character, state):
99
+
100
+ assert len(functions) > 0, "No functions to call"
101
+ assert len(openai_tools) > 0, "No openai tools to call"
102
+
103
+ output_text = process_query(
104
+ query,
105
+ history=history,
106
+ user_preferences=state["user_preferences"],
107
+ tools=openai_tools,
108
+ functions=functions,
109
+ backend=state["llm_backend"],
110
+ )
111
+ gr.Info(f"Output text: {output_text}\nGenerating voice output...")
112
+ output_text_tts = prep_for_tts(output_text)
113
+ voice_out = None
114
+ if global_context["tts_enabled"]:
115
+ if "Fast" in voice_character:
116
+ voice_out = run_melo_tts(output_text_tts, voice_character)
117
+ elif global_context["tts_backend"] == "replicate":
118
+ voice_out = run_tts_replicate(output_text_tts, voice_character)
119
+ else:
120
+ voice_out = tts_gradio(
121
+ output_text_tts, voice_character, speaker_embedding_cache
122
+ )[0]
123
+ return (
124
+ output_text,
125
+ voice_out,
126
+ )
127
+
128
+
129
+ def run_model(query, voice_character, state):
130
+ model = state.get("model", "llama3")
131
+ query = query.strip().replace("'", "")
132
+ logger.info(
133
+ f"Running model: {model} with query: {query}, voice_character: {voice_character} and llm_backend: {state['llm_backend']}, tts_enabled: {state['tts_enabled']}"
134
+ )
135
+ global_context["query"] = query
136
+ text, voice = run_llama3_model(query, voice_character, state)
137
+
138
+ if not state["enable_history"]:
139
+ history.clear()
140
+ global_context["update_proxy"] += 1
141
+
142
+ return (
143
+ text,
144
+ voice,
145
+ vehicle.model_dump(),
146
+ state,
147
+ dict(update_proxy=global_context["update_proxy"]),
148
+ )
149
+
150
+
151
+ def calculate_route_gradio(origin, destination):
152
+ _, points = calculate_route(origin, destination)
153
+ plot = kitt_utils.plot_route(points, vehicle=vehicle.location_coordinates)
154
+ global_context["map"] = plot
155
+ global_context["route_points"] = points
156
+ # state.value["route_points"] = points
157
+ vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
158
+ return plot, vehicle.model_dump(), 0
159
+
160
+
161
+ def update_vehicle_status(trip_progress, origin, destination, state):
162
+ if not global_context["route_points"]:
163
+ _, points = calculate_route(origin, destination)
164
+ global_context["route_points"] = points
165
+ global_context["destination"] = destination
166
+ global_context["route_points"] = global_context["route_points"]
167
+ n_points = len(global_context["route_points"])
168
+ index = min(int(trip_progress / 100 * n_points), n_points - 1)
169
+ logger.info(f"Trip progress: {trip_progress} len: {n_points}, index: {index}")
170
+ new_coords = global_context["route_points"][index]
171
+ new_coords = new_coords["latitude"], new_coords["longitude"]
172
+ logger.info(
173
+ f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}"
174
+ )
175
+ vehicle.location_coordinates = new_coords
176
+ new_vehicle_location = find_address(new_coords[0], new_coords[1])
177
+ vehicle.location = new_vehicle_location
178
+ plot = kitt_utils.plot_route(
179
+ global_context["route_points"], vehicle=vehicle.location_coordinates
180
+ )
181
+ return vehicle, plot, state
182
+
183
+
184
+ def save_and_transcribe_run_model(audio, voice_character, state):
185
+ text = save_and_transcribe_audio(audio)
186
+ out_text, out_voice, vehicle_status, state, update_proxy = run_model(
187
+ text, voice_character, state
188
+ )
189
+ return None, text, out_text, out_voice, vehicle_status, state, update_proxy
190
+
191
+
192
+ def set_tts_enabled(tts_enabled, state):
193
+ new_tts_enabled = tts_enabled == "Yes"
194
+ logger.info(
195
+ f"TTS enabled was {state['tts_enabled']} and changed to {new_tts_enabled}"
196
+ )
197
+ state["tts_enabled"] = new_tts_enabled
198
+ global_context["tts_enabled"] = new_tts_enabled
199
+ return state
200
+
201
+
202
+ def set_llm_backend(llm_backend, state):
203
+ assert llm_backend in ["Ollama", "Replicate", "Local"], "Invalid LLM backend"
204
+ new_llm_backend = llm_backend.lower()
205
+ logger.info(
206
+ f"LLM backend was {state['llm_backend']} and changed to {new_llm_backend}"
207
+ )
208
+ state["llm_backend"] = new_llm_backend
209
+ global_context["llm_backend"] = new_llm_backend
210
+ return state
211
+
212
+
213
+ def set_user_preferences(preferences, state):
214
+ new_preferences = preferences
215
+ logger.info(f"User preferences changed to: {new_preferences}")
216
+ state["user_preferences"] = new_preferences
217
+ global_context["user_preferences"] = new_preferences
218
+ return state
219
+
220
+
221
+ def set_enable_history(enable_history, state):
222
+ new_enable_history = enable_history == "Yes"
223
+ logger.info(
224
+ f"Enable history was {state['enable_history']} and changed to {new_enable_history}"
225
+ )
226
+ state["enable_history"] = new_enable_history
227
+ global_context["enable_history"] = new_enable_history
228
+ return state
229
+
230
+
231
+ def set_tts_backend(tts_backend, state):
232
+ new_tts_backend = tts_backend.lower()
233
+ logger.info(
234
+ f"TTS backend was {state['tts_backend']} and changed to {new_tts_backend}"
235
+ )
236
+ state["tts_backend"] = new_tts_backend
237
+ global_context["tts_backend"] = new_tts_backend
238
+ return state
239
+
240
+
241
+ def conditional_update():
242
+ if global_context["destination"] != vehicle.destination:
243
+ global_context["destination"] = vehicle.destination
244
+
245
+ if global_context["origin"] != vehicle.location:
246
+ global_context["origin"] = vehicle.location
247
+
248
+ if (
249
+ global_context["map_origin"] != vehicle.location
250
+ or global_context["map_destination"] != vehicle.destination
251
+ or global_context["update_proxy"] == 0
252
+ ):
253
+ logger.info(f"Updating the map plot... in conditional_update")
254
+ map_plot, _, _ = calculate_route_gradio(vehicle.location, vehicle.destination)
255
+ global_context["map"] = map_plot
256
+ return global_context["map"]
257
+
258
+
259
+ # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
260
+ # in "Insecure origins treated as secure", enable it and relaunch chrome
261
+
262
+ # example question:
263
+ # what's the weather like outside?
264
+ # What's the closest restaurant from here?
265
+
266
+
267
+ def create_demo(tts_server: bool = False, model="llama3"):
268
+ print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
269
+ with gr.Blocks(theme=gr.themes.Default(), title="KITT") as demo:
270
+ state = gr.State(
271
+ value={
272
+ # "context": initial_context,
273
+ "query": "",
274
+ "route_points": [],
275
+ "model": model,
276
+ "tts_enabled": ENABLE_TTS,
277
+ "llm_backend": DEFAULT_LLM_BACKEND,
278
+ "user_preferences": USER_PREFERENCES,
279
+ "enable_history": ENABLE_HISTORY,
280
+ "tts_backend": TTS_BACKEND,
281
+ "destination": DESTINATION,
282
+ }
283
+ )
284
+
285
+ plot, _, _ = calculate_route_gradio(ORIGIN, DESTINATION)
286
+ global_context["map"] = plot
287
+
288
+ with gr.Row():
289
+ # with gr.Row():
290
+ # gr.Text("KITT", interactive=False)
291
+ with gr.Column(scale=1, min_width=300):
292
+ vehicle_status = gr.JSON(
293
+ value=vehicle.model_dump(), label="Vehicle status"
294
+ )
295
+ time_picker = gr.Dropdown(
296
+ choices=hour_options,
297
+ label="What time is it? (HH:MM)",
298
+ value="08:00:00",
299
+ interactive=True,
300
+ )
301
+ voice_character = gr.Radio(
302
+ choices=voice_options,
303
+ label="Choose a voice",
304
+ value=voice_options[0],
305
+ show_label=True,
306
+ )
307
+ # voice_character = gr.Textbox(
308
+ # label="Choose a voice",
309
+ # value="freeman",
310
+ # show_label=True,
311
+ # )
312
+ origin = gr.Textbox(
313
+ value=ORIGIN,
314
+ label="Origin",
315
+ interactive=True,
316
+ )
317
+ destination = gr.Textbox(
318
+ value=DESTINATION,
319
+ label="Destination",
320
+ interactive=True,
321
+ )
322
+ preferences = gr.Textbox(
323
+ value=USER_PREFERENCES,
324
+ label="User preferences",
325
+ lines=3,
326
+ interactive=True,
327
+ )
328
+
329
+ with gr.Column(scale=2, min_width=600):
330
+ map_plot = gr.Plot(value=plot, label="Map")
331
+ trip_progress = gr.Slider(
332
+ 0, 100, step=5, label="Trip progress", interactive=True
333
+ )
334
+
335
+ # with gr.Column(scale=1, min_width=300):
336
+ # gr.Image("linkedin-1.png", label="Linkedin - Sasan Jafarnejad")
337
+ # gr.Image(
338
+ # "team-ubix.png",
339
+ # label="Research Team - UBIX - University of Luxembourg",
340
+ # )
341
+
342
+ with gr.Row():
343
+ with gr.Column():
344
+ input_audio = gr.Audio(
345
+ type="numpy",
346
+ sources=["microphone"],
347
+ label="Input audio",
348
+ elem_id="input_audio",
349
+ )
350
+ input_text = gr.Textbox(
351
+ value="How is the weather?", label="Input text", interactive=True
352
+ )
353
+ with gr.Accordion("Debug"):
354
+ input_audio_debug = gr.Audio(
355
+ type="numpy",
356
+ sources=["microphone"],
357
+ label="Input audio",
358
+ elem_id="input_audio",
359
+ )
360
+ input_text_debug = gr.Textbox(
361
+ value="How is the weather?",
362
+ label="Input text",
363
+ interactive=True,
364
+ )
365
+ update_proxy = gr.JSON(
366
+ value=dict(update_proxy=0),
367
+ label="Global context",
368
+ )
369
+ with gr.Accordion("Config"):
370
+ tts_enabled = gr.Radio(
371
+ ["Yes", "No"],
372
+ label="Enable TTS",
373
+ value="Yes" if ENABLE_TTS else "No",
374
+ interactive=True,
375
+ )
376
+ tts_backend = gr.Radio(
377
+ ["Local"],
378
+ label="TTS Backend",
379
+ value=TTS_BACKEND.title(),
380
+ interactive=True,
381
+ )
382
+ llm_backend = gr.Radio(
383
+ choices=["Ollama", "Local"],
384
+ label="LLM Backend",
385
+ value=DEFAULT_LLM_BACKEND.title(),
386
+ interactive=True,
387
+ )
388
+ enable_history = gr.Radio(
389
+ ["Yes", "No"],
390
+ label="Maintain the conversation history?",
391
+ value="Yes" if ENABLE_HISTORY else "No",
392
+ interactive=True,
393
+ )
394
+ # Push button
395
+ clear_history_btn = gr.Button(value="Clear History")
396
+ with gr.Column():
397
+ output_audio = gr.Audio(label="output audio", autoplay=True)
398
+ output_text = gr.TextArea(
399
+ value="", label="Output text", interactive=False
400
+ )
401
+
402
+ # Update plot based on the origin and destination
403
+ # Sets the current location and destination
404
+ origin.submit(
405
+ fn=calculate_route_gradio,
406
+ inputs=[origin, destination],
407
+ outputs=[map_plot, vehicle_status, trip_progress],
408
+ )
409
+ destination.submit(
410
+ fn=calculate_route_gradio,
411
+ inputs=[origin, destination],
412
+ outputs=[map_plot, vehicle_status, trip_progress],
413
+ )
414
+ preferences.submit(
415
+ fn=set_user_preferences, inputs=[preferences, state], outputs=[state]
416
+ )
417
+
418
+ # Update time based on the time picker
419
+ time_picker.select(fn=set_time, inputs=[time_picker], outputs=[vehicle_status])
420
+
421
+ # Run the model if the input text is changed
422
+ input_text.submit(
423
+ fn=run_model,
424
+ inputs=[input_text, voice_character, state],
425
+ outputs=[output_text, output_audio, vehicle_status, state, update_proxy],
426
+ )
427
+ input_text_debug.submit(
428
+ fn=run_model,
429
+ inputs=[input_text_debug, voice_character, state],
430
+ outputs=[output_text, output_audio, vehicle_status, state, update_proxy],
431
+ )
432
+
433
+ # Set the vehicle status based on the trip progress
434
+ trip_progress.release(
435
+ fn=update_vehicle_status,
436
+ inputs=[trip_progress, origin, destination, state],
437
+ outputs=[vehicle_status, map_plot, state],
438
+ )
439
+
440
+ # Save and transcribe the audio
441
+ input_audio.stop_recording(
442
+ fn=save_and_transcribe_run_model,
443
+ inputs=[input_audio, voice_character, state],
444
+ outputs=[
445
+ input_audio,
446
+ input_text,
447
+ output_text,
448
+ output_audio,
449
+ vehicle_status,
450
+ state,
451
+ update_proxy,
452
+ ],
453
+ )
454
+ input_audio_debug.stop_recording(
455
+ fn=save_and_transcribe_audio,
456
+ inputs=[input_audio_debug],
457
+ outputs=[input_text_debug],
458
+ )
459
+
460
+ # Clear the history
461
+ clear_history_btn.click(fn=clear_history, inputs=[], outputs=[])
462
+
463
+ # Config
464
+ tts_enabled.change(
465
+ fn=set_tts_enabled, inputs=[tts_enabled, state], outputs=[state]
466
+ )
467
+ tts_backend.change(
468
+ fn=set_tts_backend, inputs=[tts_backend, state], outputs=[state]
469
+ )
470
+ llm_backend.change(
471
+ fn=set_llm_backend, inputs=[llm_backend, state], outputs=[state]
472
+ )
473
+ enable_history.change(
474
+ fn=set_enable_history, inputs=[enable_history, state], outputs=[state]
475
+ )
476
+ update_proxy.change(fn=conditional_update, inputs=[], outputs=[map_plot])
477
+
478
+ return demo
479
+
480
+
481
+ # close all interfaces open to make the port available
482
+ gr.close_all()
483
+
484
+
485
+ demo = create_demo(False, "llama3")
486
+ demo.launch(
487
+ debug=True,
488
+ server_name="0.0.0.0",
489
+ server_port=7860,
490
+ ssl_verify=False,
491
+ share=False,
492
+ )