import gradio as gr import asyncio from threading import RLock from pathlib import Path import os from typing import Union HF_TOKEN = os.getenv("HF_TOKEN", None) server_timeout = 600 inference_timeout = 600 lock = RLock() loaded_models = {} def rename_image(image_path: Union[str, None], model_name: str, save_path: Union[str, None] = None): import shutil from datetime import datetime, timezone, timedelta if image_path is None: return None dt_now = datetime.now(timezone(timedelta(hours=9))) filename = f"{model_name.split('/')[-1]}_{dt_now.strftime('%Y%m%d_%H%M%S')}.png" try: if Path(image_path).exists(): png_path = "image.png" if str(Path(image_path).resolve()) != str(Path(png_path).resolve()): shutil.copy(image_path, png_path) if save_path is not None: new_path = str(Path(png_path).resolve().rename(Path(save_path).resolve())) else: new_path = str(Path(png_path).resolve().rename(Path(filename).resolve())) return new_path else: return None except Exception as e: print(e) return None # https://github.com/gradio-app/gradio/blob/main/gradio/external.py # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client from typing import Literal def load_from_model(model_name: str, hf_token: Union[str, Literal[False], None] = None): import httpx import huggingface_hub from gradio.exceptions import ModelNotFoundError, TooManyRequestsError model_url = f"https://huggingface.co/{model_name}" api_url = f"https://api-inference.huggingface.co/models/{model_name}" print(f"Fetching model from: {model_url}") headers = ({} if hf_token in [False, None] else {"Authorization": f"Bearer {hf_token}"}) response = httpx.request("GET", api_url, headers=headers) if response.status_code != 200: raise ModelNotFoundError( f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter." ) p = response.json().get("pipeline_tag") if p != "text-to-image": raise ModelNotFoundError(f"This model isn't for text-to-image or unsupported: {model_name}.") headers["X-Wait-For-Model"] = "true" kwargs = {} if hf_token is not None: kwargs["token"] = hf_token client = huggingface_hub.InferenceClient(model=model_name, headers=headers, timeout=server_timeout, **kwargs) inputs = gr.components.Textbox(label="Input") outputs = gr.components.Image(label="Output") fn = client.text_to_image def query_huggingface_inference_endpoints(*data, **kwargs): try: data = fn(*data, **kwargs) # type: ignore except huggingface_hub.utils.HfHubHTTPError as e: print(e) if "429" in str(e): raise TooManyRequestsError() from e except Exception as e: print(e) raise Exception() from e return data interface_info = { "fn": query_huggingface_inference_endpoints, "inputs": inputs, "outputs": outputs, "title": model_name, } return gr.Interface(**interface_info) def load_model(model_name: str): global loaded_models global model_info_dict if model_name in loaded_models.keys(): return loaded_models[model_name] try: loaded_models[model_name] = load_from_model(model_name, hf_token=HF_TOKEN) print(f"Loaded: {model_name}") except Exception as e: if model_name in loaded_models.keys(): del loaded_models[model_name] print(f"Failed to load: {model_name}") print(e) return None return loaded_models[model_name] def load_models(models: list): for model in models: load_model(model) def warm_model(model_name: str): model = load_model(model_name) if model: try: print(f"Warming model: {model_name}") infer_body(model, model_name, " ") except Exception as e: print(e) def warm_models(models: list[str]): for model in models: asyncio.new_event_loop().run_in_executor(None, warm_model, model) # https://huggingface.co/docs/api-inference/detailed_parameters # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client def infer_body(client: Union[gr.Interface, object], model_str: str, prompt: str, neg_prompt: str = "", height: int = 0, width: int = 0, steps: int = 0, cfg: int = 0, seed: int = -1): png_path = "image.png" kwargs = {} if height > 0: kwargs["height"] = height if width > 0: kwargs["width"] = width if steps > 0: kwargs["num_inference_steps"] = steps if cfg > 0: cfg = kwargs["guidance_scale"] = cfg if seed == -1: kwargs["seed"] = randomize_seed() else: kwargs["seed"] = seed try: if isinstance(client, gr.Interface): image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs) else: return None if isinstance(image, tuple): return None return save_image(image, png_path, model_str, prompt, neg_prompt, height, width, steps, cfg, seed) except Exception as e: print(e) raise Exception(e) from e async def infer(model_name: str, prompt: str, neg_prompt: str ="", height: int = 0, width: int = 0, steps: int = 0, cfg: int = 0, seed: int = -1, save_path: str | None = None, timeout: float = inference_timeout): model = load_model(model_name) if not model: return None task = asyncio.create_task(asyncio.to_thread(infer_body, model, model_name, prompt, neg_prompt, height, width, steps, cfg, seed)) await asyncio.sleep(0) try: result = await asyncio.wait_for(task, timeout=timeout) except asyncio.TimeoutError as e: print(e) print(f"Task timed out: {model_name}") if not task.done(): task.cancel() result = None raise Exception(f"Task timed out: {model_name}") from e except Exception as e: print(e) if not task.done(): task.cancel() result = None raise Exception(e) from e if task.done() and result is not None: with lock: image = rename_image(result, model_name, save_path) return image return None def save_image(image, savefile, modelname, prompt, nprompt, height=0, width=0, steps=0, cfg=0, seed=-1): from PIL import Image, PngImagePlugin import json try: metadata = {"prompt": prompt, "negative_prompt": nprompt, "Model": {"Model": modelname.split("/")[-1]}} if steps > 0: metadata["num_inference_steps"] = steps if cfg > 0: metadata["guidance_scale"] = cfg if seed != -1: metadata["seed"] = seed if width > 0 and height > 0: metadata["resolution"] = f"{width} x {height}" metadata_str = json.dumps(metadata) info = PngImagePlugin.PngInfo() info.add_text("metadata", metadata_str) image.save(savefile, "PNG", pnginfo=info) return str(Path(savefile).resolve()) except Exception as e: print(f"Failed to save image file: {e}") raise Exception(f"Failed to save image file:") from e def randomize_seed(): from random import seed, randint MAX_SEED = 2**32-1 seed() rseed = randint(0, MAX_SEED) return rseed def gen_image(model_name: str, prompt: str, neg_prompt: str = "", height: int = 0, width: int = 0, steps: int = 0, cfg: int = 0, seed: int = -1): if model_name in ["NA", ""]: return gr.update() try: loop = asyncio.get_running_loop() except Exception: loop = asyncio.new_event_loop() try: result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width, steps, cfg, seed, None, inference_timeout)) except (Exception, asyncio.CancelledError) as e: print(e) print(f"Task aborted: {model_name}, Error: {e}") result = None raise gr.Error(f"Task aborted: {model_name}, Error: {e}") finally: loop.close() return result def generate_image_hf(model_name: str, prompt: str, negative_prompt: str, use_defaults: bool, resolution: str, guidance_scale: float, num_inference_steps: int, seed: int, randomize_seed: bool, progress=gr.Progress()): if randomize_seed: seed = -1 if use_defaults: prompt = f"{prompt}, best quality, amazing quality, very aesthetic" negative_prompt = f"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], {negative_prompt}" width, height = map(int, resolution.split('x')) image = gen_image(model_name, prompt, negative_prompt, height, width, num_inference_steps, guidance_scale) metadata_text = f"{prompt}\nNegative prompt: {negative_prompt}\nSteps: {num_inference_steps}, Sampler: Euler a, Size: {width}x{height}, Seed: {seed}, CFG scale: {guidance_scale}" return image, seed, metadata_text