|
import os |
|
import tempfile |
|
from collections import namedtuple |
|
from pathlib import Path |
|
|
|
import gradio.components |
|
|
|
from PIL import PngImagePlugin |
|
|
|
from modules import shared |
|
|
|
|
|
Savedfile = namedtuple("Savedfile", ["name"]) |
|
|
|
|
|
def register_tmp_file(gradio, filename): |
|
if hasattr(gradio, 'temp_file_sets'): |
|
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} |
|
|
|
if hasattr(gradio, 'temp_dirs'): |
|
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))} |
|
|
|
|
|
def check_tmp_file(gradio, filename): |
|
if hasattr(gradio, 'temp_file_sets'): |
|
return any(filename in fileset for fileset in gradio.temp_file_sets) |
|
|
|
if hasattr(gradio, 'temp_dirs'): |
|
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs) |
|
|
|
return False |
|
|
|
|
|
def save_pil_to_file(self, pil_image, dir=None, format="png"): |
|
already_saved_as = getattr(pil_image, 'already_saved_as', None) |
|
if already_saved_as and os.path.isfile(already_saved_as): |
|
register_tmp_file(shared.demo, already_saved_as) |
|
filename = already_saved_as |
|
|
|
if not shared.opts.save_images_add_number: |
|
filename += f'?{os.path.getmtime(already_saved_as)}' |
|
|
|
return filename |
|
|
|
if shared.opts.temp_dir != "": |
|
dir = shared.opts.temp_dir |
|
else: |
|
os.makedirs(dir, exist_ok=True) |
|
|
|
use_metadata = False |
|
metadata = PngImagePlugin.PngInfo() |
|
for key, value in pil_image.info.items(): |
|
if isinstance(key, str) and isinstance(value, str): |
|
metadata.add_text(key, value) |
|
use_metadata = True |
|
|
|
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) |
|
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) |
|
return file_obj.name |
|
|
|
|
|
def install_ui_tempdir_override(): |
|
"""override save to file function so that it also writes PNG info""" |
|
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file |
|
|
|
|
|
def on_tmpdir_changed(): |
|
if shared.opts.temp_dir == "" or shared.demo is None: |
|
return |
|
|
|
os.makedirs(shared.opts.temp_dir, exist_ok=True) |
|
|
|
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x")) |
|
|
|
|
|
def cleanup_tmpdr(): |
|
temp_dir = shared.opts.temp_dir |
|
if temp_dir == "" or not os.path.isdir(temp_dir): |
|
return |
|
|
|
for root, _, files in os.walk(temp_dir, topdown=False): |
|
for name in files: |
|
_, extension = os.path.splitext(name) |
|
if extension != ".png": |
|
continue |
|
|
|
filename = os.path.join(root, name) |
|
os.remove(filename) |
|
|