darsoarafa commited on
Commit
b92eeaa
·
verified ·
1 Parent(s): 9b35006

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -205
app.py CHANGED
@@ -1,223 +1,219 @@
1
  from __future__ import annotations as _annotations
2
 
 
3
  import json
4
- import os
 
 
 
5
  from dataclasses import dataclass
6
- from typing import Any
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- import gradio as gr
9
- from dotenv import load_dotenv
10
- from httpx import AsyncClient
11
- from pydantic_ai import Agent, ModelRetry, RunContext
12
- from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn
13
 
14
- load_dotenv()
 
15
 
16
 
17
- @dataclass
18
- class Deps:
19
- client: AsyncClient
20
- weather_api_key: str | None
21
- geo_api_key: str | None
22
 
23
 
24
- weather_agent = Agent(
25
- "openai:gpt-4o",
26
- system_prompt="You are an expert packer. A user will ask you for help packing for a trip given a destination. Use your weather tools to provide a concise and effective packing list. Also ask follow up questions if neccessary.",
27
- deps_type=Deps,
28
- retries=2,
29
- )
30
 
31
 
32
- @weather_agent.tool
33
- async def get_lat_lng(
34
- ctx: RunContext[Deps], location_description: str
35
- ) -> dict[str, float]:
36
- """Get the latitude and longitude of a location.
37
 
38
- Args:
39
- ctx: The context.
40
- location_description: A description of a location.
41
- """
42
- if ctx.deps.geo_api_key is None:
43
- # if no API key is provided, return a dummy response (London)
44
- return {"lat": 51.1, "lng": -0.1}
45
-
46
- params = {
47
- "q": location_description,
48
- "api_key": ctx.deps.geo_api_key,
49
- }
50
- r = await ctx.deps.client.get("https://geocode.maps.co/search", params=params)
51
- r.raise_for_status()
52
- data = r.json()
53
-
54
- if data:
55
- return {"lat": data[0]["lat"], "lng": data[0]["lon"]}
56
- else:
57
- raise ModelRetry("Could not find the location")
58
-
59
-
60
- @weather_agent.tool
61
- async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]:
62
- """Get the weather at a location.
63
-
64
- Args:
65
- ctx: The context.
66
- lat: Latitude of the location.
67
- lng: Longitude of the location.
68
- """
69
- if ctx.deps.weather_api_key is None:
70
- # if no API key is provided, return a dummy response
71
- return {"temperature": "21 °C", "description": "Sunny"}
72
-
73
- params = {
74
- "apikey": ctx.deps.weather_api_key,
75
- "location": f"{lat},{lng}",
76
- "units": "metric",
77
- }
78
- r = await ctx.deps.client.get(
79
- "https://api.tomorrow.io/v4/weather/realtime", params=params
80
- )
81
- r.raise_for_status()
82
- data = r.json()
83
-
84
- values = data["data"]["values"]
85
- # https://docs.tomorrow.io/reference/data-layers-weather-codes
86
- code_lookup = {
87
- 1000: "Clear, Sunny",
88
- 1100: "Mostly Clear",
89
- 1101: "Partly Cloudy",
90
- 1102: "Mostly Cloudy",
91
- 1001: "Cloudy",
92
- 2000: "Fog",
93
- 2100: "Light Fog",
94
- 4000: "Drizzle",
95
- 4001: "Rain",
96
- 4200: "Light Rain",
97
- 4201: "Heavy Rain",
98
- 5000: "Snow",
99
- 5001: "Flurries",
100
- 5100: "Light Snow",
101
- 5101: "Heavy Snow",
102
- 6000: "Freezing Drizzle",
103
- 6001: "Freezing Rain",
104
- 6200: "Light Freezing Rain",
105
- 6201: "Heavy Freezing Rain",
106
- 7000: "Ice Pellets",
107
- 7101: "Heavy Ice Pellets",
108
- 7102: "Light Ice Pellets",
109
- 8000: "Thunderstorm",
110
- }
111
- return {
112
- "temperature": f'{values["temperatureApparent"]:0.0f}°C',
113
- "description": code_lookup.get(values["weatherCode"], "Unknown"),
114
- }
115
-
116
-
117
- TOOL_TO_DISPLAY_NAME = {"get_lat_lng": "Geocoding API", "get_weather": "Weather API"}
118
-
119
- client = AsyncClient()
120
- weather_api_key = os.getenv("WEATHER_API_KEY")
121
- # create a free API key at https://geocode.maps.co/
122
- geo_api_key = os.getenv("GEO_API_KEY")
123
- deps = Deps(client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key)
124
-
125
-
126
- async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list):
127
- chatbot.append({"role": "user", "content": prompt})
128
- yield gr.Textbox(interactive=False, value=""), chatbot, gr.skip()
129
- async with weather_agent.run_stream(
130
- prompt, deps=deps, message_history=past_messages
131
- ) as result:
132
- for message in result.new_messages():
133
- past_messages.append(message)
134
- if isinstance(message, ModelStructuredResponse):
135
- for call in message.calls:
136
- gr_message = {
137
- "role": "assistant",
138
- "content": "",
139
- "metadata": {
140
- "title": f"### 🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}",
141
- "id": call.tool_id,
142
- },
143
- }
144
- chatbot.append(gr_message)
145
- if isinstance(message, ToolReturn):
146
- for gr_message in chatbot:
147
- if gr_message.get("metadata", {}).get("id", "") == message.tool_id:
148
- gr_message["content"] = f"Output: {json.dumps(message.content)}"
149
- yield gr.skip(), chatbot, gr.skip()
150
- chatbot.append({"role": "assistant", "content": ""})
151
- async for message in result.stream_text():
152
- chatbot[-1]["content"] = message
153
- yield gr.skip(), chatbot, gr.skip()
154
- data = await result.get_data()
155
- past_messages.append(ModelTextResponse(content=data))
156
- yield gr.Textbox(interactive=True), gr.skip(), past_messages
157
-
158
-
159
- async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData):
160
- new_history = chatbot[: retry_data.index]
161
- previous_prompt = chatbot[retry_data.index]["content"]
162
- past_messages = past_messages[: retry_data.index]
163
- async for update in stream_from_agent(previous_prompt, new_history, past_messages):
164
- yield update
165
-
166
-
167
- def undo(chatbot, past_messages: list, undo_data: gr.UndoData):
168
- new_history = chatbot[: undo_data.index]
169
- past_messages = past_messages[: undo_data.index]
170
- return chatbot[undo_data.index]["content"], new_history, past_messages
171
-
172
-
173
- def select_data(message: gr.SelectData) -> str:
174
- return message.value["text"]
175
-
176
-
177
- with gr.Blocks() as demo:
178
- gr.HTML(
179
- """
180
- <div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%">
181
- <img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto">
182
- <div>
183
- <h1 style="margin: 0 0 1rem 0">Vacation Packing Assistant</h1>
184
- <h3 style="margin: 0 0 0.5rem 0">
185
- This assistant will help you pack for your vacation. Enter your destination and it will provide you with a concise packing list based on the weather forecast.
186
- </h3>
187
- <h3 style="margin: 0">
188
- Feel free to ask for help with any other questions you have about your trip!
189
- </h3>
190
- </div>
191
- </div>
192
- """
193
- )
194
- past_messages = gr.State([])
195
- chatbot = gr.Chatbot(
196
- label="Packing Assistant",
197
- type="messages",
198
- avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg"),
199
- examples=[
200
- {"text": "I am going to Paris for the holidays, what should I pack?"},
201
- {"text": "I am going to Tokyo this week."},
202
- ],
203
  )
204
- with gr.Row():
205
- prompt = gr.Textbox(
206
- lines=1,
207
- show_label=False,
208
- placeholder="I am planning a trip to Miami, what should I pack?",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  )
210
- generation = prompt.submit(
211
- stream_from_agent,
212
- inputs=[prompt, chatbot, past_messages],
213
- outputs=[prompt, chatbot, past_messages],
214
- )
215
- chatbot.example_select(select_data, None, [prompt])
216
- chatbot.retry(
217
- handle_retry, [chatbot, past_messages], [prompt, chatbot, past_messages]
218
- )
219
- chatbot.undo(undo, [chatbot, past_messages], [prompt, chatbot, past_messages])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
 
 
221
 
222
- if __name__ == "__main__":
 
 
223
  demo.launch()
 
1
  from __future__ import annotations as _annotations
2
 
3
+ import asyncio
4
  import json
5
+ import sqlite3
6
+ from collections.abc import AsyncIterator
7
+ from concurrent.futures.thread import ThreadPoolExecutor
8
+ from contextlib import asynccontextmanager
9
  from dataclasses import dataclass
10
+ from datetime import datetime, timezone
11
+ from functools import partial
12
+ from pathlib import Path
13
+ from typing import Annotated, Any, Callable, Literal, TypeVar
14
+
15
+ import fastapi
16
+ import logfire
17
+ from fastapi import Depends, Request
18
+ from fastapi.responses import FileResponse, Response, StreamingResponse
19
+ from typing_extensions import LiteralString, ParamSpec, TypedDict
20
+
21
+ from pydantic_ai import Agent
22
+ from pydantic_ai.exceptions import UnexpectedModelBehavior
23
+ from pydantic_ai.messages import (
24
+ ModelMessage,
25
+ ModelMessagesTypeAdapter,
26
+ ModelRequest,
27
+ ModelResponse,
28
+ TextPart,
29
+ UserPromptPart,
30
+ )
31
 
32
+ # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
33
+ logfire.configure(send_to_logfire='if-token-present')
 
 
 
34
 
35
+ agent = Agent('openai:gpt-4o')
36
+ THIS_DIR = Path(__file__).parent
37
 
38
 
39
+ @asynccontextmanager
40
+ async def lifespan(_app: fastapi.FastAPI):
41
+ async with Database.connect() as db:
42
+ yield {'db': db}
 
43
 
44
 
45
+ app = fastapi.FastAPI(lifespan=lifespan)
46
+ logfire.instrument_fastapi(app)
 
 
 
 
47
 
48
 
49
+ @app.get('/')
50
+ async def index() -> FileResponse:
51
+ return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html')
 
 
52
 
53
+
54
+ @app.get('/chat_app.ts')
55
+ async def main_ts() -> FileResponse:
56
+ """Get the raw typescript code, it's compiled in the browser, forgive me."""
57
+ return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain')
58
+
59
+
60
+ async def get_db(request: Request) -> Database:
61
+ return request.state.db
62
+
63
+
64
+ @app.get('/chat/')
65
+ async def get_chat(database: Database = Depends(get_db)) -> Response:
66
+ msgs = await database.get_messages()
67
+ return Response(
68
+ b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs),
69
+ media_type='text/plain',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
+
72
+
73
+ class ChatMessage(TypedDict):
74
+ """Format of messages sent to the browser."""
75
+
76
+ role: Literal['user', 'model']
77
+ timestamp: str
78
+ content: str
79
+
80
+
81
+ def to_chat_message(m: ModelMessage) -> ChatMessage:
82
+ first_part = m.parts[0]
83
+ if isinstance(m, ModelRequest):
84
+ if isinstance(first_part, UserPromptPart):
85
+ return {
86
+ 'role': 'user',
87
+ 'timestamp': first_part.timestamp.isoformat(),
88
+ 'content': first_part.content,
89
+ }
90
+ elif isinstance(m, ModelResponse):
91
+ if isinstance(first_part, TextPart):
92
+ return {
93
+ 'role': 'model',
94
+ 'timestamp': m.timestamp.isoformat(),
95
+ 'content': first_part.content,
96
+ }
97
+ raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}')
98
+
99
+
100
+ @app.post('/chat/')
101
+ async def post_chat(
102
+ prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db)
103
+ ) -> StreamingResponse:
104
+ async def stream_messages():
105
+ """Streams new line delimited JSON `Message`s to the client."""
106
+ # stream the user prompt so that can be displayed straight away
107
+ yield (
108
+ json.dumps(
109
+ {
110
+ 'role': 'user',
111
+ 'timestamp': datetime.now(tz=timezone.utc).isoformat(),
112
+ 'content': prompt,
113
+ }
114
+ ).encode('utf-8')
115
+ + b'\n'
116
  )
117
+ # get the chat history so far to pass as context to the agent
118
+ messages = await database.get_messages()
119
+ # run the agent with the user prompt and the chat history
120
+ async with agent.run_stream(prompt, message_history=messages) as result:
121
+ async for text in result.stream(debounce_by=0.01):
122
+ # text here is a `str` and the frontend wants
123
+ # JSON encoded ModelResponse, so we create one
124
+ m = ModelResponse.from_text(content=text, timestamp=result.timestamp())
125
+ yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n'
126
+
127
+ # add new messages (e.g. the user prompt and the agent response in this case) to the database
128
+ await database.add_messages(result.new_messages_json())
129
+
130
+ return StreamingResponse(stream_messages(), media_type='text/plain')
131
+
132
+
133
+ P = ParamSpec('P')
134
+ R = TypeVar('R')
135
+
136
+
137
+ @dataclass
138
+ class Database:
139
+ """Rudimentary database to store chat messages in SQLite.
140
+
141
+ The SQLite standard library package is synchronous, so we
142
+ use a thread pool executor to run queries asynchronously.
143
+ """
144
+
145
+ con: sqlite3.Connection
146
+ _loop: asyncio.AbstractEventLoop
147
+ _executor: ThreadPoolExecutor
148
+
149
+ @classmethod
150
+ @asynccontextmanager
151
+ async def connect(
152
+ cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite'
153
+ ) -> AsyncIterator[Database]:
154
+ with logfire.span('connect to DB'):
155
+ loop = asyncio.get_event_loop()
156
+ executor = ThreadPoolExecutor(max_workers=1)
157
+ con = await loop.run_in_executor(executor, cls._connect, file)
158
+ slf = cls(con, loop, executor)
159
+ try:
160
+ yield slf
161
+ finally:
162
+ await slf._asyncify(con.close)
163
+
164
+ @staticmethod
165
+ def _connect(file: Path) -> sqlite3.Connection:
166
+ con = sqlite3.connect(str(file))
167
+ con = logfire.instrument_sqlite3(con)
168
+ cur = con.cursor()
169
+ cur.execute(
170
+ 'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);'
171
+ )
172
+ con.commit()
173
+ return con
174
+
175
+ async def add_messages(self, messages: bytes):
176
+ await self._asyncify(
177
+ self._execute,
178
+ 'INSERT INTO messages (message_list) VALUES (?);',
179
+ messages,
180
+ commit=True,
181
+ )
182
+ await self._asyncify(self.con.commit)
183
+
184
+ async def get_messages(self) -> list[ModelMessage]:
185
+ c = await self._asyncify(
186
+ self._execute, 'SELECT message_list FROM messages order by id'
187
+ )
188
+ rows = await self._asyncify(c.fetchall)
189
+ messages: list[ModelMessage] = []
190
+ for row in rows:
191
+ messages.extend(ModelMessagesTypeAdapter.validate_json(row[0]))
192
+ return messages
193
+
194
+ def _execute(
195
+ self, sql: LiteralString, *args: Any, commit: bool = False
196
+ ) -> sqlite3.Cursor:
197
+ cur = self.con.cursor()
198
+ cur.execute(sql, args)
199
+ if commit:
200
+ self.con.commit()
201
+ return cur
202
+
203
+ async def _asyncify(
204
+ self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
205
+ ) -> R:
206
+ return await self._loop.run_in_executor( # type: ignore
207
+ self._executor,
208
+ partial(func, **kwargs),
209
+ *args, # type: ignore
210
+ )
211
+
212
 
213
+ if __name__ == '__main__':
214
+ #import uvicorn
215
 
216
+ #uvicorn.run(
217
+ # 'pydantic_ai_examples.chat_app:app', reload=True, reload_dirs=[str(THIS_DIR)]
218
+ #)
219
  demo.launch()