Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,223 +1,219 @@
|
|
1 |
from __future__ import annotations as _annotations
|
2 |
|
|
|
3 |
import json
|
4 |
-
import
|
|
|
|
|
|
|
5 |
from dataclasses import dataclass
|
6 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
from httpx import AsyncClient
|
11 |
-
from pydantic_ai import Agent, ModelRetry, RunContext
|
12 |
-
from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn
|
13 |
|
14 |
-
|
|
|
15 |
|
16 |
|
17 |
-
@
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
geo_api_key: str | None
|
22 |
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
system_prompt="You are an expert packer. A user will ask you for help packing for a trip given a destination. Use your weather tools to provide a concise and effective packing list. Also ask follow up questions if neccessary.",
|
27 |
-
deps_type=Deps,
|
28 |
-
retries=2,
|
29 |
-
)
|
30 |
|
31 |
|
32 |
-
@
|
33 |
-
async def
|
34 |
-
|
35 |
-
) -> dict[str, float]:
|
36 |
-
"""Get the latitude and longitude of a location.
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
"""
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
return {"lat": data[0]["lat"], "lng": data[0]["lon"]}
|
56 |
-
else:
|
57 |
-
raise ModelRetry("Could not find the location")
|
58 |
-
|
59 |
-
|
60 |
-
@weather_agent.tool
|
61 |
-
async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]:
|
62 |
-
"""Get the weather at a location.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
ctx: The context.
|
66 |
-
lat: Latitude of the location.
|
67 |
-
lng: Longitude of the location.
|
68 |
-
"""
|
69 |
-
if ctx.deps.weather_api_key is None:
|
70 |
-
# if no API key is provided, return a dummy response
|
71 |
-
return {"temperature": "21 °C", "description": "Sunny"}
|
72 |
-
|
73 |
-
params = {
|
74 |
-
"apikey": ctx.deps.weather_api_key,
|
75 |
-
"location": f"{lat},{lng}",
|
76 |
-
"units": "metric",
|
77 |
-
}
|
78 |
-
r = await ctx.deps.client.get(
|
79 |
-
"https://api.tomorrow.io/v4/weather/realtime", params=params
|
80 |
-
)
|
81 |
-
r.raise_for_status()
|
82 |
-
data = r.json()
|
83 |
-
|
84 |
-
values = data["data"]["values"]
|
85 |
-
# https://docs.tomorrow.io/reference/data-layers-weather-codes
|
86 |
-
code_lookup = {
|
87 |
-
1000: "Clear, Sunny",
|
88 |
-
1100: "Mostly Clear",
|
89 |
-
1101: "Partly Cloudy",
|
90 |
-
1102: "Mostly Cloudy",
|
91 |
-
1001: "Cloudy",
|
92 |
-
2000: "Fog",
|
93 |
-
2100: "Light Fog",
|
94 |
-
4000: "Drizzle",
|
95 |
-
4001: "Rain",
|
96 |
-
4200: "Light Rain",
|
97 |
-
4201: "Heavy Rain",
|
98 |
-
5000: "Snow",
|
99 |
-
5001: "Flurries",
|
100 |
-
5100: "Light Snow",
|
101 |
-
5101: "Heavy Snow",
|
102 |
-
6000: "Freezing Drizzle",
|
103 |
-
6001: "Freezing Rain",
|
104 |
-
6200: "Light Freezing Rain",
|
105 |
-
6201: "Heavy Freezing Rain",
|
106 |
-
7000: "Ice Pellets",
|
107 |
-
7101: "Heavy Ice Pellets",
|
108 |
-
7102: "Light Ice Pellets",
|
109 |
-
8000: "Thunderstorm",
|
110 |
-
}
|
111 |
-
return {
|
112 |
-
"temperature": f'{values["temperatureApparent"]:0.0f}°C',
|
113 |
-
"description": code_lookup.get(values["weatherCode"], "Unknown"),
|
114 |
-
}
|
115 |
-
|
116 |
-
|
117 |
-
TOOL_TO_DISPLAY_NAME = {"get_lat_lng": "Geocoding API", "get_weather": "Weather API"}
|
118 |
-
|
119 |
-
client = AsyncClient()
|
120 |
-
weather_api_key = os.getenv("WEATHER_API_KEY")
|
121 |
-
# create a free API key at https://geocode.maps.co/
|
122 |
-
geo_api_key = os.getenv("GEO_API_KEY")
|
123 |
-
deps = Deps(client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key)
|
124 |
-
|
125 |
-
|
126 |
-
async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list):
|
127 |
-
chatbot.append({"role": "user", "content": prompt})
|
128 |
-
yield gr.Textbox(interactive=False, value=""), chatbot, gr.skip()
|
129 |
-
async with weather_agent.run_stream(
|
130 |
-
prompt, deps=deps, message_history=past_messages
|
131 |
-
) as result:
|
132 |
-
for message in result.new_messages():
|
133 |
-
past_messages.append(message)
|
134 |
-
if isinstance(message, ModelStructuredResponse):
|
135 |
-
for call in message.calls:
|
136 |
-
gr_message = {
|
137 |
-
"role": "assistant",
|
138 |
-
"content": "",
|
139 |
-
"metadata": {
|
140 |
-
"title": f"### 🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}",
|
141 |
-
"id": call.tool_id,
|
142 |
-
},
|
143 |
-
}
|
144 |
-
chatbot.append(gr_message)
|
145 |
-
if isinstance(message, ToolReturn):
|
146 |
-
for gr_message in chatbot:
|
147 |
-
if gr_message.get("metadata", {}).get("id", "") == message.tool_id:
|
148 |
-
gr_message["content"] = f"Output: {json.dumps(message.content)}"
|
149 |
-
yield gr.skip(), chatbot, gr.skip()
|
150 |
-
chatbot.append({"role": "assistant", "content": ""})
|
151 |
-
async for message in result.stream_text():
|
152 |
-
chatbot[-1]["content"] = message
|
153 |
-
yield gr.skip(), chatbot, gr.skip()
|
154 |
-
data = await result.get_data()
|
155 |
-
past_messages.append(ModelTextResponse(content=data))
|
156 |
-
yield gr.Textbox(interactive=True), gr.skip(), past_messages
|
157 |
-
|
158 |
-
|
159 |
-
async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData):
|
160 |
-
new_history = chatbot[: retry_data.index]
|
161 |
-
previous_prompt = chatbot[retry_data.index]["content"]
|
162 |
-
past_messages = past_messages[: retry_data.index]
|
163 |
-
async for update in stream_from_agent(previous_prompt, new_history, past_messages):
|
164 |
-
yield update
|
165 |
-
|
166 |
-
|
167 |
-
def undo(chatbot, past_messages: list, undo_data: gr.UndoData):
|
168 |
-
new_history = chatbot[: undo_data.index]
|
169 |
-
past_messages = past_messages[: undo_data.index]
|
170 |
-
return chatbot[undo_data.index]["content"], new_history, past_messages
|
171 |
-
|
172 |
-
|
173 |
-
def select_data(message: gr.SelectData) -> str:
|
174 |
-
return message.value["text"]
|
175 |
-
|
176 |
-
|
177 |
-
with gr.Blocks() as demo:
|
178 |
-
gr.HTML(
|
179 |
-
"""
|
180 |
-
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%">
|
181 |
-
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto">
|
182 |
-
<div>
|
183 |
-
<h1 style="margin: 0 0 1rem 0">Vacation Packing Assistant</h1>
|
184 |
-
<h3 style="margin: 0 0 0.5rem 0">
|
185 |
-
This assistant will help you pack for your vacation. Enter your destination and it will provide you with a concise packing list based on the weather forecast.
|
186 |
-
</h3>
|
187 |
-
<h3 style="margin: 0">
|
188 |
-
Feel free to ask for help with any other questions you have about your trip!
|
189 |
-
</h3>
|
190 |
-
</div>
|
191 |
-
</div>
|
192 |
-
"""
|
193 |
-
)
|
194 |
-
past_messages = gr.State([])
|
195 |
-
chatbot = gr.Chatbot(
|
196 |
-
label="Packing Assistant",
|
197 |
-
type="messages",
|
198 |
-
avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg"),
|
199 |
-
examples=[
|
200 |
-
{"text": "I am going to Paris for the holidays, what should I pack?"},
|
201 |
-
{"text": "I am going to Tokyo this week."},
|
202 |
-
],
|
203 |
)
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
)
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
|
|
|
|
221 |
|
222 |
-
|
|
|
|
|
223 |
demo.launch()
|
|
|
1 |
from __future__ import annotations as _annotations
|
2 |
|
3 |
+
import asyncio
|
4 |
import json
|
5 |
+
import sqlite3
|
6 |
+
from collections.abc import AsyncIterator
|
7 |
+
from concurrent.futures.thread import ThreadPoolExecutor
|
8 |
+
from contextlib import asynccontextmanager
|
9 |
from dataclasses import dataclass
|
10 |
+
from datetime import datetime, timezone
|
11 |
+
from functools import partial
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import Annotated, Any, Callable, Literal, TypeVar
|
14 |
+
|
15 |
+
import fastapi
|
16 |
+
import logfire
|
17 |
+
from fastapi import Depends, Request
|
18 |
+
from fastapi.responses import FileResponse, Response, StreamingResponse
|
19 |
+
from typing_extensions import LiteralString, ParamSpec, TypedDict
|
20 |
+
|
21 |
+
from pydantic_ai import Agent
|
22 |
+
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
23 |
+
from pydantic_ai.messages import (
|
24 |
+
ModelMessage,
|
25 |
+
ModelMessagesTypeAdapter,
|
26 |
+
ModelRequest,
|
27 |
+
ModelResponse,
|
28 |
+
TextPart,
|
29 |
+
UserPromptPart,
|
30 |
+
)
|
31 |
|
32 |
+
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
|
33 |
+
logfire.configure(send_to_logfire='if-token-present')
|
|
|
|
|
|
|
34 |
|
35 |
+
agent = Agent('openai:gpt-4o')
|
36 |
+
THIS_DIR = Path(__file__).parent
|
37 |
|
38 |
|
39 |
+
@asynccontextmanager
|
40 |
+
async def lifespan(_app: fastapi.FastAPI):
|
41 |
+
async with Database.connect() as db:
|
42 |
+
yield {'db': db}
|
|
|
43 |
|
44 |
|
45 |
+
app = fastapi.FastAPI(lifespan=lifespan)
|
46 |
+
logfire.instrument_fastapi(app)
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
+
@app.get('/')
|
50 |
+
async def index() -> FileResponse:
|
51 |
+
return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html')
|
|
|
|
|
52 |
|
53 |
+
|
54 |
+
@app.get('/chat_app.ts')
|
55 |
+
async def main_ts() -> FileResponse:
|
56 |
+
"""Get the raw typescript code, it's compiled in the browser, forgive me."""
|
57 |
+
return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain')
|
58 |
+
|
59 |
+
|
60 |
+
async def get_db(request: Request) -> Database:
|
61 |
+
return request.state.db
|
62 |
+
|
63 |
+
|
64 |
+
@app.get('/chat/')
|
65 |
+
async def get_chat(database: Database = Depends(get_db)) -> Response:
|
66 |
+
msgs = await database.get_messages()
|
67 |
+
return Response(
|
68 |
+
b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs),
|
69 |
+
media_type='text/plain',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
)
|
71 |
+
|
72 |
+
|
73 |
+
class ChatMessage(TypedDict):
|
74 |
+
"""Format of messages sent to the browser."""
|
75 |
+
|
76 |
+
role: Literal['user', 'model']
|
77 |
+
timestamp: str
|
78 |
+
content: str
|
79 |
+
|
80 |
+
|
81 |
+
def to_chat_message(m: ModelMessage) -> ChatMessage:
|
82 |
+
first_part = m.parts[0]
|
83 |
+
if isinstance(m, ModelRequest):
|
84 |
+
if isinstance(first_part, UserPromptPart):
|
85 |
+
return {
|
86 |
+
'role': 'user',
|
87 |
+
'timestamp': first_part.timestamp.isoformat(),
|
88 |
+
'content': first_part.content,
|
89 |
+
}
|
90 |
+
elif isinstance(m, ModelResponse):
|
91 |
+
if isinstance(first_part, TextPart):
|
92 |
+
return {
|
93 |
+
'role': 'model',
|
94 |
+
'timestamp': m.timestamp.isoformat(),
|
95 |
+
'content': first_part.content,
|
96 |
+
}
|
97 |
+
raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}')
|
98 |
+
|
99 |
+
|
100 |
+
@app.post('/chat/')
|
101 |
+
async def post_chat(
|
102 |
+
prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db)
|
103 |
+
) -> StreamingResponse:
|
104 |
+
async def stream_messages():
|
105 |
+
"""Streams new line delimited JSON `Message`s to the client."""
|
106 |
+
# stream the user prompt so that can be displayed straight away
|
107 |
+
yield (
|
108 |
+
json.dumps(
|
109 |
+
{
|
110 |
+
'role': 'user',
|
111 |
+
'timestamp': datetime.now(tz=timezone.utc).isoformat(),
|
112 |
+
'content': prompt,
|
113 |
+
}
|
114 |
+
).encode('utf-8')
|
115 |
+
+ b'\n'
|
116 |
)
|
117 |
+
# get the chat history so far to pass as context to the agent
|
118 |
+
messages = await database.get_messages()
|
119 |
+
# run the agent with the user prompt and the chat history
|
120 |
+
async with agent.run_stream(prompt, message_history=messages) as result:
|
121 |
+
async for text in result.stream(debounce_by=0.01):
|
122 |
+
# text here is a `str` and the frontend wants
|
123 |
+
# JSON encoded ModelResponse, so we create one
|
124 |
+
m = ModelResponse.from_text(content=text, timestamp=result.timestamp())
|
125 |
+
yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n'
|
126 |
+
|
127 |
+
# add new messages (e.g. the user prompt and the agent response in this case) to the database
|
128 |
+
await database.add_messages(result.new_messages_json())
|
129 |
+
|
130 |
+
return StreamingResponse(stream_messages(), media_type='text/plain')
|
131 |
+
|
132 |
+
|
133 |
+
P = ParamSpec('P')
|
134 |
+
R = TypeVar('R')
|
135 |
+
|
136 |
+
|
137 |
+
@dataclass
|
138 |
+
class Database:
|
139 |
+
"""Rudimentary database to store chat messages in SQLite.
|
140 |
+
|
141 |
+
The SQLite standard library package is synchronous, so we
|
142 |
+
use a thread pool executor to run queries asynchronously.
|
143 |
+
"""
|
144 |
+
|
145 |
+
con: sqlite3.Connection
|
146 |
+
_loop: asyncio.AbstractEventLoop
|
147 |
+
_executor: ThreadPoolExecutor
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
@asynccontextmanager
|
151 |
+
async def connect(
|
152 |
+
cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite'
|
153 |
+
) -> AsyncIterator[Database]:
|
154 |
+
with logfire.span('connect to DB'):
|
155 |
+
loop = asyncio.get_event_loop()
|
156 |
+
executor = ThreadPoolExecutor(max_workers=1)
|
157 |
+
con = await loop.run_in_executor(executor, cls._connect, file)
|
158 |
+
slf = cls(con, loop, executor)
|
159 |
+
try:
|
160 |
+
yield slf
|
161 |
+
finally:
|
162 |
+
await slf._asyncify(con.close)
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def _connect(file: Path) -> sqlite3.Connection:
|
166 |
+
con = sqlite3.connect(str(file))
|
167 |
+
con = logfire.instrument_sqlite3(con)
|
168 |
+
cur = con.cursor()
|
169 |
+
cur.execute(
|
170 |
+
'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);'
|
171 |
+
)
|
172 |
+
con.commit()
|
173 |
+
return con
|
174 |
+
|
175 |
+
async def add_messages(self, messages: bytes):
|
176 |
+
await self._asyncify(
|
177 |
+
self._execute,
|
178 |
+
'INSERT INTO messages (message_list) VALUES (?);',
|
179 |
+
messages,
|
180 |
+
commit=True,
|
181 |
+
)
|
182 |
+
await self._asyncify(self.con.commit)
|
183 |
+
|
184 |
+
async def get_messages(self) -> list[ModelMessage]:
|
185 |
+
c = await self._asyncify(
|
186 |
+
self._execute, 'SELECT message_list FROM messages order by id'
|
187 |
+
)
|
188 |
+
rows = await self._asyncify(c.fetchall)
|
189 |
+
messages: list[ModelMessage] = []
|
190 |
+
for row in rows:
|
191 |
+
messages.extend(ModelMessagesTypeAdapter.validate_json(row[0]))
|
192 |
+
return messages
|
193 |
+
|
194 |
+
def _execute(
|
195 |
+
self, sql: LiteralString, *args: Any, commit: bool = False
|
196 |
+
) -> sqlite3.Cursor:
|
197 |
+
cur = self.con.cursor()
|
198 |
+
cur.execute(sql, args)
|
199 |
+
if commit:
|
200 |
+
self.con.commit()
|
201 |
+
return cur
|
202 |
+
|
203 |
+
async def _asyncify(
|
204 |
+
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
205 |
+
) -> R:
|
206 |
+
return await self._loop.run_in_executor( # type: ignore
|
207 |
+
self._executor,
|
208 |
+
partial(func, **kwargs),
|
209 |
+
*args, # type: ignore
|
210 |
+
)
|
211 |
+
|
212 |
|
213 |
+
if __name__ == '__main__':
|
214 |
+
#import uvicorn
|
215 |
|
216 |
+
#uvicorn.run(
|
217 |
+
# 'pydantic_ai_examples.chat_app:app', reload=True, reload_dirs=[str(THIS_DIR)]
|
218 |
+
#)
|
219 |
demo.launch()
|