sasan commited on
Commit
60ee11d
·
1 Parent(s): 5dbfb3f

chore: A new more advanced method

Browse files
kitt/core/model.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import uuid
4
+
5
+ from langchain.memory import ChatMessageHistory
6
+ from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
7
+ from langchain_core.utils.function_calling import convert_to_openai_function
8
+ import ollama
9
+ from pydantic import BaseModel
10
+ from loguru import logger
11
+
12
+
13
+ from kitt.skills import vehicle_status
14
+
15
+
16
+
17
+ class FunctionCall(BaseModel):
18
+ arguments: dict
19
+ """
20
+ The arguments to call the function with, as generated by the model in JSON
21
+ format. Note that the model does not always generate valid JSON, and may
22
+ hallucinate parameters not defined by your function schema. Validate the
23
+ arguments in your code before calling your function.
24
+ """
25
+
26
+ name: str
27
+ """The name of the function to call."""
28
+
29
+ schema_json = json.loads(FunctionCall.schema_json())
30
+ HRMS_SYSTEM_PROMPT = """<|begin_of_text|>
31
+ <|im_start|>system
32
+ You are a function calling AI agent with self-recursion.
33
+ You can call only one function at a time and analyse data you get from function response.
34
+ You are provided with function signatures within <tools></tools> XML tags.
35
+ {car_status}
36
+
37
+ You may use agentic frameworks for reasoning and planning to help with user query.
38
+ Please call a function and wait for function results to be provided to you in the next iteration.
39
+ Don't make assumptions about what values to plug into function arguments.
40
+ Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
41
+ Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
42
+ Analyze the data once you get the results and call another function.
43
+ At each iteration please continue adding the your analysis to previous summary.
44
+ Your final response should directly answer the user query.
45
+
46
+
47
+ Here are the available tools:
48
+ <tools> {tools} </tools>
49
+ If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
50
+ <tool_call>
51
+ {{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
52
+ </tool_call>
53
+ Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
54
+ When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
55
+
56
+ When asked for the weather, lookup the weather for the current location of the car. Unless the user provides a location, then use that location.
57
+ If asked about points of interest, use the tools available to you. Do not make up points of interest.
58
+
59
+ Use the following pydantic model json schema for each tool call you will make:
60
+ {schema}
61
+
62
+ At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
63
+ Please keep a running summary with analysis of previous function results and summaries from previous iterations.
64
+ Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
65
+ If you plan to continue with analysis, always call another function.
66
+ For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
67
+ <tool_call>
68
+ {{"arguments": <args-dict>, "name": <function-name>}}
69
+ </tool_call>
70
+ <|im_end|>"""
71
+ AI_PREAMBLE = """
72
+ <|im_start|>assistant
73
+ """
74
+ HRMS_TEMPLATE_USER = """
75
+ <|im_start|>user
76
+ {user_input}<|im_end|>"""
77
+ HRMS_TEMPLATE_ASSISTANT = """
78
+ <|im_start|>assistant
79
+ {assistant_response}<|im_end|>"""
80
+ HRMS_TEMPLATE_TOOL_RESULT = """
81
+ <|im_start|>tool
82
+ {result}
83
+ <|im_end|>"""
84
+
85
+
86
+ def append_message(prompt, h):
87
+ if h.type == "human":
88
+ prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
89
+ elif h.type == "ai":
90
+ prompt += HRMS_TEMPLATE_ASSISTANT.format(assistant_response=h.content)
91
+ elif h.type == "tool":
92
+ prompt += HRMS_TEMPLATE_TOOL_RESULT.format(result=h.content)
93
+ return prompt
94
+
95
+
96
+ def get_prompt(template, history, tools, schema, car_status=None):
97
+ if not car_status:
98
+ # car_status = vehicle.dict()
99
+ car_status = vehicle_status()[0]
100
+
101
+ # "vehicle_status": vehicle_status_fn()[0]
102
+ kwargs = {"history": history, "schema": schema, "tools": tools, "car_status": car_status}
103
+
104
+
105
+ prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
106
+
107
+ if history:
108
+ for h in history.messages:
109
+ prompt = append_message(prompt, h)
110
+
111
+ # if input:
112
+ # prompt += USER_QUERY_TEMPLATE.format(user_input=input)
113
+ return prompt
114
+
115
+
116
+ def use_tool(tool_call, tools):
117
+ func_name = tool_call["name"]
118
+ kwargs = tool_call["arguments"]
119
+ for tool in tools:
120
+ if tool.name == func_name:
121
+ return tool.invoke(input=kwargs)
122
+ return None
123
+
124
+
125
+ def parse_tool_calls(text):
126
+ logger.debug(f"Start parsing tool_calls: {text}")
127
+ pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
128
+
129
+ if not text.startswith("<tool_call>"):
130
+ return [], []
131
+
132
+ matches = re.findall(pattern, text, re.DOTALL)
133
+ tool_calls = []
134
+ errors = []
135
+ for match in matches:
136
+ try:
137
+ tool_call = json.loads(match)
138
+ tool_calls.append(tool_call)
139
+ except json.JSONDecodeError as e:
140
+ errors.append(f"Invalid JSON in tool call: {e}")
141
+
142
+ logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
143
+ return tool_calls, errors
144
+
145
+
146
+ def process_response(user_query, res, history, tools, depth):
147
+ """Returns True if the response contains tool calls, False otherwise."""
148
+ logger.debug(f"Processing response: {res}")
149
+ tool_calls, errors = parse_tool_calls(res)
150
+ # TODO: Handle errors
151
+ if not tool_calls:
152
+ return False
153
+ # tool_results = ""
154
+ tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
155
+ for tool_call in tool_calls:
156
+ # TODO: Extra Validation
157
+ # Call the function
158
+ try:
159
+ result = use_tool(tool_call, tools)
160
+ if type(result) == tuple:
161
+ result = result[1]
162
+ tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
163
+ except Exception as e:
164
+ print(e)
165
+ # Currently only to mimic OpneAI's behavior
166
+ # But it could be used for tracking function calls
167
+
168
+ tool_results = tool_results.strip()
169
+ print(f"Tool results: {tool_results}")
170
+ tool_call_id = uuid.uuid4().hex
171
+ history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
172
+ return True
173
+
174
+
175
+ def run_inference_step(history, tools, schema_json, dry_run=False):
176
+ # If we decide to call a function, we need to generate the prompt for the model
177
+ # based on the history of the conversation so far.
178
+ # not break the loop
179
+ openai_tools = [convert_to_openai_function(tool) for tool in tools]
180
+ prompt = get_prompt(HRMS_SYSTEM_PROMPT, history, openai_tools, schema_json)
181
+ print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
182
+
183
+ data = {
184
+ "prompt": prompt + AI_PREAMBLE,
185
+ # "streaming": False,
186
+ # "model": "smangrul/llama-3-8b-instruct-function-calling",
187
+ # "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
188
+ # "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
189
+ "model": "interstellarninja/hermes-2-pro-llama-3-8b",
190
+ "raw": True,
191
+ "options": {"temperature": 0.8,
192
+ # "max_tokens": 1500,
193
+ "num_predict": 1500,
194
+ # "num_predict": 1500,
195
+ # "max_tokens": 1500,
196
+ }
197
+ }
198
+
199
+ if dry_run:
200
+ print(prompt + AI_PREAMBLE)
201
+ return "Didn't really run it."
202
+
203
+ out = ollama.generate(**data)
204
+
205
+ res = out["response"]
206
+
207
+ return res
208
+
209
+
210
+ def process_query(user_query: str, history: ChatMessageHistory, tools):
211
+ history.add_message(HumanMessage(content=user_query))
212
+ for depth in range(10):
213
+ out = run_inference_step(history, tools, schema_json)
214
+ print(f"Inference step result:\n{out}\n------------------\n")
215
+ history.add_message(AIMessage(content=out))
216
+ if not process_response(user_query, out, history, tools, depth):
217
+ print(f"This is the answer, no more iterations: {out}")
218
+ return out
219
+ # Otherwise, tools result is already added to history, we just need to continue the loop.
220
+ # If we get here something went wrong.
221
+ history.add_message(
222
+ AIMessage(content="Sorry, I am not sure how to help you with that.")
223
+ )
224
+ return "Sorry, I am not sure how to help you with that."
kitt/core/utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional, Union
2
+
3
+
4
+ def plot_route(points, vehicle: Union[tuple[float, float], None] = None):
5
+ import plotly.express as px
6
+
7
+ lats = []
8
+ lons = []
9
+
10
+ for point in points:
11
+ lats.append(point["latitude"])
12
+ lons.append(point["longitude"])
13
+ # fig = px.line_geo(lat=lats, lon=lons)
14
+ # fig.update_geos(fitbounds="locations")
15
+
16
+ fig = px.line_mapbox(
17
+ lat=lats, lon=lons, zoom=12, height=600, color_discrete_sequence=["red"]
18
+ )
19
+
20
+ if vehicle:
21
+ fig.add_trace(
22
+ px.scatter_mapbox(
23
+ lat=[vehicle[0]],
24
+ lon=[vehicle[1]],
25
+ color_discrete_sequence=["blue"],
26
+ ).data[0]
27
+ )
28
+
29
+ fig.update_layout(
30
+ mapbox_style="open-street-map",
31
+ # mapbox_zoom=12,
32
+ )
33
+ fig.update_geos(fitbounds="locations")
34
+ fig.update_layout(margin={"r": 20, "t": 20, "l": 20, "b": 20})
35
+ return fig
kitt/skills/__init__.py CHANGED
@@ -2,7 +2,7 @@ from datetime import datetime
2
  import inspect
3
 
4
  from .common import execute_function_call, extract_func_args, vehicle as vehicle_obj
5
- from .weather import get_weather, get_forecast
6
  from .routing import find_route
7
  from .poi import search_points_of_interests, search_along_route_w_coordinates
8
  from .vehicle import vehicle_status
 
2
  import inspect
3
 
4
  from .common import execute_function_call, extract_func_args, vehicle as vehicle_obj
5
+ from .weather import get_weather_current_location, get_weather, get_forecast
6
  from .routing import find_route
7
  from .poi import search_points_of_interests, search_along_route_w_coordinates
8
  from .vehicle import vehicle_status
kitt/skills/routing.py CHANGED
@@ -17,31 +17,6 @@ def find_coordinates(address):
17
  return lat, lon
18
 
19
 
20
- def plot_route(points):
21
- import plotly.express as px
22
-
23
- lats = []
24
- lons = []
25
-
26
- for point in points:
27
- lats.append(point["latitude"])
28
- lons.append(point["longitude"])
29
- # fig = px.line_geo(lat=lats, lon=lons)
30
- # fig.update_geos(fitbounds="locations")
31
-
32
- fig = px.line_mapbox(
33
- lat=lats, lon=lons, zoom=12, height=600, color_discrete_sequence=["red"]
34
- )
35
-
36
- fig.update_layout(
37
- mapbox_style="open-street-map",
38
- # mapbox_zoom=12,
39
- )
40
- fig.update_geos(fitbounds="locations")
41
- fig.update_layout(margin={"r": 20, "t": 20, "l": 20, "b": 20})
42
- return fig
43
-
44
-
45
  def calculate_route(origin, destination):
46
  """This function is called when the origin or destination is updated in the GUI. It calculates the route between the origin and destination."""
47
  print(f"calculate_route(origin: {origin}, destination: {destination})")
@@ -64,7 +39,7 @@ def calculate_route(origin, destination):
64
  data = response.json()
65
  points = data["routes"][0]["legs"][0]["points"]
66
 
67
- return plot_route(points), vehicle.model_dump_json(), points
68
 
69
 
70
  def find_route_tomtom(
 
17
  return lat, lon
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def calculate_route(origin, destination):
21
  """This function is called when the origin or destination is updated in the GUI. It calculates the route between the origin and destination."""
22
  print(f"calculate_route(origin: {origin}, destination: {destination})")
 
39
  data = response.json()
40
  points = data["routes"][0]["legs"][0]["points"]
41
 
42
+ return vehicle.model_dump_json(), points
43
 
44
 
45
  def find_route_tomtom(
kitt/skills/vehicle.py CHANGED
@@ -2,8 +2,11 @@ from .common import vehicle
2
 
3
 
4
  STATUS_TEMPLATE = """
5
- We are at {location}, coordinates: {lat}, {lon},
6
- current time: {time}, current date: {date} and our destination is: {destination}.
 
 
 
7
  """
8
 
9
 
 
2
 
3
 
4
  STATUS_TEMPLATE = """
5
+ The current location is:{location}
6
+ The current Geo coordinates: {lat}, {lon}
7
+ The current time: {time}
8
+ The current date: {date}
9
+ The current destination is: {destination}
10
  """
11
 
12
 
kitt/skills/weather.py CHANGED
@@ -3,12 +3,32 @@ import requests
3
  from .common import config, vehicle
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  # current weather API
7
  def get_weather(location: str = ""):
8
  """
9
- Returns the CURRENT weather in a specified location.
 
 
10
  Args:
11
- location (string) : Required. The name of the location, could be a city or lat/longitude in the following format latitude,longitude (example: 37.7749,-122.4194). If the location is not specified, the function will return the weather in the current location.
 
 
 
12
  """
13
 
14
  if location == "":
@@ -56,10 +76,10 @@ def get_weather(location: str = ""):
56
  # weather forecast API
57
  def get_forecast(city_name: str = "", when=0, **kwargs):
58
  """
59
- Returns the weather forecast in a specified number of days for a specified city .
60
  Args:
61
- city_name (string) : Required. The name of the city.
62
- when (int) : Required. in number of days (until the day for which we want to know the forecast) (example: tomorrow is 1, in two days is 2, etc.)
63
  """
64
 
65
  when += 1
 
3
  from .common import config, vehicle
4
 
5
 
6
+ def get_weather_current_location():
7
+ """
8
+ Returns the CURRENT weather in current location.
9
+ When responding to user, only mention the weather condition, temperature, and the temperature that it feels like, unless the user asks for more information.
10
+
11
+ Returns:
12
+ dict: The weather data in the specified location.
13
+ """
14
+ print(
15
+ f"get_weather: location is empty, using the vehicle location. ({vehicle.location})"
16
+ )
17
+ location = vehicle.location
18
+ return get_weather(location)
19
+
20
+
21
  # current weather API
22
  def get_weather(location: str = ""):
23
  """
24
+ Get the current weather in a specified location.
25
+ When responding to user, only mention the weather condition, temperature, and the temperature that it feels like, unless the user asks for more information.
26
+
27
  Args:
28
+ location (string) : Optional. The name of the location, if empty, the vehicle location is used.
29
+
30
+ Returns:
31
+ dict: The weather data in the specified location.
32
  """
33
 
34
  if location == "":
 
76
  # weather forecast API
77
  def get_forecast(city_name: str = "", when=0, **kwargs):
78
  """
79
+ Get the weather forecast in a specified number of days for a specified location.
80
  Args:
81
+ city_name (string) : Required. The name of the city.
82
+ when (int) : Required. in number of days (until the day for which we want to know the forecast) (example: tomorrow is 1, in two days is 2, etc.)
83
  """
84
 
85
  when += 1
main.py CHANGED
@@ -11,6 +11,10 @@ from kitt.skills.routing import calculate_route
11
  import ollama
12
 
13
  from langchain.tools.base import StructuredTool
 
 
 
 
14
 
15
  from kitt.skills import (
16
  get_weather,
@@ -21,9 +25,12 @@ from kitt.skills import (
21
  search_along_route_w_coordinates,
22
  do_anything_else,
23
  date_time_info,
 
24
  )
25
  from kitt.skills import extract_func_args
26
  from kitt.core import voice_options, tts_gradio
 
 
27
 
28
 
29
  global_context = {
@@ -33,6 +40,7 @@ global_context = {
33
  }
34
 
35
  speaker_embedding_cache = {}
 
36
 
37
  MODEL_FUNC = "nexusraven"
38
  MODEL_GENERAL = "llama3:instruct"
@@ -111,11 +119,12 @@ def get_vehicle_status(state):
111
  tools = [
112
  StructuredTool.from_function(get_weather),
113
  StructuredTool.from_function(find_route),
114
- # StructuredTool.from_function(vehicle_status),
115
  StructuredTool.from_function(search_points_of_interests),
116
  StructuredTool.from_function(search_along_route),
117
  StructuredTool.from_function(date_time_info),
118
- StructuredTool.from_function(do_anything_else),
 
119
  ]
120
 
121
 
@@ -133,6 +142,9 @@ def run_generic_model(query):
133
  return out["response"]
134
 
135
 
 
 
 
136
 
137
  def run_nexusraven_model(query, voice_character):
138
  global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
@@ -169,36 +181,13 @@ def run_nexusraven_model(query, voice_character):
169
 
170
 
171
  def run_llama3_model(query, voice_character):
172
- global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
173
- print("Prompt: ", global_context["prompt"])
174
- data = {
175
- "prompt": global_context["prompt"],
176
- # "streaming": False,
177
- # "model": "smangrul/llama-3-8b-instruct-function-calling",
178
- "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
179
- "raw": True,
180
- "options": {"temperature": 0.5, "stop": ["\nReflection:", "\nThought:"]},
181
- }
182
- out = ollama.generate(**data)
183
- llm_response = out["response"]
184
- if "Call: " in llm_response:
185
- print(f"llm_response: {llm_response}")
186
- llm_response = llm_response.replace("<bot_end>", " ")
187
- func_name, kwargs = extract_func_args(llm_response)
188
- print(f"Function: {func_name}, Args: {kwargs}")
189
- if func_name == "do_anything_else":
190
- output_text = run_generic_model(query)
191
- else:
192
- output_text = use_tool(func_name, kwargs, tools)
193
- else:
194
- output_text = out["response"]
195
-
196
- if type(output_text) == tuple:
197
- output_text = output_text[0]
198
  gr.Info(f"Output text: {output_text}, generating voice output...")
 
 
199
  return (
200
  output_text,
201
- tts_gradio(output_text, voice_character, speaker_embedding_cache)[0],
202
  )
203
 
204
 
@@ -216,22 +205,28 @@ def run_model(query, voice_character, state):
216
 
217
 
218
  def calculate_route_gradio(origin, destination):
219
- plot, vehicle_status, points = calculate_route(origin, destination)
 
220
  global_context["route_points"] = points
221
  vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
222
- return plot, vehicle_status
223
 
224
 
225
- def update_vehicle_status(trip_progress):
 
 
 
226
  n_points = len(global_context["route_points"])
227
- new_coords = global_context["route_points"][
228
- min(int(trip_progress / 100 * n_points), n_points - 1)
229
- ]
230
  new_coords = new_coords["latitude"], new_coords["longitude"]
231
  print(f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}")
232
  vehicle.location_coordinates = new_coords
233
  vehicle.location = ""
234
- return vehicle.model_dump_json()
 
 
235
 
236
 
237
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -278,7 +273,7 @@ def save_and_transcribe_audio(audio):
278
  # What's the closest restaurant from here?
279
 
280
 
281
- def create_demo(tts_server: bool = False, model="llama3"):
282
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
283
  with gr.Blocks(theme=gr.themes.Default()) as demo:
284
  state = gr.State(
@@ -287,6 +282,7 @@ def create_demo(tts_server: bool = False, model="llama3"):
287
  "query": "",
288
  "route_points": [],
289
  "model": model,
 
290
  }
291
  )
292
  trip_points = gr.State(value=[])
@@ -344,6 +340,8 @@ def create_demo(tts_server: bool = False, model="llama3"):
344
  vehicle_status = gr.JSON(
345
  value=vehicle.model_dump_json(), label="Vehicle status"
346
  )
 
 
347
  with gr.Column():
348
  output_audio = gr.Audio(label="output audio", autoplay=True)
349
  output_text = gr.TextArea(
@@ -355,12 +353,12 @@ def create_demo(tts_server: bool = False, model="llama3"):
355
  origin.submit(
356
  fn=calculate_route_gradio,
357
  inputs=[origin, destination],
358
- outputs=[map_plot, vehicle_status],
359
  )
360
  destination.submit(
361
  fn=calculate_route_gradio,
362
  inputs=[origin, destination],
363
- outputs=[map_plot, vehicle_status],
364
  )
365
 
366
  # Update time based on the time picker
@@ -375,13 +373,17 @@ def create_demo(tts_server: bool = False, model="llama3"):
375
 
376
  # Set the vehicle status based on the trip progress
377
  trip_progress.release(
378
- fn=update_vehicle_status, inputs=[trip_progress], outputs=[vehicle_status]
379
  )
380
 
381
  # Save and transcribe the audio
382
  input_audio.stop_recording(
383
  fn=save_and_transcribe_audio, inputs=[input_audio], outputs=[input_text]
384
  )
 
 
 
 
385
  return demo
386
 
387
 
@@ -389,7 +391,7 @@ def create_demo(tts_server: bool = False, model="llama3"):
389
  gr.close_all()
390
 
391
 
392
- demo = create_demo(False, "llama3")
393
  demo.launch(
394
  debug=True,
395
  server_name="0.0.0.0",
 
11
  import ollama
12
 
13
  from langchain.tools.base import StructuredTool
14
+ from langchain.memory import ChatMessageHistory
15
+ from langchain_core.utils.function_calling import convert_to_openai_function
16
+ from loguru import logger
17
+
18
 
19
  from kitt.skills import (
20
  get_weather,
 
25
  search_along_route_w_coordinates,
26
  do_anything_else,
27
  date_time_info,
28
+ get_weather_current_location
29
  )
30
  from kitt.skills import extract_func_args
31
  from kitt.core import voice_options, tts_gradio
32
+ from kitt.core.model import process_query
33
+ from kitt.core import utils as kitt_utils
34
 
35
 
36
  global_context = {
 
40
  }
41
 
42
  speaker_embedding_cache = {}
43
+ history = ChatMessageHistory()
44
 
45
  MODEL_FUNC = "nexusraven"
46
  MODEL_GENERAL = "llama3:instruct"
 
119
  tools = [
120
  StructuredTool.from_function(get_weather),
121
  StructuredTool.from_function(find_route),
122
+ StructuredTool.from_function(vehicle_status_fn),
123
  StructuredTool.from_function(search_points_of_interests),
124
  StructuredTool.from_function(search_along_route),
125
  StructuredTool.from_function(date_time_info),
126
+ StructuredTool.from_function(get_weather_current_location),
127
+ # StructuredTool.from_function(do_anything_else),
128
  ]
129
 
130
 
 
142
  return out["response"]
143
 
144
 
145
+ def clear_history():
146
+ history.clear()
147
+
148
 
149
  def run_nexusraven_model(query, voice_character):
150
  global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
 
181
 
182
 
183
  def run_llama3_model(query, voice_character):
184
+ output_text = process_query(query, history, tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  gr.Info(f"Output text: {output_text}, generating voice output...")
186
+ # voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
187
+ voice_out = None
188
  return (
189
  output_text,
190
+ voice_out,
191
  )
192
 
193
 
 
205
 
206
 
207
  def calculate_route_gradio(origin, destination):
208
+ vehicle_status, points = calculate_route(origin, destination)
209
+ plot = kitt_utils.plot_route(points, vehicle=vehicle.location_coordinates)
210
  global_context["route_points"] = points
211
  vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
212
+ return plot, vehicle_status, 0
213
 
214
 
215
+ def update_vehicle_status(trip_progress, origin, destination):
216
+ if not global_context["route_points"]:
217
+ vehicle_status, points = calculate_route(origin, destination)
218
+ global_context["route_points"] = points
219
  n_points = len(global_context["route_points"])
220
+ index = min(int(trip_progress / 100 * n_points), n_points - 1)
221
+ print(f"Trip progress: {trip_progress} len: {n_points}, index: {index}")
222
+ new_coords = global_context["route_points"][index]
223
  new_coords = new_coords["latitude"], new_coords["longitude"]
224
  print(f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}")
225
  vehicle.location_coordinates = new_coords
226
  vehicle.location = ""
227
+
228
+ plot = kitt_utils.plot_route(global_context["route_points"], vehicle=vehicle.location_coordinates)
229
+ return vehicle.model_dump_json(), plot
230
 
231
 
232
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
273
  # What's the closest restaurant from here?
274
 
275
 
276
+ def create_demo(tts_server: bool = False, model="llama3", tts=True):
277
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
278
  with gr.Blocks(theme=gr.themes.Default()) as demo:
279
  state = gr.State(
 
282
  "query": "",
283
  "route_points": [],
284
  "model": model,
285
+ "tts": tts,
286
  }
287
  )
288
  trip_points = gr.State(value=[])
 
340
  vehicle_status = gr.JSON(
341
  value=vehicle.model_dump_json(), label="Vehicle status"
342
  )
343
+ # Push button
344
+ clear_history_btn = gr.Button(value="Clear History")
345
  with gr.Column():
346
  output_audio = gr.Audio(label="output audio", autoplay=True)
347
  output_text = gr.TextArea(
 
353
  origin.submit(
354
  fn=calculate_route_gradio,
355
  inputs=[origin, destination],
356
+ outputs=[map_plot, vehicle_status, trip_progress],
357
  )
358
  destination.submit(
359
  fn=calculate_route_gradio,
360
  inputs=[origin, destination],
361
+ outputs=[map_plot, vehicle_status, trip_progress],
362
  )
363
 
364
  # Update time based on the time picker
 
373
 
374
  # Set the vehicle status based on the trip progress
375
  trip_progress.release(
376
+ fn=update_vehicle_status, inputs=[trip_progress, origin, destination], outputs=[vehicle_status, map_plot]
377
  )
378
 
379
  # Save and transcribe the audio
380
  input_audio.stop_recording(
381
  fn=save_and_transcribe_audio, inputs=[input_audio], outputs=[input_text]
382
  )
383
+
384
+ # Clear the history
385
+ clear_history_btn.click(fn=clear_history, inputs=[], outputs=[])
386
+
387
  return demo
388
 
389
 
 
391
  gr.close_all()
392
 
393
 
394
+ demo = create_demo(False, "llama3", tts=False)
395
  demo.launch(
396
  debug=True,
397
  server_name="0.0.0.0",