Spaces:
Running
Running
from __future__ import annotations | |
import asyncio | |
import functools | |
import hashlib | |
import hmac | |
import json | |
import os | |
import re | |
import shutil | |
import sys | |
from collections import deque | |
from contextlib import AsyncExitStack, asynccontextmanager | |
from dataclasses import dataclass as python_dataclass | |
from datetime import datetime | |
from pathlib import Path | |
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper | |
from typing import ( | |
TYPE_CHECKING, | |
AsyncContextManager, | |
AsyncGenerator, | |
BinaryIO, | |
Callable, | |
List, | |
Optional, | |
Tuple, | |
Union, | |
) | |
from urllib.parse import urlparse | |
import anyio | |
import fastapi | |
import gradio_client.utils as client_utils | |
import httpx | |
import multipart | |
from gradio_client.documentation import document | |
from multipart.multipart import parse_options_header | |
from starlette.datastructures import FormData, Headers, MutableHeaders, UploadFile | |
from starlette.formparsers import MultiPartException, MultipartPart | |
from starlette.responses import PlainTextResponse, Response | |
from starlette.types import ASGIApp, Message, Receive, Scope, Send | |
from gradio import processing_utils, utils | |
from gradio.data_classes import PredictBody | |
from gradio.exceptions import Error | |
from gradio.helpers import EventData | |
from gradio.state_holder import SessionState | |
if TYPE_CHECKING: | |
from gradio.blocks import Blocks | |
from gradio.routes import App | |
class Obj: | |
""" | |
Using a class to convert dictionaries into objects. Used by the `Request` class. | |
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/ | |
""" | |
def __init__(self, dict_): | |
self.__dict__.update(dict_) | |
for key, value in dict_.items(): | |
if isinstance(value, (dict, list)): | |
value = Obj(value) | |
setattr(self, key, value) | |
def __getitem__(self, item): | |
return self.__dict__[item] | |
def __setitem__(self, item, value): | |
self.__dict__[item] = value | |
def __iter__(self): | |
for key, value in self.__dict__.items(): | |
if isinstance(value, Obj): | |
yield (key, dict(value)) | |
else: | |
yield (key, value) | |
def __contains__(self, item) -> bool: | |
if item in self.__dict__: | |
return True | |
for value in self.__dict__.values(): | |
if isinstance(value, Obj) and item in value: | |
return True | |
return False | |
def get(self, item, default=None): | |
if item in self: | |
return self.__dict__[item] | |
return default | |
def keys(self): | |
return self.__dict__.keys() | |
def values(self): | |
return self.__dict__.values() | |
def items(self): | |
return self.__dict__.items() | |
def __str__(self) -> str: | |
return str(self.__dict__) | |
def __repr__(self) -> str: | |
return str(self.__dict__) | |
class Request: | |
""" | |
A Gradio request object that can be used to access the request headers, cookies, | |
query parameters and other information about the request from within the prediction | |
function. The class is a thin wrapper around the fastapi.Request class. Attributes | |
of this class include: `headers`, `client`, `query_params`, `session_hash`, and `path_params`. If | |
auth is enabled, the `username` attribute can be used to get the logged in user. | |
Example: | |
import gradio as gr | |
def echo(text, request: gr.Request): | |
if request: | |
print("Request headers dictionary:", request.headers) | |
print("IP address:", request.client.host) | |
print("Query parameters:", dict(request.query_params)) | |
print("Session hash:", request.session_hash) | |
return text | |
io = gr.Interface(echo, "textbox", "textbox").launch() | |
Demos: request_ip_headers | |
""" | |
def __init__( | |
self, | |
request: fastapi.Request | None = None, | |
username: str | None = None, | |
session_hash: str | None = None, | |
**kwargs, | |
): | |
""" | |
Can be instantiated with either a fastapi.Request or by manually passing in | |
attributes (needed for queueing). | |
Parameters: | |
request: A fastapi.Request | |
username: The username of the logged in user (if auth is enabled) | |
session_hash: The session hash of the current session. It is unique for each page load. | |
""" | |
self.request = request | |
self.username = username | |
self.session_hash = session_hash | |
self.kwargs: dict = kwargs | |
def dict_to_obj(self, d): | |
if isinstance(d, dict): | |
return json.loads(json.dumps(d), object_hook=Obj) | |
else: | |
return d | |
def __getattr__(self, name): | |
if self.request: | |
return self.dict_to_obj(getattr(self.request, name)) | |
else: | |
try: | |
obj = self.kwargs[name] | |
except KeyError as ke: | |
raise AttributeError( | |
f"'Request' object has no attribute '{name}'" | |
) from ke | |
return self.dict_to_obj(obj) | |
class FnIndexInferError(Exception): | |
pass | |
def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int: | |
if body.fn_index is None: | |
for i, fn in enumerate(app.get_blocks().fns): | |
if fn.api_name == api_name: | |
return i | |
raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.") | |
else: | |
return body.fn_index | |
def compile_gr_request( | |
app: App, | |
body: PredictBody, | |
fn_index_inferred: int, | |
username: Optional[str], | |
request: Optional[fastapi.Request], | |
): | |
# If this fn_index cancels jobs, then the only input we need is the | |
# current session hash | |
if app.get_blocks().fns[fn_index_inferred].cancels: | |
body.data = [body.session_hash] | |
if body.request: | |
if body.batched: | |
gr_request = [Request(username=username, request=request)] | |
else: | |
gr_request = Request( | |
username=username, request=body.request, session_hash=body.session_hash | |
) | |
else: | |
if request is None: | |
raise ValueError("request must be provided if body.request is None") | |
gr_request = Request( | |
username=username, request=request, session_hash=body.session_hash | |
) | |
return gr_request | |
def restore_session_state(app: App, body: PredictBody): | |
event_id = body.event_id | |
session_hash = getattr(body, "session_hash", None) | |
if session_hash is not None: | |
session_state = app.state_holder[session_hash] | |
# The should_reset set keeps track of the fn_indices | |
# that have been cancelled. When a job is cancelled, | |
# the /reset route will mark the jobs as having been reset. | |
# That way if the cancel job finishes BEFORE the job being cancelled | |
# the job being cancelled will not overwrite the state of the iterator. | |
if event_id is None: | |
iterator = None | |
elif event_id in app.iterators_to_reset: | |
iterator = None | |
app.iterators_to_reset.remove(event_id) | |
else: | |
iterator = app.iterators.get(event_id) | |
else: | |
session_state = SessionState(app.get_blocks()) | |
iterator = None | |
return session_state, iterator | |
def prepare_event_data( | |
blocks: Blocks, | |
body: PredictBody, | |
) -> EventData: | |
target = body.trigger_id | |
event_data = EventData( | |
blocks.blocks.get(target) if target else None, | |
body.event_data, | |
) | |
return event_data | |
async def call_process_api( | |
app: App, | |
body: PredictBody, | |
gr_request: Union[Request, list[Request]], | |
fn_index_inferred: int, | |
root_path: str, | |
): | |
session_state, iterator = restore_session_state(app=app, body=body) | |
dependency = app.get_blocks().fns[fn_index_inferred] | |
event_data = prepare_event_data(app.get_blocks(), body) | |
event_id = body.event_id | |
session_hash = getattr(body, "session_hash", None) | |
inputs = body.data | |
batch_in_single_out = not body.batched and dependency.batch | |
if batch_in_single_out: | |
inputs = [inputs] | |
try: | |
with utils.MatplotlibBackendMananger(): | |
output = await app.get_blocks().process_api( | |
fn_index=fn_index_inferred, | |
inputs=inputs, | |
request=gr_request, | |
state=session_state, | |
iterator=iterator, | |
session_hash=session_hash, | |
event_id=event_id, | |
event_data=event_data, | |
in_event_listener=True, | |
simple_format=body.simple_format, | |
root_path=root_path, | |
) | |
iterator = output.pop("iterator", None) | |
if event_id is not None: | |
app.iterators[event_id] = iterator # type: ignore | |
if isinstance(output, Error): | |
raise output | |
except BaseException: | |
iterator = app.iterators.get(event_id) if event_id is not None else None | |
if iterator is not None: # close off any streams that are still open | |
run_id = id(iterator) | |
pending_streams: dict[int, list] = ( | |
app.get_blocks().pending_streams[session_hash].get(run_id, {}) | |
) | |
for stream in pending_streams.values(): | |
stream.append(None) | |
raise | |
if batch_in_single_out: | |
output["data"] = output["data"][0] | |
return output | |
def get_root_url( | |
request: fastapi.Request, route_path: str, root_path: str | None | |
) -> str: | |
""" | |
Gets the root url of the Gradio app (i.e. the public url of the app) without a trailing slash. | |
This is how the root_url is resolved: | |
1. If a user provides a `root_path` manually that is a full URL, it is returned directly. | |
2. If the request has an x-forwarded-host header (e.g. because it is behind a proxy), the root url is | |
constructed from the x-forwarded-host header. In this case, `route_path` is not used to construct the root url. | |
3. Otherwise, the root url is constructed from the request url. The query parameters and `route_path` are stripped off. | |
And if a relative `root_path` is provided, and it is not already the subpath of the URL, it is appended to the root url. | |
In cases (2) and (3), We also check to see if the x-forwarded-proto header is present, and if so, convert the root url to https. | |
And if there are multiple hosts in the x-forwarded-host or multiple protocols in the x-forwarded-proto, the first one is used. | |
""" | |
def get_first_header_value(header_name: str): | |
header_value = request.headers.get(header_name) | |
if header_value: | |
return header_value.split(",")[0].strip() | |
return None | |
if root_path and client_utils.is_http_url_like(root_path): | |
return root_path.rstrip("/") | |
x_forwarded_host = get_first_header_value("x-forwarded-host") | |
root_url = f"http://{x_forwarded_host}" if x_forwarded_host else str(request.url) | |
root_url = httpx.URL(root_url) | |
root_url = root_url.copy_with(query=None) | |
root_url = str(root_url).rstrip("/") | |
if get_first_header_value("x-forwarded-proto") == "https": | |
root_url = root_url.replace("http://", "https://") | |
route_path = route_path.rstrip("/") | |
if len(route_path) > 0 and not x_forwarded_host: | |
root_url = root_url[: -len(route_path)] | |
root_url = root_url.rstrip("/") | |
root_url = httpx.URL(root_url) | |
if root_path and root_url.path != root_path: | |
root_url = root_url.copy_with(path=root_path) | |
return str(root_url).rstrip("/") | |
def _user_safe_decode(src: bytes, codec: str) -> str: | |
try: | |
return src.decode(codec) | |
except (UnicodeDecodeError, LookupError): | |
return src.decode("latin-1") | |
class GradioUploadFile(UploadFile): | |
"""UploadFile with a sha attribute.""" | |
def __init__( | |
self, | |
file: BinaryIO, | |
*, | |
size: int | None = None, | |
filename: str | None = None, | |
headers: Headers | None = None, | |
) -> None: | |
super().__init__(file, size=size, filename=filename, headers=headers) | |
self.sha = hashlib.sha1() | |
class FileUploadProgressUnit: | |
filename: str | |
chunk_size: int | |
class FileUploadProgressTracker: | |
deque: deque[FileUploadProgressUnit] | |
is_done: bool | |
class FileUploadProgressNotTrackedError(Exception): | |
pass | |
class FileUploadProgressNotQueuedError(Exception): | |
pass | |
class FileUploadProgress: | |
def __init__(self) -> None: | |
self._statuses: dict[str, FileUploadProgressTracker] = {} | |
def track(self, upload_id: str): | |
if upload_id not in self._statuses: | |
self._statuses[upload_id] = FileUploadProgressTracker(deque(), False) | |
def append(self, upload_id: str, filename: str, message_bytes: bytes): | |
if upload_id not in self._statuses: | |
self.track(upload_id) | |
queue = self._statuses[upload_id].deque | |
if len(queue) == 0: | |
queue.append(FileUploadProgressUnit(filename, len(message_bytes))) | |
else: | |
last_unit = queue.popleft() | |
if last_unit.filename != filename: | |
queue.append(FileUploadProgressUnit(filename, len(message_bytes))) | |
else: | |
queue.append( | |
FileUploadProgressUnit( | |
filename, | |
last_unit.chunk_size + len(message_bytes), | |
) | |
) | |
def set_done(self, upload_id: str): | |
if upload_id not in self._statuses: | |
self.track(upload_id) | |
self._statuses[upload_id].is_done = True | |
def is_done(self, upload_id: str): | |
if upload_id not in self._statuses: | |
raise FileUploadProgressNotTrackedError() | |
return self._statuses[upload_id].is_done | |
def stop_tracking(self, upload_id: str): | |
if upload_id in self._statuses: | |
del self._statuses[upload_id] | |
def pop(self, upload_id: str) -> FileUploadProgressUnit: | |
if upload_id not in self._statuses: | |
raise FileUploadProgressNotTrackedError() | |
try: | |
return self._statuses[upload_id].deque.pop() | |
except IndexError as e: | |
raise FileUploadProgressNotQueuedError() from e | |
class GradioMultiPartParser: | |
"""Vendored from starlette.MultipartParser. | |
Thanks starlette! | |
Made the following modifications | |
- Use GradioUploadFile instead of UploadFile | |
- Use NamedTemporaryFile instead of SpooledTemporaryFile | |
- Compute hash of data as the request is streamed | |
""" | |
max_file_size = 1024 * 1024 | |
def __init__( | |
self, | |
headers: Headers, | |
stream: AsyncGenerator[bytes, None], | |
*, | |
max_files: Union[int, float] = 1000, | |
max_fields: Union[int, float] = 1000, | |
upload_id: str | None = None, | |
upload_progress: FileUploadProgress | None = None, | |
max_file_size: int | float, | |
) -> None: | |
self.headers = headers | |
self.stream = stream | |
self.max_files = max_files | |
self.max_fields = max_fields | |
self.items: List[Tuple[str, Union[str, UploadFile]]] = [] | |
self.upload_id = upload_id | |
self.upload_progress = upload_progress | |
self._current_files = 0 | |
self._current_fields = 0 | |
self.max_file_size = max_file_size | |
self._current_partial_header_name: bytes = b"" | |
self._current_partial_header_value: bytes = b"" | |
self._current_part = MultipartPart() | |
self._charset = "" | |
self._file_parts_to_write: List[Tuple[MultipartPart, bytes]] = [] | |
self._file_parts_to_finish: List[MultipartPart] = [] | |
self._files_to_close_on_error: List[_TemporaryFileWrapper] = [] | |
def on_part_begin(self) -> None: | |
self._current_part = MultipartPart() | |
def on_part_data(self, data: bytes, start: int, end: int) -> None: | |
message_bytes = data[start:end] | |
if self.upload_progress is not None: | |
self.upload_progress.append( | |
self.upload_id, # type: ignore | |
self._current_part.file.filename, # type: ignore | |
message_bytes, | |
) | |
if self._current_part.file is None: | |
self._current_part.data += message_bytes | |
else: | |
self._file_parts_to_write.append((self._current_part, message_bytes)) | |
def on_part_end(self) -> None: | |
if self._current_part.file is None: | |
self.items.append( | |
( | |
self._current_part.field_name, | |
_user_safe_decode(self._current_part.data, self._charset), | |
) | |
) | |
else: | |
self._file_parts_to_finish.append(self._current_part) | |
# The file can be added to the items right now even though it's not | |
# finished yet, because it will be finished in the `parse()` method, before | |
# self.items is used in the return value. | |
self.items.append((self._current_part.field_name, self._current_part.file)) | |
def on_header_field(self, data: bytes, start: int, end: int) -> None: | |
self._current_partial_header_name += data[start:end] | |
def on_header_value(self, data: bytes, start: int, end: int) -> None: | |
self._current_partial_header_value += data[start:end] | |
def on_header_end(self) -> None: | |
field = self._current_partial_header_name.lower() | |
if field == b"content-disposition": | |
self._current_part.content_disposition = self._current_partial_header_value | |
self._current_part.item_headers.append( | |
(field, self._current_partial_header_value) | |
) | |
self._current_partial_header_name = b"" | |
self._current_partial_header_value = b"" | |
def on_headers_finished(self) -> None: | |
_, options = parse_options_header(self._current_part.content_disposition or b"") | |
try: | |
self._current_part.field_name = _user_safe_decode( | |
options[b"name"], str(self._charset) | |
) | |
except KeyError as e: | |
raise MultiPartException( | |
'The Content-Disposition header field "name" must be ' "provided." | |
) from e | |
if b"filename" in options: | |
self._current_files += 1 | |
if self._current_files > self.max_files: | |
raise MultiPartException( | |
f"Too many files. Maximum number of files is {self.max_files}." | |
) | |
filename = _user_safe_decode(options[b"filename"], str(self._charset)) | |
tempfile = NamedTemporaryFile(delete=False) | |
self._files_to_close_on_error.append(tempfile) | |
self._current_part.file = GradioUploadFile( | |
file=tempfile, # type: ignore[arg-type] | |
size=0, | |
filename=filename, | |
headers=Headers(raw=self._current_part.item_headers), | |
) | |
else: | |
self._current_fields += 1 | |
if self._current_fields > self.max_fields: | |
raise MultiPartException( | |
f"Too many fields. Maximum number of fields is {self.max_fields}." | |
) | |
self._current_part.file = None | |
def on_end(self) -> None: | |
pass | |
async def parse(self) -> FormData: | |
# Parse the Content-Type header to get the multipart boundary. | |
_, params = parse_options_header(self.headers["Content-Type"]) | |
charset = params.get(b"charset", "utf-8") | |
if isinstance(charset, bytes): | |
charset = charset.decode("latin-1") | |
self._charset = charset | |
try: | |
boundary = params[b"boundary"] | |
except KeyError as e: | |
raise MultiPartException("Missing boundary in multipart.") from e | |
# Callbacks dictionary. | |
callbacks: multipart.multipart.MultipartCallbacks = { | |
"on_part_begin": self.on_part_begin, | |
"on_part_data": self.on_part_data, | |
"on_part_end": self.on_part_end, | |
"on_header_field": self.on_header_field, | |
"on_header_value": self.on_header_value, | |
"on_header_end": self.on_header_end, | |
"on_headers_finished": self.on_headers_finished, | |
"on_end": self.on_end, | |
} | |
# Create the parser. | |
parser = multipart.MultipartParser(boundary, callbacks) | |
try: | |
# Feed the parser with data from the request. | |
async for chunk in self.stream: | |
parser.write(chunk) | |
# Write file data, it needs to use await with the UploadFile methods | |
# that call the corresponding file methods *in a threadpool*, | |
# otherwise, if they were called directly in the callback methods above | |
# (regular, non-async functions), that would block the event loop in | |
# the main thread. | |
for part, data in self._file_parts_to_write: | |
assert part.file # for type checkers # noqa: S101 | |
await part.file.write(data) | |
part.file.sha.update(data) # type: ignore | |
if os.stat(part.file.file.name).st_size > self.max_file_size: | |
if self.upload_progress is not None: | |
self.upload_progress.set_done(self.upload_id) # type: ignore | |
raise MultiPartException( | |
f"File size exceeded maximum allowed size of {self.max_file_size} bytes." | |
) | |
for part in self._file_parts_to_finish: | |
assert part.file # for type checkers # noqa: S101 | |
await part.file.seek(0) | |
self._file_parts_to_write.clear() | |
self._file_parts_to_finish.clear() | |
except MultiPartException as exc: | |
# Close all the files if there was an error. | |
for file in self._files_to_close_on_error: | |
file.close() | |
Path(file.name).unlink() | |
raise exc | |
parser.finalize() | |
if self.upload_progress is not None: | |
self.upload_progress.set_done(self.upload_id) # type: ignore | |
return FormData(self.items) | |
def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> None: | |
for file, dest in zip(files, destinations): | |
shutil.move(file, dest) | |
def update_root_in_config(config: dict, root: str) -> dict: | |
""" | |
Updates the root "key" in the config dictionary to the new root url. If the | |
root url has changed, all of the urls in the config that correspond to component | |
file urls are updated to use the new root url. | |
""" | |
previous_root = config.get("root") | |
if previous_root is None or previous_root != root: | |
config["root"] = root | |
config = processing_utils.add_root_url(config, root, previous_root) | |
return config | |
def compare_passwords_securely(input_password: str, correct_password: str) -> bool: | |
return hmac.compare_digest(input_password.encode(), correct_password.encode()) | |
def starts_with_protocol(string: str) -> bool: | |
"""This regex matches strings that start with a scheme (one or more characters not including colon, slash, or space) | |
followed by ://, or start with just //, \\/, /\\, or \\ as they are interpreted as SMB paths on Windows. | |
""" | |
pattern = r"^(?:[a-zA-Z][a-zA-Z0-9+\-.]*://|//|\\\\|\\/|/\\)" | |
return re.match(pattern, string) is not None | |
def get_hostname(url: str) -> str: | |
""" | |
Returns the hostname of a given url, or an empty string if the url cannot be parsed. | |
Examples: | |
get_hostname("https://www.gradio.app") -> "www.gradio.app" | |
get_hostname("localhost:7860") -> "localhost" | |
get_hostname("127.0.0.1") -> "127.0.0.1" | |
""" | |
if not url: | |
return "" | |
if "://" not in url: | |
url = "http://" + url | |
try: | |
return urlparse(url).hostname or "" | |
except Exception: | |
return "" | |
class CustomCORSMiddleware: | |
# This is a modified version of the Starlette CORSMiddleware that restricts the allowed origins when the host is localhost. | |
# Adapted from: https://github.com/encode/starlette/blob/89fae174a1ea10f59ae248fe030d9b7e83d0b8a0/starlette/middleware/cors.py | |
def __init__( | |
self, | |
app: ASGIApp, | |
) -> None: | |
self.app = app | |
self.all_methods = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") | |
self.preflight_headers = { | |
"Access-Control-Allow-Methods": ", ".join(self.all_methods), | |
"Access-Control-Max-Age": str(600), | |
} | |
self.simple_headers = {"Access-Control-Allow-Credentials": "true"} | |
# Any of these hosts suggests that the Gradio app is running locally. | |
# Note: "null" is a special case that happens if a Gradio app is running | |
# as an embedded web component in a local static webpage. | |
self.localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"] | |
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | |
if scope["type"] != "http": | |
await self.app(scope, receive, send) | |
return | |
headers = Headers(scope=scope) | |
origin = headers.get("origin") | |
if origin is None: | |
await self.app(scope, receive, send) | |
return | |
if scope["method"] == "OPTIONS" and "access-control-request-method" in headers: | |
response = self.preflight_response(request_headers=headers) | |
await response(scope, receive, send) | |
return | |
await self.simple_response(scope, receive, send, request_headers=headers) | |
def preflight_response(self, request_headers: Headers) -> Response: | |
headers = dict(self.preflight_headers) | |
origin = request_headers["Origin"] | |
if self.is_valid_origin(request_headers): | |
headers["Access-Control-Allow-Origin"] = origin | |
requested_headers = request_headers.get("access-control-request-headers") | |
if requested_headers is not None: | |
headers["Access-Control-Allow-Headers"] = requested_headers | |
return PlainTextResponse("OK", status_code=200, headers=headers) | |
async def simple_response( | |
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers | |
) -> None: | |
send = functools.partial(self._send, send=send, request_headers=request_headers) | |
await self.app(scope, receive, send) | |
async def _send( | |
self, message: Message, send: Send, request_headers: Headers | |
) -> None: | |
if message["type"] != "http.response.start": | |
await send(message) | |
return | |
message.setdefault("headers", []) | |
headers = MutableHeaders(scope=message) | |
headers.update(self.simple_headers) | |
has_cookie = "cookie" in request_headers | |
origin = request_headers["Origin"] | |
if has_cookie or self.is_valid_origin(request_headers): | |
self.allow_explicit_origin(headers, origin) | |
await send(message) | |
def is_valid_origin(self, request_headers: Headers) -> bool: | |
origin = request_headers["Origin"] | |
host = request_headers["Host"] | |
host_name = get_hostname(host) | |
origin_name = get_hostname(origin) | |
return ( | |
host_name not in self.localhost_aliases | |
or origin_name in self.localhost_aliases | |
) | |
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: | |
headers["Access-Control-Allow-Origin"] = origin | |
headers.add_vary_header("Origin") | |
def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None: | |
"""Delete files that are older than age. If age is None, delete all files.""" | |
dont_delete = set() | |
for component in blocks.blocks.values(): | |
dont_delete.update(getattr(component, "keep_in_cache", set())) | |
for temp_set in blocks.temp_file_sets: | |
# We use a copy of the set to avoid modifying the set while iterating over it | |
# otherwise we would get an exception: Set changed size during iteration | |
to_remove = set() | |
for file in temp_set: | |
if file in dont_delete: | |
continue | |
try: | |
file_path = Path(file) | |
modified_time = datetime.fromtimestamp(file_path.lstat().st_ctime) | |
if age is None or (datetime.now() - modified_time).seconds > age: | |
os.remove(file) | |
to_remove.add(file) | |
except FileNotFoundError: | |
continue | |
temp_set -= to_remove | |
async def delete_files_on_schedule(app: App, frequency: int, age: int) -> None: | |
"""Startup task to delete files created by the app based on time since last modification.""" | |
while True: | |
await asyncio.sleep(frequency) | |
await anyio.to_thread.run_sync( | |
delete_files_created_by_app, app.get_blocks(), age | |
) | |
async def _lifespan_handler( | |
app: App, frequency: int = 1, age: int = 1 | |
) -> AsyncGenerator: | |
"""A context manager that triggers the startup and shutdown events of the app.""" | |
asyncio.create_task(delete_files_on_schedule(app, frequency, age)) | |
yield | |
delete_files_created_by_app(app.get_blocks(), age=None) | |
async def _delete_state(app: App): | |
"""Delete all expired state every second.""" | |
while True: | |
app.state_holder.delete_all_expired_state() | |
await asyncio.sleep(1) | |
async def _delete_state_handler(app: App): | |
"""When the server launches, regularly delete expired state.""" | |
# The stop event needs to get the current event loop for python 3.8 | |
# but the loop parameter is deprecated for 3.8+ | |
if sys.version_info < (3, 10): | |
loop = asyncio.get_running_loop() | |
app.stop_event = asyncio.Event(loop=loop) | |
asyncio.create_task(_delete_state(app)) | |
yield | |
def create_lifespan_handler( | |
user_lifespan: Callable[[App], AsyncContextManager] | None, | |
frequency: int | None = 1, | |
age: int | None = 1, | |
) -> Callable[[App], AsyncContextManager]: | |
"""Return a context manager that applies _lifespan_handler and user_lifespan if it exists.""" | |
async def _handler(app: App): | |
async with AsyncExitStack() as stack: | |
await stack.enter_async_context(_delete_state_handler(app)) | |
if frequency and age: | |
await stack.enter_async_context(_lifespan_handler(app, frequency, age)) | |
if user_lifespan is not None: | |
await stack.enter_async_context(user_lifespan(app)) | |
yield | |
return _handler | |