chore: A new more advanced method
Browse files- kitt/core/model.py +224 -0
- kitt/core/utils.py +35 -0
- kitt/skills/__init__.py +1 -1
- kitt/skills/routing.py +1 -26
- kitt/skills/vehicle.py +5 -2
- kitt/skills/weather.py +25 -5
- main.py +43 -41
kitt/core/model.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import uuid
|
4 |
+
|
5 |
+
from langchain.memory import ChatMessageHistory
|
6 |
+
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
|
7 |
+
from langchain_core.utils.function_calling import convert_to_openai_function
|
8 |
+
import ollama
|
9 |
+
from pydantic import BaseModel
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
+
|
13 |
+
from kitt.skills import vehicle_status
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class FunctionCall(BaseModel):
|
18 |
+
arguments: dict
|
19 |
+
"""
|
20 |
+
The arguments to call the function with, as generated by the model in JSON
|
21 |
+
format. Note that the model does not always generate valid JSON, and may
|
22 |
+
hallucinate parameters not defined by your function schema. Validate the
|
23 |
+
arguments in your code before calling your function.
|
24 |
+
"""
|
25 |
+
|
26 |
+
name: str
|
27 |
+
"""The name of the function to call."""
|
28 |
+
|
29 |
+
schema_json = json.loads(FunctionCall.schema_json())
|
30 |
+
HRMS_SYSTEM_PROMPT = """<|begin_of_text|>
|
31 |
+
<|im_start|>system
|
32 |
+
You are a function calling AI agent with self-recursion.
|
33 |
+
You can call only one function at a time and analyse data you get from function response.
|
34 |
+
You are provided with function signatures within <tools></tools> XML tags.
|
35 |
+
{car_status}
|
36 |
+
|
37 |
+
You may use agentic frameworks for reasoning and planning to help with user query.
|
38 |
+
Please call a function and wait for function results to be provided to you in the next iteration.
|
39 |
+
Don't make assumptions about what values to plug into function arguments.
|
40 |
+
Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
|
41 |
+
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
|
42 |
+
Analyze the data once you get the results and call another function.
|
43 |
+
At each iteration please continue adding the your analysis to previous summary.
|
44 |
+
Your final response should directly answer the user query.
|
45 |
+
|
46 |
+
|
47 |
+
Here are the available tools:
|
48 |
+
<tools> {tools} </tools>
|
49 |
+
If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
|
50 |
+
<tool_call>
|
51 |
+
{{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
|
52 |
+
</tool_call>
|
53 |
+
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
|
54 |
+
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
|
55 |
+
|
56 |
+
When asked for the weather, lookup the weather for the current location of the car. Unless the user provides a location, then use that location.
|
57 |
+
If asked about points of interest, use the tools available to you. Do not make up points of interest.
|
58 |
+
|
59 |
+
Use the following pydantic model json schema for each tool call you will make:
|
60 |
+
{schema}
|
61 |
+
|
62 |
+
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
|
63 |
+
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
|
64 |
+
Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
|
65 |
+
If you plan to continue with analysis, always call another function.
|
66 |
+
For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
67 |
+
<tool_call>
|
68 |
+
{{"arguments": <args-dict>, "name": <function-name>}}
|
69 |
+
</tool_call>
|
70 |
+
<|im_end|>"""
|
71 |
+
AI_PREAMBLE = """
|
72 |
+
<|im_start|>assistant
|
73 |
+
"""
|
74 |
+
HRMS_TEMPLATE_USER = """
|
75 |
+
<|im_start|>user
|
76 |
+
{user_input}<|im_end|>"""
|
77 |
+
HRMS_TEMPLATE_ASSISTANT = """
|
78 |
+
<|im_start|>assistant
|
79 |
+
{assistant_response}<|im_end|>"""
|
80 |
+
HRMS_TEMPLATE_TOOL_RESULT = """
|
81 |
+
<|im_start|>tool
|
82 |
+
{result}
|
83 |
+
<|im_end|>"""
|
84 |
+
|
85 |
+
|
86 |
+
def append_message(prompt, h):
|
87 |
+
if h.type == "human":
|
88 |
+
prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
|
89 |
+
elif h.type == "ai":
|
90 |
+
prompt += HRMS_TEMPLATE_ASSISTANT.format(assistant_response=h.content)
|
91 |
+
elif h.type == "tool":
|
92 |
+
prompt += HRMS_TEMPLATE_TOOL_RESULT.format(result=h.content)
|
93 |
+
return prompt
|
94 |
+
|
95 |
+
|
96 |
+
def get_prompt(template, history, tools, schema, car_status=None):
|
97 |
+
if not car_status:
|
98 |
+
# car_status = vehicle.dict()
|
99 |
+
car_status = vehicle_status()[0]
|
100 |
+
|
101 |
+
# "vehicle_status": vehicle_status_fn()[0]
|
102 |
+
kwargs = {"history": history, "schema": schema, "tools": tools, "car_status": car_status}
|
103 |
+
|
104 |
+
|
105 |
+
prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
|
106 |
+
|
107 |
+
if history:
|
108 |
+
for h in history.messages:
|
109 |
+
prompt = append_message(prompt, h)
|
110 |
+
|
111 |
+
# if input:
|
112 |
+
# prompt += USER_QUERY_TEMPLATE.format(user_input=input)
|
113 |
+
return prompt
|
114 |
+
|
115 |
+
|
116 |
+
def use_tool(tool_call, tools):
|
117 |
+
func_name = tool_call["name"]
|
118 |
+
kwargs = tool_call["arguments"]
|
119 |
+
for tool in tools:
|
120 |
+
if tool.name == func_name:
|
121 |
+
return tool.invoke(input=kwargs)
|
122 |
+
return None
|
123 |
+
|
124 |
+
|
125 |
+
def parse_tool_calls(text):
|
126 |
+
logger.debug(f"Start parsing tool_calls: {text}")
|
127 |
+
pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
|
128 |
+
|
129 |
+
if not text.startswith("<tool_call>"):
|
130 |
+
return [], []
|
131 |
+
|
132 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
133 |
+
tool_calls = []
|
134 |
+
errors = []
|
135 |
+
for match in matches:
|
136 |
+
try:
|
137 |
+
tool_call = json.loads(match)
|
138 |
+
tool_calls.append(tool_call)
|
139 |
+
except json.JSONDecodeError as e:
|
140 |
+
errors.append(f"Invalid JSON in tool call: {e}")
|
141 |
+
|
142 |
+
logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
|
143 |
+
return tool_calls, errors
|
144 |
+
|
145 |
+
|
146 |
+
def process_response(user_query, res, history, tools, depth):
|
147 |
+
"""Returns True if the response contains tool calls, False otherwise."""
|
148 |
+
logger.debug(f"Processing response: {res}")
|
149 |
+
tool_calls, errors = parse_tool_calls(res)
|
150 |
+
# TODO: Handle errors
|
151 |
+
if not tool_calls:
|
152 |
+
return False
|
153 |
+
# tool_results = ""
|
154 |
+
tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
|
155 |
+
for tool_call in tool_calls:
|
156 |
+
# TODO: Extra Validation
|
157 |
+
# Call the function
|
158 |
+
try:
|
159 |
+
result = use_tool(tool_call, tools)
|
160 |
+
if type(result) == tuple:
|
161 |
+
result = result[1]
|
162 |
+
tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
|
163 |
+
except Exception as e:
|
164 |
+
print(e)
|
165 |
+
# Currently only to mimic OpneAI's behavior
|
166 |
+
# But it could be used for tracking function calls
|
167 |
+
|
168 |
+
tool_results = tool_results.strip()
|
169 |
+
print(f"Tool results: {tool_results}")
|
170 |
+
tool_call_id = uuid.uuid4().hex
|
171 |
+
history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
|
172 |
+
return True
|
173 |
+
|
174 |
+
|
175 |
+
def run_inference_step(history, tools, schema_json, dry_run=False):
|
176 |
+
# If we decide to call a function, we need to generate the prompt for the model
|
177 |
+
# based on the history of the conversation so far.
|
178 |
+
# not break the loop
|
179 |
+
openai_tools = [convert_to_openai_function(tool) for tool in tools]
|
180 |
+
prompt = get_prompt(HRMS_SYSTEM_PROMPT, history, openai_tools, schema_json)
|
181 |
+
print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
|
182 |
+
|
183 |
+
data = {
|
184 |
+
"prompt": prompt + AI_PREAMBLE,
|
185 |
+
# "streaming": False,
|
186 |
+
# "model": "smangrul/llama-3-8b-instruct-function-calling",
|
187 |
+
# "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
|
188 |
+
# "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
|
189 |
+
"model": "interstellarninja/hermes-2-pro-llama-3-8b",
|
190 |
+
"raw": True,
|
191 |
+
"options": {"temperature": 0.8,
|
192 |
+
# "max_tokens": 1500,
|
193 |
+
"num_predict": 1500,
|
194 |
+
# "num_predict": 1500,
|
195 |
+
# "max_tokens": 1500,
|
196 |
+
}
|
197 |
+
}
|
198 |
+
|
199 |
+
if dry_run:
|
200 |
+
print(prompt + AI_PREAMBLE)
|
201 |
+
return "Didn't really run it."
|
202 |
+
|
203 |
+
out = ollama.generate(**data)
|
204 |
+
|
205 |
+
res = out["response"]
|
206 |
+
|
207 |
+
return res
|
208 |
+
|
209 |
+
|
210 |
+
def process_query(user_query: str, history: ChatMessageHistory, tools):
|
211 |
+
history.add_message(HumanMessage(content=user_query))
|
212 |
+
for depth in range(10):
|
213 |
+
out = run_inference_step(history, tools, schema_json)
|
214 |
+
print(f"Inference step result:\n{out}\n------------------\n")
|
215 |
+
history.add_message(AIMessage(content=out))
|
216 |
+
if not process_response(user_query, out, history, tools, depth):
|
217 |
+
print(f"This is the answer, no more iterations: {out}")
|
218 |
+
return out
|
219 |
+
# Otherwise, tools result is already added to history, we just need to continue the loop.
|
220 |
+
# If we get here something went wrong.
|
221 |
+
history.add_message(
|
222 |
+
AIMessage(content="Sorry, I am not sure how to help you with that.")
|
223 |
+
)
|
224 |
+
return "Sorry, I am not sure how to help you with that."
|
kitt/core/utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Optional, Union
|
2 |
+
|
3 |
+
|
4 |
+
def plot_route(points, vehicle: Union[tuple[float, float], None] = None):
|
5 |
+
import plotly.express as px
|
6 |
+
|
7 |
+
lats = []
|
8 |
+
lons = []
|
9 |
+
|
10 |
+
for point in points:
|
11 |
+
lats.append(point["latitude"])
|
12 |
+
lons.append(point["longitude"])
|
13 |
+
# fig = px.line_geo(lat=lats, lon=lons)
|
14 |
+
# fig.update_geos(fitbounds="locations")
|
15 |
+
|
16 |
+
fig = px.line_mapbox(
|
17 |
+
lat=lats, lon=lons, zoom=12, height=600, color_discrete_sequence=["red"]
|
18 |
+
)
|
19 |
+
|
20 |
+
if vehicle:
|
21 |
+
fig.add_trace(
|
22 |
+
px.scatter_mapbox(
|
23 |
+
lat=[vehicle[0]],
|
24 |
+
lon=[vehicle[1]],
|
25 |
+
color_discrete_sequence=["blue"],
|
26 |
+
).data[0]
|
27 |
+
)
|
28 |
+
|
29 |
+
fig.update_layout(
|
30 |
+
mapbox_style="open-street-map",
|
31 |
+
# mapbox_zoom=12,
|
32 |
+
)
|
33 |
+
fig.update_geos(fitbounds="locations")
|
34 |
+
fig.update_layout(margin={"r": 20, "t": 20, "l": 20, "b": 20})
|
35 |
+
return fig
|
kitt/skills/__init__.py
CHANGED
@@ -2,7 +2,7 @@ from datetime import datetime
|
|
2 |
import inspect
|
3 |
|
4 |
from .common import execute_function_call, extract_func_args, vehicle as vehicle_obj
|
5 |
-
from .weather import get_weather, get_forecast
|
6 |
from .routing import find_route
|
7 |
from .poi import search_points_of_interests, search_along_route_w_coordinates
|
8 |
from .vehicle import vehicle_status
|
|
|
2 |
import inspect
|
3 |
|
4 |
from .common import execute_function_call, extract_func_args, vehicle as vehicle_obj
|
5 |
+
from .weather import get_weather_current_location, get_weather, get_forecast
|
6 |
from .routing import find_route
|
7 |
from .poi import search_points_of_interests, search_along_route_w_coordinates
|
8 |
from .vehicle import vehicle_status
|
kitt/skills/routing.py
CHANGED
@@ -17,31 +17,6 @@ def find_coordinates(address):
|
|
17 |
return lat, lon
|
18 |
|
19 |
|
20 |
-
def plot_route(points):
|
21 |
-
import plotly.express as px
|
22 |
-
|
23 |
-
lats = []
|
24 |
-
lons = []
|
25 |
-
|
26 |
-
for point in points:
|
27 |
-
lats.append(point["latitude"])
|
28 |
-
lons.append(point["longitude"])
|
29 |
-
# fig = px.line_geo(lat=lats, lon=lons)
|
30 |
-
# fig.update_geos(fitbounds="locations")
|
31 |
-
|
32 |
-
fig = px.line_mapbox(
|
33 |
-
lat=lats, lon=lons, zoom=12, height=600, color_discrete_sequence=["red"]
|
34 |
-
)
|
35 |
-
|
36 |
-
fig.update_layout(
|
37 |
-
mapbox_style="open-street-map",
|
38 |
-
# mapbox_zoom=12,
|
39 |
-
)
|
40 |
-
fig.update_geos(fitbounds="locations")
|
41 |
-
fig.update_layout(margin={"r": 20, "t": 20, "l": 20, "b": 20})
|
42 |
-
return fig
|
43 |
-
|
44 |
-
|
45 |
def calculate_route(origin, destination):
|
46 |
"""This function is called when the origin or destination is updated in the GUI. It calculates the route between the origin and destination."""
|
47 |
print(f"calculate_route(origin: {origin}, destination: {destination})")
|
@@ -64,7 +39,7 @@ def calculate_route(origin, destination):
|
|
64 |
data = response.json()
|
65 |
points = data["routes"][0]["legs"][0]["points"]
|
66 |
|
67 |
-
return
|
68 |
|
69 |
|
70 |
def find_route_tomtom(
|
|
|
17 |
return lat, lon
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def calculate_route(origin, destination):
|
21 |
"""This function is called when the origin or destination is updated in the GUI. It calculates the route between the origin and destination."""
|
22 |
print(f"calculate_route(origin: {origin}, destination: {destination})")
|
|
|
39 |
data = response.json()
|
40 |
points = data["routes"][0]["legs"][0]["points"]
|
41 |
|
42 |
+
return vehicle.model_dump_json(), points
|
43 |
|
44 |
|
45 |
def find_route_tomtom(
|
kitt/skills/vehicle.py
CHANGED
@@ -2,8 +2,11 @@ from .common import vehicle
|
|
2 |
|
3 |
|
4 |
STATUS_TEMPLATE = """
|
5 |
-
|
6 |
-
current
|
|
|
|
|
|
|
7 |
"""
|
8 |
|
9 |
|
|
|
2 |
|
3 |
|
4 |
STATUS_TEMPLATE = """
|
5 |
+
The current location is:{location}
|
6 |
+
The current Geo coordinates: {lat}, {lon}
|
7 |
+
The current time: {time}
|
8 |
+
The current date: {date}
|
9 |
+
The current destination is: {destination}
|
10 |
"""
|
11 |
|
12 |
|
kitt/skills/weather.py
CHANGED
@@ -3,12 +3,32 @@ import requests
|
|
3 |
from .common import config, vehicle
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
# current weather API
|
7 |
def get_weather(location: str = ""):
|
8 |
"""
|
9 |
-
|
|
|
|
|
10 |
Args:
|
11 |
-
|
|
|
|
|
|
|
12 |
"""
|
13 |
|
14 |
if location == "":
|
@@ -56,10 +76,10 @@ def get_weather(location: str = ""):
|
|
56 |
# weather forecast API
|
57 |
def get_forecast(city_name: str = "", when=0, **kwargs):
|
58 |
"""
|
59 |
-
|
60 |
Args:
|
61 |
-
|
62 |
-
|
63 |
"""
|
64 |
|
65 |
when += 1
|
|
|
3 |
from .common import config, vehicle
|
4 |
|
5 |
|
6 |
+
def get_weather_current_location():
|
7 |
+
"""
|
8 |
+
Returns the CURRENT weather in current location.
|
9 |
+
When responding to user, only mention the weather condition, temperature, and the temperature that it feels like, unless the user asks for more information.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
dict: The weather data in the specified location.
|
13 |
+
"""
|
14 |
+
print(
|
15 |
+
f"get_weather: location is empty, using the vehicle location. ({vehicle.location})"
|
16 |
+
)
|
17 |
+
location = vehicle.location
|
18 |
+
return get_weather(location)
|
19 |
+
|
20 |
+
|
21 |
# current weather API
|
22 |
def get_weather(location: str = ""):
|
23 |
"""
|
24 |
+
Get the current weather in a specified location.
|
25 |
+
When responding to user, only mention the weather condition, temperature, and the temperature that it feels like, unless the user asks for more information.
|
26 |
+
|
27 |
Args:
|
28 |
+
location (string) : Optional. The name of the location, if empty, the vehicle location is used.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
dict: The weather data in the specified location.
|
32 |
"""
|
33 |
|
34 |
if location == "":
|
|
|
76 |
# weather forecast API
|
77 |
def get_forecast(city_name: str = "", when=0, **kwargs):
|
78 |
"""
|
79 |
+
Get the weather forecast in a specified number of days for a specified location.
|
80 |
Args:
|
81 |
+
city_name (string) : Required. The name of the city.
|
82 |
+
when (int) : Required. in number of days (until the day for which we want to know the forecast) (example: tomorrow is 1, in two days is 2, etc.)
|
83 |
"""
|
84 |
|
85 |
when += 1
|
main.py
CHANGED
@@ -11,6 +11,10 @@ from kitt.skills.routing import calculate_route
|
|
11 |
import ollama
|
12 |
|
13 |
from langchain.tools.base import StructuredTool
|
|
|
|
|
|
|
|
|
14 |
|
15 |
from kitt.skills import (
|
16 |
get_weather,
|
@@ -21,9 +25,12 @@ from kitt.skills import (
|
|
21 |
search_along_route_w_coordinates,
|
22 |
do_anything_else,
|
23 |
date_time_info,
|
|
|
24 |
)
|
25 |
from kitt.skills import extract_func_args
|
26 |
from kitt.core import voice_options, tts_gradio
|
|
|
|
|
27 |
|
28 |
|
29 |
global_context = {
|
@@ -33,6 +40,7 @@ global_context = {
|
|
33 |
}
|
34 |
|
35 |
speaker_embedding_cache = {}
|
|
|
36 |
|
37 |
MODEL_FUNC = "nexusraven"
|
38 |
MODEL_GENERAL = "llama3:instruct"
|
@@ -111,11 +119,12 @@ def get_vehicle_status(state):
|
|
111 |
tools = [
|
112 |
StructuredTool.from_function(get_weather),
|
113 |
StructuredTool.from_function(find_route),
|
114 |
-
|
115 |
StructuredTool.from_function(search_points_of_interests),
|
116 |
StructuredTool.from_function(search_along_route),
|
117 |
StructuredTool.from_function(date_time_info),
|
118 |
-
StructuredTool.from_function(
|
|
|
119 |
]
|
120 |
|
121 |
|
@@ -133,6 +142,9 @@ def run_generic_model(query):
|
|
133 |
return out["response"]
|
134 |
|
135 |
|
|
|
|
|
|
|
136 |
|
137 |
def run_nexusraven_model(query, voice_character):
|
138 |
global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
|
@@ -169,36 +181,13 @@ def run_nexusraven_model(query, voice_character):
|
|
169 |
|
170 |
|
171 |
def run_llama3_model(query, voice_character):
|
172 |
-
|
173 |
-
print("Prompt: ", global_context["prompt"])
|
174 |
-
data = {
|
175 |
-
"prompt": global_context["prompt"],
|
176 |
-
# "streaming": False,
|
177 |
-
# "model": "smangrul/llama-3-8b-instruct-function-calling",
|
178 |
-
"model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
|
179 |
-
"raw": True,
|
180 |
-
"options": {"temperature": 0.5, "stop": ["\nReflection:", "\nThought:"]},
|
181 |
-
}
|
182 |
-
out = ollama.generate(**data)
|
183 |
-
llm_response = out["response"]
|
184 |
-
if "Call: " in llm_response:
|
185 |
-
print(f"llm_response: {llm_response}")
|
186 |
-
llm_response = llm_response.replace("<bot_end>", " ")
|
187 |
-
func_name, kwargs = extract_func_args(llm_response)
|
188 |
-
print(f"Function: {func_name}, Args: {kwargs}")
|
189 |
-
if func_name == "do_anything_else":
|
190 |
-
output_text = run_generic_model(query)
|
191 |
-
else:
|
192 |
-
output_text = use_tool(func_name, kwargs, tools)
|
193 |
-
else:
|
194 |
-
output_text = out["response"]
|
195 |
-
|
196 |
-
if type(output_text) == tuple:
|
197 |
-
output_text = output_text[0]
|
198 |
gr.Info(f"Output text: {output_text}, generating voice output...")
|
|
|
|
|
199 |
return (
|
200 |
output_text,
|
201 |
-
|
202 |
)
|
203 |
|
204 |
|
@@ -216,22 +205,28 @@ def run_model(query, voice_character, state):
|
|
216 |
|
217 |
|
218 |
def calculate_route_gradio(origin, destination):
|
219 |
-
|
|
|
220 |
global_context["route_points"] = points
|
221 |
vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
|
222 |
-
return plot, vehicle_status
|
223 |
|
224 |
|
225 |
-
def update_vehicle_status(trip_progress):
|
|
|
|
|
|
|
226 |
n_points = len(global_context["route_points"])
|
227 |
-
|
228 |
-
|
229 |
-
]
|
230 |
new_coords = new_coords["latitude"], new_coords["longitude"]
|
231 |
print(f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}")
|
232 |
vehicle.location_coordinates = new_coords
|
233 |
vehicle.location = ""
|
234 |
-
|
|
|
|
|
235 |
|
236 |
|
237 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -278,7 +273,7 @@ def save_and_transcribe_audio(audio):
|
|
278 |
# What's the closest restaurant from here?
|
279 |
|
280 |
|
281 |
-
def create_demo(tts_server: bool = False, model="llama3"):
|
282 |
print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
|
283 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
284 |
state = gr.State(
|
@@ -287,6 +282,7 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
287 |
"query": "",
|
288 |
"route_points": [],
|
289 |
"model": model,
|
|
|
290 |
}
|
291 |
)
|
292 |
trip_points = gr.State(value=[])
|
@@ -344,6 +340,8 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
344 |
vehicle_status = gr.JSON(
|
345 |
value=vehicle.model_dump_json(), label="Vehicle status"
|
346 |
)
|
|
|
|
|
347 |
with gr.Column():
|
348 |
output_audio = gr.Audio(label="output audio", autoplay=True)
|
349 |
output_text = gr.TextArea(
|
@@ -355,12 +353,12 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
355 |
origin.submit(
|
356 |
fn=calculate_route_gradio,
|
357 |
inputs=[origin, destination],
|
358 |
-
outputs=[map_plot, vehicle_status],
|
359 |
)
|
360 |
destination.submit(
|
361 |
fn=calculate_route_gradio,
|
362 |
inputs=[origin, destination],
|
363 |
-
outputs=[map_plot, vehicle_status],
|
364 |
)
|
365 |
|
366 |
# Update time based on the time picker
|
@@ -375,13 +373,17 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
375 |
|
376 |
# Set the vehicle status based on the trip progress
|
377 |
trip_progress.release(
|
378 |
-
fn=update_vehicle_status, inputs=[trip_progress], outputs=[vehicle_status]
|
379 |
)
|
380 |
|
381 |
# Save and transcribe the audio
|
382 |
input_audio.stop_recording(
|
383 |
fn=save_and_transcribe_audio, inputs=[input_audio], outputs=[input_text]
|
384 |
)
|
|
|
|
|
|
|
|
|
385 |
return demo
|
386 |
|
387 |
|
@@ -389,7 +391,7 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
389 |
gr.close_all()
|
390 |
|
391 |
|
392 |
-
demo = create_demo(False, "llama3")
|
393 |
demo.launch(
|
394 |
debug=True,
|
395 |
server_name="0.0.0.0",
|
|
|
11 |
import ollama
|
12 |
|
13 |
from langchain.tools.base import StructuredTool
|
14 |
+
from langchain.memory import ChatMessageHistory
|
15 |
+
from langchain_core.utils.function_calling import convert_to_openai_function
|
16 |
+
from loguru import logger
|
17 |
+
|
18 |
|
19 |
from kitt.skills import (
|
20 |
get_weather,
|
|
|
25 |
search_along_route_w_coordinates,
|
26 |
do_anything_else,
|
27 |
date_time_info,
|
28 |
+
get_weather_current_location
|
29 |
)
|
30 |
from kitt.skills import extract_func_args
|
31 |
from kitt.core import voice_options, tts_gradio
|
32 |
+
from kitt.core.model import process_query
|
33 |
+
from kitt.core import utils as kitt_utils
|
34 |
|
35 |
|
36 |
global_context = {
|
|
|
40 |
}
|
41 |
|
42 |
speaker_embedding_cache = {}
|
43 |
+
history = ChatMessageHistory()
|
44 |
|
45 |
MODEL_FUNC = "nexusraven"
|
46 |
MODEL_GENERAL = "llama3:instruct"
|
|
|
119 |
tools = [
|
120 |
StructuredTool.from_function(get_weather),
|
121 |
StructuredTool.from_function(find_route),
|
122 |
+
StructuredTool.from_function(vehicle_status_fn),
|
123 |
StructuredTool.from_function(search_points_of_interests),
|
124 |
StructuredTool.from_function(search_along_route),
|
125 |
StructuredTool.from_function(date_time_info),
|
126 |
+
StructuredTool.from_function(get_weather_current_location),
|
127 |
+
# StructuredTool.from_function(do_anything_else),
|
128 |
]
|
129 |
|
130 |
|
|
|
142 |
return out["response"]
|
143 |
|
144 |
|
145 |
+
def clear_history():
|
146 |
+
history.clear()
|
147 |
+
|
148 |
|
149 |
def run_nexusraven_model(query, voice_character):
|
150 |
global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
|
|
|
181 |
|
182 |
|
183 |
def run_llama3_model(query, voice_character):
|
184 |
+
output_text = process_query(query, history, tools)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
gr.Info(f"Output text: {output_text}, generating voice output...")
|
186 |
+
# voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
187 |
+
voice_out = None
|
188 |
return (
|
189 |
output_text,
|
190 |
+
voice_out,
|
191 |
)
|
192 |
|
193 |
|
|
|
205 |
|
206 |
|
207 |
def calculate_route_gradio(origin, destination):
|
208 |
+
vehicle_status, points = calculate_route(origin, destination)
|
209 |
+
plot = kitt_utils.plot_route(points, vehicle=vehicle.location_coordinates)
|
210 |
global_context["route_points"] = points
|
211 |
vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
|
212 |
+
return plot, vehicle_status, 0
|
213 |
|
214 |
|
215 |
+
def update_vehicle_status(trip_progress, origin, destination):
|
216 |
+
if not global_context["route_points"]:
|
217 |
+
vehicle_status, points = calculate_route(origin, destination)
|
218 |
+
global_context["route_points"] = points
|
219 |
n_points = len(global_context["route_points"])
|
220 |
+
index = min(int(trip_progress / 100 * n_points), n_points - 1)
|
221 |
+
print(f"Trip progress: {trip_progress} len: {n_points}, index: {index}")
|
222 |
+
new_coords = global_context["route_points"][index]
|
223 |
new_coords = new_coords["latitude"], new_coords["longitude"]
|
224 |
print(f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}")
|
225 |
vehicle.location_coordinates = new_coords
|
226 |
vehicle.location = ""
|
227 |
+
|
228 |
+
plot = kitt_utils.plot_route(global_context["route_points"], vehicle=vehicle.location_coordinates)
|
229 |
+
return vehicle.model_dump_json(), plot
|
230 |
|
231 |
|
232 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
273 |
# What's the closest restaurant from here?
|
274 |
|
275 |
|
276 |
+
def create_demo(tts_server: bool = False, model="llama3", tts=True):
|
277 |
print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
|
278 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
279 |
state = gr.State(
|
|
|
282 |
"query": "",
|
283 |
"route_points": [],
|
284 |
"model": model,
|
285 |
+
"tts": tts,
|
286 |
}
|
287 |
)
|
288 |
trip_points = gr.State(value=[])
|
|
|
340 |
vehicle_status = gr.JSON(
|
341 |
value=vehicle.model_dump_json(), label="Vehicle status"
|
342 |
)
|
343 |
+
# Push button
|
344 |
+
clear_history_btn = gr.Button(value="Clear History")
|
345 |
with gr.Column():
|
346 |
output_audio = gr.Audio(label="output audio", autoplay=True)
|
347 |
output_text = gr.TextArea(
|
|
|
353 |
origin.submit(
|
354 |
fn=calculate_route_gradio,
|
355 |
inputs=[origin, destination],
|
356 |
+
outputs=[map_plot, vehicle_status, trip_progress],
|
357 |
)
|
358 |
destination.submit(
|
359 |
fn=calculate_route_gradio,
|
360 |
inputs=[origin, destination],
|
361 |
+
outputs=[map_plot, vehicle_status, trip_progress],
|
362 |
)
|
363 |
|
364 |
# Update time based on the time picker
|
|
|
373 |
|
374 |
# Set the vehicle status based on the trip progress
|
375 |
trip_progress.release(
|
376 |
+
fn=update_vehicle_status, inputs=[trip_progress, origin, destination], outputs=[vehicle_status, map_plot]
|
377 |
)
|
378 |
|
379 |
# Save and transcribe the audio
|
380 |
input_audio.stop_recording(
|
381 |
fn=save_and_transcribe_audio, inputs=[input_audio], outputs=[input_text]
|
382 |
)
|
383 |
+
|
384 |
+
# Clear the history
|
385 |
+
clear_history_btn.click(fn=clear_history, inputs=[], outputs=[])
|
386 |
+
|
387 |
return demo
|
388 |
|
389 |
|
|
|
391 |
gr.close_all()
|
392 |
|
393 |
|
394 |
+
demo = create_demo(False, "llama3", tts=False)
|
395 |
demo.launch(
|
396 |
debug=True,
|
397 |
server_name="0.0.0.0",
|