Spaces:
Runtime error
Runtime error
File size: 7,565 Bytes
ad93086 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import os
import tempfile
from collections import namedtuple
from pathlib import Path
import gradio.components
import gradio as gr
from PIL import PngImagePlugin
from modules import shared
Savedfile = namedtuple("Savedfile", ["name"])
def register_tmp_file(gradio_app, filename):
if hasattr(gradio_app, 'temp_file_sets'): # gradio 3.15
if hasattr(gr.utils, 'abspath'): # gradio 4.19
filename = gr.utils.abspath(filename)
else:
filename = os.path.abspath(filename)
gradio_app.temp_file_sets[0] = gradio_app.temp_file_sets[0] | {filename}
if hasattr(gradio_app, 'temp_dirs'): # gradio 3.9
gradio_app.temp_dirs = gradio_app.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
def check_tmp_file(gradio_app, filename):
if hasattr(gradio_app, 'temp_file_sets'):
if hasattr(gr.utils, 'abspath'): # gradio 4.19
filename = gr.utils.abspath(filename)
else:
filename = os.path.abspath(filename)
return any(filename in fileset for fileset in gradio_app.temp_file_sets)
if hasattr(gradio_app, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio_app.temp_dirs)
return False
def save_pil_to_file(pil_image, cache_dir=None, format="png"):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
filename_with_mtime = f'{already_saved_as}?{os.path.getmtime(already_saved_as)}'
register_tmp_file(shared.demo, filename_with_mtime)
return filename_with_mtime
if shared.opts.temp_dir:
dir = shared.opts.temp_dir
else:
dir = cache_dir
os.makedirs(dir, exist_ok=True)
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj.name
async def async_move_files_to_cache(data, block, postprocess=False, check_in_upload_folder=False, keep_in_cache=False):
"""Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file.
Also handles the case where the file is on an external Gradio app (/proxy=...).
Runs after .postprocess() and before .preprocess().
Copied from gradio's processing_utils.py
Args:
data: The input or output data for a component. Can be a dictionary or a dataclass
block: The component whose data is being processed
postprocess: Whether its running from postprocessing
check_in_upload_folder: If True, instead of moving the file to cache, checks if the file is in already in cache (exception if not).
keep_in_cache: If True, the file will not be deleted from cache when the server is shut down.
"""
from gradio import FileData
from gradio.data_classes import GradioRootModel
from gradio.data_classes import GradioModel
from gradio_client import utils as client_utils
from gradio.utils import get_upload_folder, is_in_or_equal, is_static_file
async def _move_to_cache(d: dict):
payload = FileData(**d)
# EDITED
payload.path = payload.path.rsplit('?', 1)[0]
# If the gradio app developer is returning a URL from
# postprocess, it means the component can display a URL
# without it being served from the gradio server
# This makes it so that the URL is not downloaded and speeds up event processing
if payload.url and postprocess and client_utils.is_http_url_like(payload.url):
payload.path = payload.url
elif is_static_file(payload):
pass
elif not block.proxy_url:
# EDITED
if check_tmp_file(shared.demo, payload.path):
temp_file_path = payload.path
else:
# If the file is on a remote server, do not move it to cache.
if check_in_upload_folder and not client_utils.is_http_url_like(
payload.path
):
path = os.path.abspath(payload.path)
if not is_in_or_equal(path, get_upload_folder()):
raise ValueError(
f"File {path} is not in the upload folder and cannot be accessed."
)
if not payload.is_stream:
temp_file_path = await block.async_move_resource_to_block_cache(
payload.path
)
if temp_file_path is None:
raise ValueError("Did not determine a file path for the resource.")
payload.path = temp_file_path
if keep_in_cache:
block.keep_in_cache.add(payload.path)
url_prefix = "/stream/" if payload.is_stream else "/file="
if block.proxy_url:
proxy_url = block.proxy_url.rstrip("/")
url = f"/proxy={proxy_url}{url_prefix}{payload.path}"
elif client_utils.is_http_url_like(payload.path) or payload.path.startswith(
f"{url_prefix}"
):
url = payload.path
else:
url = f"{url_prefix}{payload.path}"
payload.url = url
return payload.model_dump()
if isinstance(data, (GradioRootModel, GradioModel)):
data = data.model_dump()
return await client_utils.async_traverse(
data, _move_to_cache, client_utils.is_file_obj
)
def install_ui_tempdir_override():
"""
override save to file function so that it also writes PNG info.
override gradio4's move_files_to_cache function to prevent it from writing a copy into a temporary directory.
"""
gradio.processing_utils.save_pil_to_cache = save_pil_to_file
gradio.processing_utils.async_move_files_to_cache = async_move_files_to_cache
def on_tmpdir_changed():
if shared.opts.temp_dir == "" or shared.demo is None:
return
os.makedirs(shared.opts.temp_dir, exist_ok=True)
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr():
temp_dir = shared.opts.temp_dir
if temp_dir == "" or not os.path.isdir(temp_dir):
return
for root, _, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
continue
filename = os.path.join(root, name)
os.remove(filename)
def is_gradio_temp_path(path):
"""
Check if the path is a temp dir used by gradio
"""
path = Path(path)
if shared.opts.temp_dir and path.is_relative_to(shared.opts.temp_dir):
return True
if gradio_temp_dir := os.environ.get("GRADIO_TEMP_DIR"):
if path.is_relative_to(gradio_temp_dir):
return True
if path.is_relative_to(Path(tempfile.gettempdir()) / "gradio"):
return True
return False
|